package com.mathieubolla.guava; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.FluentIterable.from; import java.util.Iterator; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.LinkedBlockingQueue; import com.google.common.base.Function; import com.google.common.base.Predicate; import com.google.common.base.Throwables; import com.google.common.collect.AbstractIterator; public final class ParallelUtils { private ParallelUtils() { super(); } /** * Will create a fixed thread pool executor service, and shut it down at * iterator's end. * * @see #parallelTransform(Iterable, com.google.common.base.Function, int, * java.util.concurrent.ExecutorService) */ public static <T, U> Iterable<U> parallelTransform( final Iterable<T> source, final Function<? super T, U> transform, int threads) { ExecutorService executorService = Executors.newFixedThreadPool(threads); return doTransformStuf(source, transform, threads, executorService, true); } /** * Computes transform on source, threads elements at a time, and iterates * over these in source order, tapping into executorService threadPool. */ public static <T, U> Iterable<U> parallelTransform( final Iterable<T> source, final Function<? super T, U> transform, int threads, final ExecutorService executorService) { return doTransformStuf(source, transform, threads, executorService, false); } /** * Computes filter on source, threads elements at a time, and iterates over * these in source order, tapping into executorService threadPool. */ public static <T> Iterable<T> parallelFilter(Iterable<T> source, Predicate<? super T> predicate, int threads, ExecutorService executorService) { return doFilterStuf(source, predicate, threads, executorService, false); } /** * Will create a fixed thread pool executor service, and shut it down at * iterator's end. * * @see #parallelFilter(Iterable, com.google.common.base.Predicate, int, * java.util.concurrent.ExecutorService) */ public static <T> Iterable<T> parallelFilter(Iterable<T> source, Predicate<? super T> predicate, int threads) { ExecutorService executorService = Executors.newFixedThreadPool(threads); return doFilterStuf(source, predicate, threads, executorService, true); } private static <T, U> Iterable<U> doTransformStuf(final Iterable<T> source, final Function<? super T, U> transform, int threads, final ExecutorService executorService, final boolean shutdownInTheEnd) { checkArgument(threads > 0, "amount of threads must be strictly positive"); final LinkedBlockingQueue<Future<U>> queue = new LinkedBlockingQueue<Future<U>>( threads); final Iterator<T> sourceIterator = source.iterator(); return new Iterable<U>() { public Iterator<U> iterator() { return new AbstractIterator<U>() { @Override protected U computeNext() { if (queue.isEmpty() && !sourceIterator.hasNext()) { if (shutdownInTheEnd) { executorService.shutdown(); } return endOfData(); } while (queue.remainingCapacity() > 0 && sourceIterator.hasNext()) { final T next = sourceIterator.next(); Future<U> future = executorService .submit(new Callable<U>() { public U call() throws Exception { return transform.apply(next); } }); try { queue.put(future); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw Throwables.propagate(e); } } try { return queue.take().get(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw Throwables.propagate(e); } catch (ExecutionException e) { Throwable orig = e.getCause(); Throwables.propagateIfPossible(orig); throw Throwables.propagate(orig); } } }; } }; } private static <T> Iterable<T> doFilterStuf(Iterable<T> source, Predicate<? super T> predicate, int threads, ExecutorService executorService, boolean shutdownInTheEnd) { return from( doTransformStuf(source, calculatePredicateResult(predicate), threads, executorService, shutdownInTheEnd)).filter( FilterResult.getState).transform( ParallelUtils.<T> getGetObjectFunction()); } @SuppressWarnings("unchecked") private static <T> Function<FilterResult, T> getGetObjectFunction() { return (Function<FilterResult, T>) FilterResult.getObject; } private static <T> Function<T, FilterResult> calculatePredicateResult( final Predicate<T> predicate) { return new Function<T, FilterResult>() { public FilterResult apply(T input) { return new FilterResult(input, predicate.apply(input)); } }; } private static class FilterResult { private static final Function<FilterResult, Object> getObject = new Function<FilterResult, Object>() { @Override public Object apply(FilterResult filterResult) { return filterResult.object; } }; private static final Predicate<FilterResult> getState = new Predicate<FilterResult>() { public boolean apply(FilterResult filterResult) { return filterResult.state; }; }; private final Object object; private final boolean state; FilterResult(Object a, boolean b) { this.object = a; this.state = b; } } }