package edu.stanford.nlp.util.concurrent; import edu.stanford.nlp.util.RuntimeInterruptedException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; /** * Provides convenient multicore processing for threadsafe objects. Objects that can * be wrapped by MulticoreWrapper must implement the ThreadsafeProcessor interface. * * See edu.stanford.nlp.util.concurrent.MulticoreWrapperTest and * edu.stanford.nlp.tagger.maxent.documentation.MulticoreWrapperDemo for examples of use. * * TODO(spenceg): This code does **not** support multiple consumers, i.e., multi-threaded calls * to peek() and poll(). * * @author Spence Green * * @param <I> input type * @param <O> output type */ public class MulticoreWrapper<I,O> { final int nThreads; int submittedItemCounter = 0; // Which id was the last id returned. Only meaningful in the case // of a queue where output order matters. private int returnedItemCounter = -1; private final boolean orderResults; private final Map<Integer,O> outputQueue; final ThreadPoolExecutor threadPool; // private final ExecutorCompletionService<Integer> queue; final BlockingQueue<Integer> idleProcessors; private final List<ThreadsafeProcessor<I,O>> processorList; private final JobCallback<O> callback; /** * Constructor. * * @param nThreads If less than or equal to 0, then automatically determine the number * of threads. Otherwise, the size of the underlying threadpool. * @param processor */ public MulticoreWrapper(int nThreads, ThreadsafeProcessor<I,O> processor) { this(nThreads, processor, true); } /** * Constructor. * * @param numThreads -- if less than or equal to 0, then automatically determine the number * of threads. Otherwise, the size of the underlying threadpool. * @param processor * @param orderResults -- If true, return results in the order submitted. Otherwise, return results * as they become available. */ public MulticoreWrapper(int numThreads, ThreadsafeProcessor<I,O> processor, boolean orderResults) { nThreads = numThreads <= 0 ? Runtime.getRuntime().availableProcessors() : numThreads; this.orderResults = orderResults; outputQueue = new ConcurrentHashMap<>(2 * nThreads); threadPool = buildThreadPool(nThreads); // queue = new ExecutorCompletionService<Integer>(threadPool); idleProcessors = new ArrayBlockingQueue<>(nThreads, false); callback = (result, processorId) -> { outputQueue.put(result.id, result.item); idleProcessors.add(processorId); }; // Sanity check: Fixed thread pool so prevent timeouts. // Default should be false threadPool.allowCoreThreadTimeOut(false); threadPool.prestartAllCoreThreads(); // Setup the processors, one per thread List<ThreadsafeProcessor<I,O>> procList = new ArrayList<>(nThreads); procList.add(processor); idleProcessors.add(0); for (int i = 1; i < nThreads; ++i) { procList.add(processor.newInstance()); idleProcessors.add(i); } processorList = Collections.unmodifiableList(procList); } protected ThreadPoolExecutor buildThreadPool(int nThreads) { return (ThreadPoolExecutor) Executors.newFixedThreadPool(nThreads); } public int nThreads() { return nThreads; } /** * Return status information about the underlying threadpool. */ @Override public String toString() { return String.format("active: %d/%d submitted: %d completed: %d input_q: %d output_q: %d idle_q: %d", threadPool.getActiveCount(), threadPool.getPoolSize(), threadPool.getTaskCount(), threadPool.getCompletedTaskCount(), threadPool.getQueue().size(), outputQueue.size(), idleProcessors.size()); } /** * Allocate instance to a process and return. This call blocks until item * can be assigned to a thread. * * @param item Input to a Processor * @throws RejectedExecutionException -- A RuntimeException when there is an * uncaught exception in the queue. Resolution is for the calling class to shutdown * the wrapper and create a new threadpool. * */ public synchronized void put(I item) throws RejectedExecutionException { Integer procId = getProcessor(); if (procId == null) { throw new RejectedExecutionException("Couldn't submit item to threadpool: " + item.toString()); } final int itemId = submittedItemCounter++; CallableJob<I,O> job = new CallableJob<>(item, itemId, processorList.get(procId), procId, callback); threadPool.submit(job); } /** * Returns the next available thread id. Subclasses may wish to * override this, for example if they implement a timeout */ Integer getProcessor() { try { return idleProcessors.take(); } catch (InterruptedException e) { throw new RuntimeInterruptedException(e); } } /** * Wait for all threads to finish, then destroy the pool of * worker threads so that the main thread can shutdown. */ public void join() { join(true); } /** * Wait for all threads to finish. * * @param destroyThreadpool -- if true, then destroy the worker threads * so that the main thread can shutdown. */ public void join(boolean destroyThreadpool) { // Make blocking calls to the last processes that are running if ( ! threadPool.isShutdown()) { try { for (int i = nThreads; i > 0; --i) { idleProcessors.take(); } if (destroyThreadpool) { threadPool.shutdown(); // Sanity check. The threadpool should be done after iterating over // the processors. threadPool.awaitTermination(10, TimeUnit.SECONDS); } else { // Repopulate the list of processors for (int i = 0; i < nThreads; ++i) { idleProcessors.put(i); } } } catch (InterruptedException e) { throw new RuntimeException(e); } } } /** * Indicates whether or not a new result is available. * * @return true if a new result is available, false otherwise. */ public boolean peek() { if (outputQueue.isEmpty()) { return false; } else { return orderResults ? outputQueue.containsKey(returnedItemCounter + 1) : true; } } /** * Returns the next available result. * * @return the next completed result, or null if no result is available */ public O poll() { if (!peek()) return null; returnedItemCounter++; int itemIndex = orderResults ? returnedItemCounter : outputQueue.keySet().iterator().next(); return outputQueue.remove(itemIndex); } /** * Internal class for a result when a CallableJob completes. * * @author Spence Green * * @param <O> */ private static interface JobCallback<O> { public void call(QueueItem<O> result, int processorId); } /** * Internal class for adding a job to the thread pool. * * @author Spence Green * * @param <I> * @param <O> */ static class CallableJob<I,O> implements Callable<Integer> { final I item; private final int itemId; private final ThreadsafeProcessor<I,O> processor; private final int processorId; private final JobCallback<O> callback; public CallableJob(I item, int itemId, ThreadsafeProcessor<I,O> processor, int processorId, JobCallback<O> callback) { this.item = item; this.itemId = itemId; this.processor = processor; this.processorId = processorId; this.callback = callback; } @Override public Integer call() { try { O result = processor.process(item); QueueItem<O> output = new QueueItem<>(result, itemId); callback.call(output, processorId); return itemId; } catch (Exception e) { e.printStackTrace(); // Hope that the consumer knows how to handle null! QueueItem<O> output = new QueueItem<>(null, itemId); callback.call(output, processorId); return itemId; } } } /** * Internal class for storing results of type O in a min queue. * * @author Spence Green * * @param <O> */ private static class QueueItem<O> implements Comparable<QueueItem<O>> { public final int id; public final O item; public QueueItem(O item, int id) { this.item = item; this.id = id; } @Override public int compareTo(QueueItem<O> other) { return this.id - other.id; } @SuppressWarnings("unchecked") public boolean equals(Object other) { if (other == this) return true; if ( ! (other instanceof QueueItem)) return false; QueueItem<O> otherQueue = (QueueItem<O>) other; return this.id == otherQueue.id; } public int hashCode() { return id; } } }