/*
* Copyright 2007-2010 Brian S O'Neill
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.cojen.dirmi.util;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.LinkedList;
import java.util.TreeSet;
import java.util.concurrent.AbstractExecutorService;
import java.util.concurrent.Callable;
import java.util.concurrent.Delayed;
import java.util.concurrent.Executors;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.security.AccessControlContext;
import java.security.AccessController;
import java.security.PrivilegedAction;
/**
* Custom thread pool implementation which is slightly more efficient than the
* default JDK1.6 thread pool and also provides scheduling services. More
* importantly, this implementation immediately removes cancelled tasks (in
* O(log n) time) instead of leaking memory.
*
* @author Brian S O'Neill
*/
public class ThreadPool extends AbstractExecutorService implements ScheduledExecutorService {
private static final boolean LIMIT_REACHED_THREAD_DUMP;
private static final boolean LIMIT_REACHED_SYSTEM_EXIT;
static {
String prefix = ThreadPool.class.getName() + ".limitReached";
LIMIT_REACHED_THREAD_DUMP = System.getProperty(prefix + "ThreadDump", "").equals("true");
LIMIT_REACHED_SYSTEM_EXIT = System.getProperty(prefix + "SystemExit", "").equals("true");
}
private static final AtomicLong cPoolNumber = new AtomicLong(1);
static final AtomicLong cTaskNumber = new AtomicLong(1);
private static final String SHUTDOWN_MESSAGE = "Thread pool is shutdown";
private final AccessControlContext mContext;
private final ThreadGroup mGroup;
private final AtomicLong mThreadNumber = new AtomicLong(1);
private final String mNamePrefix;
private final boolean mDaemon;
private final Thread.UncaughtExceptionHandler mHandler;
private final int mMax;
private final long mIdleTimeout = 10000;
// Pool is accessed like a stack.
private final LinkedList<PooledThread> mPool;
private final HashSet<PooledThread> mAllThreads;
private final TreeSet<Task> mScheduledTasks;
private boolean mTaskRunnerReady;
private int mActive;
private boolean mShutdown;
/**
* @param max the maximum allowed number of threads
* @param daemon pass true for all threads to be daemon -- they won't
* prevent the JVM from exiting
*/
public ThreadPool(int max, boolean daemon) {
this(max, daemon, null, null);
}
/**
* @param max the maximum allowed number of threads
* @param daemon pass true for all threads to be daemon -- they won't
* prevent the JVM from exiting
* @param prefix thread name prefix; default used if null
*/
public ThreadPool(int max, boolean daemon, String prefix) {
this(max, daemon, prefix, null);
}
/**
* @param max the maximum allowed number of threads
* @param daemon pass true for all threads to be daemon -- they won't
* prevent the JVM from exiting
* @param prefix thread name prefix; default used if null
* @param handler optional uncaught exception handler
*/
public ThreadPool(int max, boolean daemon, String prefix,
Thread.UncaughtExceptionHandler handler)
{
if (max <= 0) {
throw new IllegalArgumentException
("Maximum number of threads must be greater than zero: " + max);
}
mContext = AccessController.getContext();
SecurityManager s = System.getSecurityManager();
mGroup = (s != null) ? s.getThreadGroup() : Thread.currentThread().getThreadGroup();
if (prefix == null) {
prefix = "pool";
}
mNamePrefix = prefix + '-' + cPoolNumber.getAndIncrement() + "-thread-";
mDaemon = daemon;
mHandler = handler;
mMax = max;
mPool = new LinkedList<PooledThread>();
mAllThreads = new HashSet<PooledThread>();
mScheduledTasks = new TreeSet<Task>();
}
@Override
public void execute(Runnable command) throws RejectedExecutionException {
execute(command, false);
}
private void execute(Runnable command, boolean force) throws RejectedExecutionException {
if (command == null) {
throw new NullPointerException("Command is null");
}
PooledThread thread;
while (true) {
find: {
synchronized (mPool) {
if (!force && mShutdown) {
throw new RejectedExecutionException(SHUTDOWN_MESSAGE);
}
if (!mPool.isEmpty()) {
thread = mPool.removeLast();
break find;
}
if (mActive >= mMax) {
limitReached();
}
// Create a new thread if the number of active threads
// is less than the maximum allowed.
mActive++;
// Create outside synchronized block.
}
try {
thread = startNewPooledThread(command);
} catch (Error e) {
mActive--;
throw e;
}
return;
}
try {
if (thread.setCommand(command)) {
return;
}
} catch (IllegalStateException e) {
if (isShutdown()) {
// Cannot set command because thread is forced to run Shutdown.
throw new RejectedExecutionException(SHUTDOWN_MESSAGE);
}
throw e;
}
// Couldn't set the command because the pooled thread is exiting.
// Wait for it to exit to ensure that the active count is less
// than the maximum and try to obtain another thread.
try {
thread.join();
} catch (InterruptedException e) {
throw new RejectedExecutionException(e);
}
}
}
private void limitReached() {
String message = "Too many active threads: " + mMax;
if (LIMIT_REACHED_THREAD_DUMP) {
System.err.println(new java.util.Date() + ": " + message +
"; dumping current thread and all pooled threads");
Thread current = Thread.currentThread();
dump(System.err, current);
synchronized (mAllThreads) {
for (PooledThread t : mAllThreads) {
if (t != current) {
dump(System.err, t);
}
}
}
}
if (LIMIT_REACHED_SYSTEM_EXIT) {
Exception e = new RejectedExecutionException(message + "; exiting");
try {
Thread t = Thread.currentThread();
t.getUncaughtExceptionHandler().uncaughtException(t, e);
System.exit(1);
} catch (Throwable e2) {
// Cannot exit or excpetion handler is broken. Fall through and
// throw an exception.
}
}
throw new RejectedExecutionException(message);
}
private static void dump(java.io.PrintStream out, Thread t) {
out.println('"' + t.getName() + "\" state=" + t.getState());
try {
StackTraceElement[] trace = t.getStackTrace();
for (StackTraceElement element : trace) {
out.println("\t at " + element);
}
} catch (SecurityException e) {
out.println(e);
}
}
@Override
public ScheduledFuture<?> schedule(final Runnable command, long delay, TimeUnit unit) {
return new Task<Object>(Executors.callable(command), delay, 0, unit);
}
@Override
public <V> ScheduledFuture<V> schedule(Callable<V> callable, long delay, TimeUnit unit) {
return new Task<V>(callable, delay, 0, unit);
}
@Override
public ScheduledFuture<?> scheduleAtFixedRate(Runnable command,
long initialDelay,
long period,
TimeUnit unit)
{
if (period <= 0) {
throw new IllegalArgumentException();
}
return new Task<Object>(Executors.callable(command), initialDelay, period, unit);
}
/**
* Schedules a task which executes with a randomly selected rate, for
* applying jitter.
*
* @param lowPeriod low end of the random period range
* @param highPeriod high end of the random period range
*/
public ScheduledFuture<?> scheduleAtRandomRate(Runnable command,
long initialDelay,
long lowPeriod,
long highPeriod,
TimeUnit unit)
{
if (lowPeriod < 0 || highPeriod <= 0 || lowPeriod > highPeriod) {
throw new IllegalArgumentException();
}
Callable<Object> callable = Executors.callable(command);
if (lowPeriod == highPeriod) {
return new Task<Object>(callable, initialDelay, lowPeriod, unit);
} else {
return new JitterTask<Object>(callable, initialDelay, lowPeriod, highPeriod, unit);
}
}
@Override
public ScheduledFuture<?> scheduleWithFixedDelay(Runnable command,
long initialDelay,
long delay,
TimeUnit unit)
{
if (delay <= 0) {
throw new IllegalArgumentException();
}
return new Task<Object>(Executors.callable(command), initialDelay, -delay, unit);
}
@Override
public void shutdown() {
synchronized (mPool) {
if (!mShutdown) {
mShutdown = true;
Runnable shutdown = new Shutdown();
for (PooledThread thread : mPool) {
thread.setCommand(shutdown);
}
}
mPool.notifyAll();
}
synchronized (mScheduledTasks) {
mScheduledTasks.clear();
mScheduledTasks.notifyAll();
}
}
@Override
public List<Runnable> shutdownNow() {
shutdown();
synchronized (mAllThreads) {
for (Thread thread : mAllThreads) {
thread.interrupt();
}
}
// Implementation has no queue, so nothing to return. Scheduled tasks
// should not be returned, because it would imply that they are to be
// invoked immediately.
return Collections.emptyList();
}
@Override
public boolean isShutdown() {
synchronized (mPool) {
return mShutdown;
}
}
@Override
public boolean isTerminated() {
synchronized (mPool) {
return mShutdown && mActive <= 0;
}
}
@Override
public boolean awaitTermination(long time, TimeUnit unit) throws InterruptedException {
if (time < 0) {
return false;
}
synchronized (mPool) {
if (isTerminated()) {
return true;
}
if (time == 0) {
return false;
}
long start = System.nanoTime();
long nanos = unit.toNanos(time);
do {
mPool.wait(roundNanosToMillis(nanos));
long now = System.nanoTime();
if ((nanos -= now - start) <= 0) {
return isTerminated();
}
start = now;
} while (!isTerminated());
}
return true;
}
private static long roundNanosToMillis(long nanos) {
if (nanos <= (Long.MAX_VALUE - 999999)) {
nanos += 999999;
}
return nanos / 1000000;
}
void threadAvailable(PooledThread thread) {
synchronized (mPool) {
mPool.addLast(thread);
mPool.notify();
}
}
void threadExiting(PooledThread thread) {
synchronized (mPool) {
mPool.remove(thread);
mActive--;
mPool.notify();
}
synchronized (mAllThreads) {
mAllThreads.remove(thread);
}
}
/**
* @throws RejectedExecutionException only if shutdown
*/
void scheduleTask(Task<?> task) {
if (isShutdown()) {
throw new RejectedExecutionException(SHUTDOWN_MESSAGE);
}
synchronized (mScheduledTasks) {
if (!mScheduledTasks.add(task)) {
throw new InternalError();
}
if (mScheduledTasks.first() == task) {
if (mTaskRunnerReady) {
mScheduledTasks.notify();
} else {
TaskRunner runner = new TaskRunner();
try {
execute(runner, true);
mTaskRunnerReady = true;
} catch (RejectedExecutionException e) {
// Task is scheduled as soon as a thread becomes available.
}
}
}
}
}
TaskRunner needsTaskRunner() {
synchronized (mScheduledTasks) {
if (!mTaskRunnerReady && !mScheduledTasks.isEmpty()) {
TaskRunner runner = new TaskRunner();
mTaskRunnerReady = true;
return runner;
}
}
return null;
}
void removeTask(Task<?> task) {
synchronized (mScheduledTasks) {
mScheduledTasks.remove(task);
if (mScheduledTasks.isEmpty()) {
mScheduledTasks.notifyAll();
}
}
}
void runNextScheduledTask() {
Task<?> task;
boolean replaced;
synchronized (mScheduledTasks) {
while (true) {
if (mScheduledTasks.isEmpty()) {
mTaskRunnerReady = false;
return;
}
task = mScheduledTasks.first();
long delay = task.getAtNanos() - System.nanoTime();
if (delay <= 0) {
mScheduledTasks.remove(task);
try {
execute(new TaskRunner(), true);
replaced = true;
} catch (RejectedExecutionException e) {
// Task is scheduled as soon as a thread becomes available.
mTaskRunnerReady = false;
replaced = false;
}
break;
}
try {
mScheduledTasks.wait(roundNanosToMillis(delay));
} catch (InterruptedException e) {
// Clear the interrupted status.
Thread.interrupted();
e = null;
}
}
}
try {
task.run();
if (replaced) {
// Run any more tasks which need to be run
// immediately, to avoid having to switch context.
while (true) {
synchronized (mScheduledTasks) {
if (mScheduledTasks.isEmpty()) {
break;
}
task = mScheduledTasks.first();
if ((task.getAtNanos() - System.nanoTime()) > 0) {
break;
}
mScheduledTasks.remove(task);
}
// Clear the interrupted state, if set by previous task execution.
Thread.interrupted();
task.run();
}
}
} catch (Throwable e) {
Thread t = Thread.currentThread();
t.getUncaughtExceptionHandler().uncaughtException(t, e);
}
}
private PooledThread startNewPooledThread(Runnable command) {
PooledThread thread = new PooledThread
(mGroup, mNamePrefix + mThreadNumber.getAndIncrement(), mContext, command);
if (thread.isDaemon() != mDaemon) {
thread.setDaemon(mDaemon);
}
if (thread.getPriority() != Thread.NORM_PRIORITY) {
thread.setPriority(Thread.NORM_PRIORITY);
}
if (mHandler != null) {
thread.setUncaughtExceptionHandler(mHandler);
}
synchronized (mAllThreads) {
mAllThreads.add(thread);
}
try {
thread.start();
} catch (Error e) {
synchronized (mAllThreads) {
mAllThreads.remove(thread);
}
throw e;
}
return thread;
}
private class PooledThread extends Thread {
private final AccessControlContext mContext;
private Runnable mCommand;
private boolean mExiting;
public PooledThread(ThreadGroup group, String name,
AccessControlContext context, Runnable command)
{
super(group, null, name);
mContext = context;
mCommand = command;
}
synchronized boolean setCommand(Runnable command) {
if (mCommand != null) {
throw new IllegalStateException("Command in pooled thread is already set");
}
if (mExiting) {
return false;
} else {
mCommand = command;
notify();
return true;
}
}
private synchronized Runnable waitForCommand() throws InterruptedException {
Runnable command;
if ((command = mCommand) == null) {
long idle = mIdleTimeout;
if (idle != 0) {
if (idle < 0) {
wait(0);
} else {
wait(idle);
}
}
if ((command = mCommand) == null) {
mExiting = true;
}
}
mCommand = null;
return command;
}
@Override
public void run() {
AccessController.doPrivileged(new PrivilegedAction<Object>() {
public Object run() {
run0();
return null;
}
}, mContext);
}
void run0() {
try {
while (!isShutdown()) {
if (Thread.interrupted()) {
continue;
}
Runnable command;
try {
if ((command = waitForCommand()) == null) {
break;
}
} catch (InterruptedException e) {
e = null;
continue;
}
do {
try {
command.run();
} catch (Throwable e) {
if (!(command instanceof Shutdown)) {
getUncaughtExceptionHandler().uncaughtException(this, e);
}
e = null;
}
} while ((command = needsTaskRunner()) != null);
threadAvailable(this);
}
} finally {
threadExiting(this);
}
}
}
private class Task<V> extends FutureTask<V> implements ScheduledFuture<V> {
private final long mNum;
final long mPeriodNanos;
volatile long mAtNanos;
/**
* @param period Period for repeating tasks. A positive value indicates
* fixed-rate execution. A negative value indicates fixed-delay
* execution. A value of 0 indicates a non-repeating task.
*/
Task(Callable<V> callable, long initialDelay, long period, TimeUnit unit) {
super(callable);
long periodNanos;
if (period == 0) {
periodNanos = 0;
} else if ((periodNanos = unit.toNanos(period)) == 0) {
// Account for any rounding error.
periodNanos = period < 0 ? -1 : 1;
}
mPeriodNanos = periodNanos;
mNum = cTaskNumber.getAndIncrement();
long atNanos = System.nanoTime();
if (initialDelay > 0) {
atNanos += unit.toNanos(initialDelay);
}
mAtNanos = atNanos;
start();
}
void start() {
scheduleTask(this);
}
@Override
public long getDelay(TimeUnit unit) {
return unit.convert(mAtNanos - System.nanoTime(), TimeUnit.NANOSECONDS);
}
@Override
public int compareTo(Delayed delayed) {
if (this == delayed) {
return 0;
}
if (delayed instanceof Task) {
Task<?> other = (Task<?>) delayed;
long diff = mAtNanos - other.mAtNanos;
if (diff < 0) {
return -1;
} else if (diff > 0) {
return 1;
} else if (mNum < other.mNum) {
return -1;
} else {
return 1;
}
}
long diff = getDelay(TimeUnit.NANOSECONDS) - delayed.getDelay(TimeUnit.NANOSECONDS);
return diff == 0 ? 0 : (diff < 0 ? -1 : 1);
}
@Override
public void run() {
removeTask(this);
long periodNanos = mPeriodNanos;
if (periodNanos == 0) {
super.run();
} else if (super.runAndReset()) {
if (periodNanos > 0) {
mAtNanos += periodNanos;
} else {
mAtNanos = System.nanoTime() - periodNanos;
}
try {
scheduleTask(this);
} catch (RejectedExecutionException e) {
// Shutdown.
}
}
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
removeTask(this);
return super.cancel(mayInterruptIfRunning);
}
long getAtNanos() {
return mAtNanos;
}
@Override
public String toString() {
StringBuilder b = new StringBuilder()
.append("ScheduledFuture {delayNanos=")
.append(String.valueOf(getDelay(TimeUnit.NANOSECONDS)));
if (mPeriodNanos != 0) {
b.append(", periodNanos=").append(String.valueOf(mPeriodNanos));
}
return b.append('}').toString();
}
}
static int randomInt(int n) {
// See "Xorshift RNGs" by George Marsaglia for why these numbers were chosen.
n ^= (n << 13);
n ^= (n >>> 17);
n ^= (n << 5);
return n;
}
private class JitterTask<V> extends Task<V> {
private final long mRangeNanos;
private int mRandom;
JitterTask(Callable<V> callable, long initialDelay,
long lowPeriod, long highPeriod, TimeUnit unit)
{
super(callable, initialDelay, lowPeriod, unit);
mRangeNanos = unit.toNanos(highPeriod - lowPeriod);
while ((mRandom = Random.randomInt()) == 0);
super.start();
}
@Override
void start() {
}
@Override
public void run() {
removeTask(this);
if (super.runAndReset()) {
mAtNanos += mPeriodNanos + randomLong(mRangeNanos);
try {
scheduleTask(this);
} catch (RejectedExecutionException e) {
// Shutdown.
}
}
}
@Override
public String toString() {
return new StringBuilder()
.append("ScheduledFuture {delayNanos=")
.append(String.valueOf(getDelay(TimeUnit.NANOSECONDS)))
.append(", lowPeriodNanos=").append(String.valueOf(mPeriodNanos))
.append(", highPeriodNanos=").append(String.valueOf(mPeriodNanos + mRangeNanos))
.append('}').toString();
}
private long randomLong(final long n) {
int n2 = mRandom;
long bits, val;
do {
int n1 = randomInt(n2);
n2 = randomInt(n1);
bits = (((((long) n1) << 32) + (long) n2) >>> 1);
val = bits % n;
} while (bits - val + n - 1 < 0);
mRandom = n2;
return val;
}
}
private class TaskRunner implements Runnable {
@Override
public void run() {
runNextScheduledTask();
}
}
private static class Shutdown implements Runnable {
@Override
public void run() {
throw new ThreadDeath();
}
}
}