package com.hwlcn.ldap.util.parallel; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import com.hwlcn.ldap.util.Debug; import com.hwlcn.core.annotation.InternalUseOnly; import com.hwlcn.ldap.util.LDAPSDKThreadFactory; import com.hwlcn.core.annotation.ThreadSafety; import com.hwlcn.ldap.util.ThreadSafetyLevel; import com.hwlcn.ldap.util.Validator; @InternalUseOnly() @ThreadSafety(level=ThreadSafetyLevel.COMPLETELY_THREADSAFE) public final class ParallelProcessor<I, O> { private final Processor<I, O> processor; private final List<Thread> workers; private final int minPerThread; private final Semaphore workerSemaphore = new Semaphore(0); private final AtomicReference<List<? extends I>> inputItems = new AtomicReference<List<? extends I>>(); private final AtomicReference<List<Result<I, O>>> outputItems = new AtomicReference<List<Result<I, O>>>(); private final AtomicInteger nextToProcess = new AtomicInteger(); private volatile CountDownLatch processingCompleteSignal; private final AtomicBoolean shutdown = new AtomicBoolean(); public ParallelProcessor(final Processor<I, O> processor, final int totalThreads, final int minPerThread) { this(processor, null, totalThreads, minPerThread); } public ParallelProcessor(final Processor<I, O> processor, final ThreadFactory threadFactory, final int totalThreads, final int minPerThread) { Validator.ensureNotNull(processor); Validator.ensureTrue(totalThreads >= 1, "ParallelProcessor.totalThreads must be at least 1."); Validator.ensureTrue(totalThreads <= 1000, // Upper bound on # of threads "ParallelProcessor.totalThreads must not be greater than 1000."); Validator.ensureTrue(minPerThread >= 1, "ParallelProcessor.minPerThread must be at least 1."); this.processor = processor; this.minPerThread = minPerThread; final ThreadFactory tf; if (threadFactory == null) { tf = new LDAPSDKThreadFactory("ParallelProcessor-Worker", true); } else { tf = threadFactory; } final int numExtraThreads = totalThreads - 1; final List<Thread> workerList = new ArrayList<Thread>(numExtraThreads); for (int i = 0; i < numExtraThreads; i++) { final Thread worker = tf.newThread(new Worker()); workerList.add(worker); worker.start(); } workers = workerList; } public synchronized ArrayList<Result<I, O>> processAll( final List<? extends I> items) throws InterruptedException, IllegalStateException { if (shutdown.get()) { throw new IllegalStateException( "cannot call processAll() after shutdown()"); } Validator.ensureNotNull(items); final int extraThreads = Math.min((items.size() / minPerThread) - 1, workers.size()); if (extraThreads <= 0) { final ArrayList<Result<I, O>> output = new ArrayList<Result<I, O>>(items.size()); for (final I item : items) { output.add(process(item)); } return output; } processingCompleteSignal = new CountDownLatch(extraThreads); inputItems.set(items); final ArrayList<Result<I, O>> output = new ArrayList<Result<I, O>>(items.size()); for (int i = 0; i < items.size(); i++) { output.add(null); } outputItems.set(output); nextToProcess.set(0); workerSemaphore.release(extraThreads); processInParallel(); processingCompleteSignal.await(); return output; } public synchronized void shutdown() throws InterruptedException { if (shutdown.getAndSet(true)) { return; } workerSemaphore.release(workers.size()); for (final Thread worker : workers) { worker.join(); } } private void processInParallel() { try { final List<? extends I> items = inputItems.get(); final List<Result<I, O>> outputs = outputItems.get(); final int size = items.size(); int next; while ((next = nextToProcess.getAndIncrement()) < size) { final I input = items.get(next); outputs.set(next, process(input)); } } catch (Throwable e) { Debug.debugException(e); } } private ProcessResult process(final I input) { O output = null; Throwable failureCause = null; try { output = processor.process(input); } catch (Throwable e) { failureCause = e; } return new ProcessResult(input, output, failureCause); } private final class Worker implements Runnable { private Worker() { } public void run() { while (true) { try { workerSemaphore.acquire(); } catch (InterruptedException e) { Debug.debugException(e); } if (shutdown.get()) { return; } try { processInParallel(); } finally { processingCompleteSignal.countDown(); } } } } private final class ProcessResult implements Result<I, O> { private final I inputItem; private final O outputItem; private final Throwable failureCause; private ProcessResult(final I inputItem, final O outputItem, final Throwable failureCause) { this.inputItem = inputItem; this.outputItem = outputItem; this.failureCause = failureCause; } public I getInput() { return inputItem; } public O getOutput() { return outputItem; } public Throwable getFailureCause() { return failureCause; } } }