package io.lumify.core.util;
import java.util.*;
import java.util.concurrent.*;
public class WorkerPool {
private static final LumifyLogger LOGGER = LumifyLoggerFactory.getLogger(WorkerPool.class);
private ExecutorService executorService;
/**
* Create a pool with the specified number of threads.
*
* @param nThreads
*/
public WorkerPool(int nThreads) {
LOGGER.debug("initializing worker pool with %d threads", nThreads);
executorService = LOGGER.isDebugEnabled() ? new MetricReportingExecutorService(LOGGER, nThreads) : Executors.newFixedThreadPool(nThreads);
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
shutdownAndAwaitTermination(10);
}
});
}
private void shutdownAndAwaitTermination(int seconds) {
// disable submission of new tasks
executorService.shutdown();
try {
// wait for existing tasks to terminate
if (!executorService.awaitTermination(seconds, TimeUnit.SECONDS)) {
// cancel lingering tasks
executorService.shutdownNow();
// wait for lingering tasks to terminate
if (!executorService.awaitTermination(seconds, TimeUnit.SECONDS)) {
System.err.println("executorService did not terminate!");
}
}
} catch (InterruptedException ie) {
// (re-)cancel if current thread also interrupted
executorService.shutdownNow();
// preserve interrupt status
Thread.currentThread().interrupt();
}
}
/**
* Execute the {@link Callable} tasks in parallel (per the configured size of the {@link WorkerPool}) and wait for them to complete.
*
* @param tasks a map of {@link Callable}s with keys by which you will be able to access each return value
* @return the return values of each {@link Callable}s mapped by their input key
*/
public <K, V> Map<K, V> invokeAll(Map<K, Callable<V>> tasks) {
String caller = LOGGER.isDebugEnabled() ? Thread.currentThread().getStackTrace()[2].toString() : "n/a";
LOGGER.debug("[%s] is invoking %d mapped tasks", caller, tasks.size());
List<K> orderedKeys = new ArrayList<K>(tasks.size());
List<Callable<V>> orderedTasks = new ArrayList<Callable<V>>(tasks.size());
for (Map.Entry<K, Callable<V>> entry : tasks.entrySet()) {
orderedKeys.add(entry.getKey());
orderedTasks.add(entry.getValue());
}
try {
long start = System.currentTimeMillis();
List<Future<V>> executorResults = executorService.invokeAll(orderedTasks);
long finish = System.currentTimeMillis();
LOGGER.debug("[%s] invoked %d mapped tasks in %d ms", caller, tasks.size(), finish - start);
Map<K, V> mappedResults = new LinkedHashMap<K, V>(tasks.size());
for (int i = 0; i < tasks.size(); i++) {
K key = orderedKeys.get(i);
V result = executorResults.get(i).get();
mappedResults.put(key, result);
}
return mappedResults;
} catch (InterruptedException e) {
throw new RuntimeException(e);
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}
/**
* Execute the {@link Callable} tasks in parallel (per the configured size of the {@link WorkerPool}) and wait for them to complete.
*
* @param tasks a list of {@link Callable}s
* @return the ordered return values
*/
public <T> List<T> invokeAll(List<Callable<T>> tasks) {
String caller = LOGGER.isDebugEnabled() ? Thread.currentThread().getStackTrace()[2].toString() : "n/a";
LOGGER.debug("[%s] is invoking %d listed tasks", caller, tasks.size());
try {
long start = System.currentTimeMillis();
List<Future<T>> executorResults = executorService.invokeAll(tasks);
long finish = System.currentTimeMillis();
LOGGER.debug("[%s] invoked %d listed tasks in %d ms", caller, tasks.size(), finish - start);
List<T> results = new ArrayList<T>(tasks.size());
for (Future<T> future : executorResults) {
results.add(future.get());
}
return results;
} catch (InterruptedException e) {
throw new RuntimeException(e);
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
}
public static void main(String[] args) {
WorkerPool workerPool = new WorkerPool(3);
mapExample(workerPool);
listExample(workerPool);
slowExample(workerPool, 5);
if (LOGGER.isDebugEnabled()) {
((MetricReportingExecutorService) workerPool.executorService).tick();
((MetricReportingExecutorService) workerPool.executorService).report();
}
System.exit(0);
}
/**
* Calculate the squares of the integers 1 through 10.
*
* @param workerPool
*/
private static void mapExample(WorkerPool workerPool) {
Map<Integer, Callable<Integer>> taskMap = new HashMap<Integer, Callable<Integer>>();
for (int i = 1; i <= 10; i++) {
final int input = i;
Callable<Integer> callable = new Callable<Integer>() {
@Override
public Integer call() throws Exception {
// do all your parallelizable work here
return input * input;
}
};
taskMap.put(i, callable);
}
Map<Integer, Integer> resultMap = workerPool.invokeAll(taskMap);
for (Map.Entry<Integer, Integer> entry : resultMap.entrySet()) {
System.out.println("key: " + entry.getKey() + ", value: " + entry.getValue());
}
}
/**
* Return strings including the integers -1 through -10.
*
* @param workerPool
*/
private static void listExample(WorkerPool workerPool) {
List<Callable<String>> taskList = new ArrayList<Callable<String>>();
for (int i = 1; i <= 10; i++) {
final String input = Integer.toString(-1 * i);
Callable<String> callable = new Callable<String>() {
@Override
public String call() throws Exception {
// do all your parallelizable work here
return "R(" + input + ")";
}
};
taskList.add(callable);
}
List<String> resultList = workerPool.invokeAll(taskList);
for (String result : resultList) {
System.out.println("result: " + result);
}
}
private static void slowExample(WorkerPool workerPool, int nTasks) {
List<Callable<Integer>> taskList = new ArrayList<Callable<Integer>>();
Random random = new Random();
int maxWait = 10000;
for (int i = 1; i <= nTasks; i++) {
final int wait = random.nextInt(maxWait);
Callable<Integer> callable = new Callable<Integer>() {
@Override
public Integer call() throws Exception {
Thread.sleep(wait);
return wait;
}
};
taskList.add(callable);
}
long start = System.currentTimeMillis();
List<Integer> waits = workerPool.invokeAll(taskList);
long finish = System.currentTimeMillis();
int longestWait = 0;
int totalWait = 0;
for (int wait : waits) {
longestWait = wait > longestWait ? wait : longestWait;
totalWait += wait;
}
System.out.println(String.format("elapsed time: %5d ms (%.2f%%)", finish - start, 100 * (1.0 * (finish - start)) / totalWait));
System.out.println(String.format("longest wait: %5d ms", longestWait));
System.out.println(String.format("average wait: %5d ms", totalWait / nTasks));
System.out.println(String.format("total wait: %5d ms", totalWait));
}
}