package com.github.kmkt.util.concurrent;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.FutureTask;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* submit により投入される引数(タスク)を TaskWorker で実装された処理実体により並列処理するクラス
* ThreadPoolExecutor では runnable, callable を投入するが、TaskWorkerRunner では TaskWorker への
* 引数を投入する
*
* License : MIT License
* Copyright (c) 2015 NagasawaXien
*
* @param <T> TaskWorker への引数の型
* @param <R> TaskWorker からの返り値の型
*/
public class TaskWorkerRunner<T, R> {
private static final Logger logger = LoggerFactory.getLogger(TaskWorkerRunner.class);
/** タスク処理用 ThreadPool */
protected final ExecutorService pool;
/** 処理スレッド数の最大 */
protected int maximimParallel;
/** タスク処理の実体 TaskWorker のサプライヤ・ファクトリ */
protected final TaskWorkerSupplier<T, R> workerSupplier;
/** タスク処理が終了した TaskWorker の通知インタフェース */
protected final TaskWorkerCollector workerCollector;
/** CallbackCaller callback しない場合は null */
protected final CallbackCaller<T, R> callbackCaller;
/** 総タスク数カウンタ */
private AtomicInteger tasks = new AtomicInteger(0);
/** 処理中タスク数カウンタ */
private AtomicInteger runningTasks = new AtomicInteger(0);
/**
* 処理終了時に CallbackCaller に終了タスクを引き渡す FutureTask
* TaskRunner 専用の内部クラス
*
* @see java.util.concurrent.FutureTask
*/
protected class CallbackFutureTask extends FutureTask<R> {
private T req; // TaskRunner が持つ TaskWorker への引数
public CallbackFutureTask(TaskRunner runner) {
super(runner);
if (runner == null)
throw new IllegalArgumentException("runner should not be null");
req = runner.getTaskReq();
}
public T getTaskReq() {
return req;
}
@Override
protected void done() {
// callback 有効時は CallbackCaller にリクエストと自身を登録する
if (callbackCaller != null)
callbackCaller.registerFinishedTask(req, this);
}
}
/**
* TaskWorker を呼び出す Callable wrapper
*/
protected class TaskRunner implements Callable<R> {
private T req; // TaskWorker への引数
/**
*
* @param taskreq TaskWorker への引数
*/
public TaskRunner(T taskreq) {
this.req = taskreq;
}
/**
* TaskWorker への引数を取得する
* @return
*/
public T getTaskReq() {
return req;
}
@Override
public R call() throws Exception {
TaskWorker<T, R> worker = null;
try {
runningTasks.incrementAndGet();
worker = workerSupplier.get(); // TaskWorker 取得
if (worker == null)
throw new TaskWorkerStartException("TaskWorkerSupplier returns null", req);
R result = worker.doTask(req);
return result;
} finally {
if (workerCollector != null && worker != null)
workerCollector.collect(worker);
runningTasks.decrementAndGet();
tasks.decrementAndGet();
}
}
}
/**
* コールバック処理スレッド
* @param <T> TaskWorker への引数の型
* @param <R> TaskWorker からの返り値の型
*/
private static class CallbackCaller<T, R> implements Runnable {
/** 処理終了時の callback */
private TaskCompleteListener<T, R> listener;
private BlockingQueue<Pair> queue = new LinkedBlockingQueue<>();
/**
* タスクとその処理結果を含む Future のペア
*/
private class Pair {
T t;
Future<R> r;
}
/**
*
* @param listener 処理終了時の callback
*/
public CallbackCaller(TaskCompleteListener<T, R> listener) {
if (listener == null)
throw new IllegalArgumentException("listener should not be null");
this.listener = listener;
}
@Override
public void run() {
logger.debug("Start task complete callback thread");
try {
while (true) {
Pair ele = queue.take(); // block
try {
listener.onComplete(ele.t, ele.r); // callback
} catch (Exception e) {
logger.error("Unexpedted exception in TaskCompleteListener callback", e);
}
if (Thread.interrupted())
break; // exit loop by interruption
}
} catch (InterruptedException e) {
// exit loop by interruption
} finally {
if (Thread.interrupted()) {
logger.debug("Exit task complete callback thread by thread interruption");
} else {
logger.debug("Exit task complete callback thread");
}
}
}
/**
* 処理終了したタスクの登録
* @param taskreq 処理終了したタスク
* @param future taskreq の処理結果を含む Future
*/
public void registerFinishedTask(T taskreq, Future<R> future) {
Pair ele = new Pair();
ele.t = taskreq;
ele.r = future;
try {
queue.put(ele);
} catch (InterruptedException e) {
logger.error(null, e); // not occure
}
}
}
/**
* Future のみで処理完了待ちをする TaskWorkerRunner
*
* @param maxparallels 処理スレッド数の最大
* @param supplier タスク処理の実体 TaskWorker のサプライヤorファクトリ
*/
public TaskWorkerRunner(int maxparallels, TaskWorkerSupplier<T, R> supplier) {
this(maxparallels, supplier, null, null);
}
/**
* Future のみで処理完了待ちをする TaskWorkerRunner
*
* @param maxparallels 処理スレッド数の最大
* @param supplier タスク処理の実体 TaskWorker のサプライヤorファクトリ
* @param collector タスク処理が終了した TaskWorker の通知インタフェース
*/
public TaskWorkerRunner(int maxparallels, TaskWorkerSupplier<T, R> supplier, TaskWorkerCollector collector) {
this(maxparallels, supplier, collector, null);
}
/**
* Future での処理完了待機と callback での処理完了通知をする TaskWorkerRunner
*
* @param maxparallels 処理スレッド数の最大
* @param supplier タスク処理の実体 TaskWorker のサプライヤorファクトリ
* @param listener 処理終了時の callback null時は callback しない
*/
public TaskWorkerRunner(int maxparallels, TaskWorkerSupplier<T, R> supplier, TaskCompleteListener<T, R> listener) {
this(maxparallels, supplier, null, listener);
}
/**
* Future での処理完了待機と callback での処理完了通知をする TaskWorkerRunner
*
* @param maxparallels 処理スレッド数の最大
* @param supplier タスク処理の実体 TaskWorker のサプライヤorファクトリ
* @param collector タスク処理が終了した TaskWorker の通知インタフェース
* @param listener 処理終了時の callback null時は callback しない
*/
public TaskWorkerRunner(int maxparallels, TaskWorkerSupplier<T, R> supplier, TaskWorkerCollector collector, TaskCompleteListener<T, R> listener) {
if (maxparallels <= 0)
throw new IllegalArgumentException("maxparallels should be a positive integer");
if (supplier == null)
throw new IllegalArgumentException("supplier should not be null");
maximimParallel = maxparallels;
workerSupplier = supplier;
workerCollector = collector;
if (listener != null) {
callbackCaller = new CallbackCaller<T, R>(listener);
} else {
callbackCaller = null;
}
int threads = (callbackCaller == null) ? maxparallels : maxparallels + 1; // callback 有効時は callback 用スレッドを確保する
pool = Executors.newFixedThreadPool(threads, new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r);
t.setDaemon(true);
return t;
}
});
// callback 有効時は callback 用スレッドを確保する
if (callbackCaller != null) {
pool.execute(callbackCaller);
}
}
/**
* TaskWorkerRunner に処理対象の taskreq を送信する
* @param taskreq TaskWorker に処理させる引数 nullable
* @return タスクの保留完了を表すFuture
*/
public Future<R> submit(T taskreq) {
FutureTask<R> taskhost = new CallbackFutureTask(new TaskRunner(taskreq));
tasks.incrementAndGet();
pool.execute(taskhost);
return taskhost;
}
/**
* 未実行のタスク数を取得する
* @return
*/
public int getNumOfWaitingTask() {
return tasks.get() - runningTasks.get();
}
/**
* 処理中のタスク数を取得する
* @return
*/
public int getNumOfRunningTask() {
return runningTasks.get();
}
/**
* 処理中・未実行タスクを合わせた総タスク数を取得する
* @return
*/
public int getNumOfTask() {
return tasks.get();
}
/**
* 内部の ThreadPool をシャットダウンする
* @see java.util.concurrent.ExecutorService.shutdown()
*/
public void shutdown() {
pool.shutdown();
}
/**
* 内部の ThreadPool をシャットダウンする
* @see java.util.concurrent.ExecutorService.shutdownNow()
* @return 未処理のタスクのリスト
*/
@SuppressWarnings("unchecked")
public List<T> shutdownNow() {
List<Runnable> remain = pool.shutdownNow();
List<T> result = new ArrayList<T>(remain.size());
for (Runnable r : remain) {
if (!(r instanceof CallbackCaller)) {
result.add(((CallbackFutureTask) r).getTaskReq());
}
}
return result;
}
/**
* 内部の ThreadPool のシャットダウン完了を待つ
* @see java.util.concurrent.ExecutorService.awaitTermination(int, TimeUnit)
*
* @param timeout 待機する最長時間
* @param unit timeout引数の時間単位
* @return シャットダウンした場合は true それ以外は false
* @throws InterruptedException
*/
public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
return pool.awaitTermination(timeout, unit);
}
}