package ch.unibe.scg.cells; import static com.google.common.base.Preconditions.checkNotNull; import static java.util.Collections.singletonList; import java.io.Closeable; import java.io.IOException; import java.io.PrintStream; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Set; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import javax.inject.Inject; import javax.inject.Provider; import com.google.common.base.Throwables; import com.google.common.collect.Iterables; import com.google.common.collect.Ordering; import com.google.common.io.Closer; import com.google.protobuf.ByteString; /** Implementation of a {@link Pipeline} meant to run in memory. */ public class InMemoryPipeline<IN, OUT> implements Pipeline<IN, OUT>, Closeable { /** In seconds. */ final private static int PRINT_INTERVAL = 1; /** In seconds. */ final private static int SHUTDOWN_TIMEOUT = 20; /** How much to sample from each shard to get splitters. */ final private static int SAMPLE_SIZE = 16; final private CellSource<IN> pipeSrc; /** Result of the last pipeline run. */ private Source<OUT> pipeSink; final private PipelineStageScope scope; /** Synchronized set. */ final private Provider<Set<LocalCounter>> registry; final private PrintStream out; final private ExecutorService threadPool = Executors.newFixedThreadPool( Runtime.getRuntime().availableProcessors()); InMemoryPipeline(CellSource<IN> pipeSrc, PipelineStageScope scope, Provider<Set<LocalCounter>> counterRegistry, PrintStream out) { // Don't subclass. this.pipeSrc = pipeSrc; this.scope = scope; this.registry = counterRegistry; this.out = out; } /** Incredibly hacky way of moving an exception from one thread to another. */ private static class ExceptionHolder { volatile Exception e; // volatile because it can be written and read from different threads. } /** A builder for a {@link InMemoryPipeline}. */ public static class Builder { final private PipelineStageScope scope; final private Provider<Set<LocalCounter>> registry; @Inject Builder(PipelineStageScope scope, @CounterRegistry Provider<Set<LocalCounter>> registry) { this.scope = scope; this.registry = registry; } /** Create a pipeline that uses stderr for diagnostic print. No parameters are allowed to be null. */ public <IN, OUT> InMemoryPipeline<IN, OUT> make(CellSource<IN> pipeSrc) { checkNotNull(pipeSrc); // counter info is diagnostic information, not actual output. return new InMemoryPipeline<>(pipeSrc, scope, registry, System.err); } } @Override public Source<OUT> lastEfflux() { return pipeSink; } @Override public MappablePipeline<IN, OUT> influx(Codec<IN> c) { return new InMemoryMappablePipeline<>(pipeSrc, c); } private class InMemoryMappablePipeline<I> implements MappablePipeline<I, OUT> { final private CellSource<I> src; final private Codec<I> srcCodec; InMemoryMappablePipeline(CellSource<I> src, Codec<I> srcCodec) { this.src = src; this.srcCodec = srcCodec; } @Override public <E> ShuffleablePipeline<E, OUT> map(Mapper<I, E> mapper) { return new InMemoryShuffleablePipeline<>(src, srcCodec, mapper); } @Override public void mapAndEfflux(Mapper<I, OUT> m, Codec<OUT> sinkCodec) throws IOException, InterruptedException { pipeSink = Cells.decodeSource(run(src, srcCodec, m, sinkCodec), sinkCodec); } } private class InMemoryShuffleablePipeline<I, E> implements ShuffleablePipeline<E, OUT> { final private CellSource<I> src; final private Codec<I> srcCodec; final private Mapper<I, E> mapper; InMemoryShuffleablePipeline(CellSource<I> src, Codec<I> srcCodec, Mapper<I, E> mapper) { this.src = src; this.srcCodec = srcCodec; this.mapper = mapper; } @Override public MappablePipeline<E, OUT> shuffle(Codec<E> sinkCodec) throws IOException, InterruptedException { return new InMemoryMappablePipeline<>(run(src, srcCodec, mapper, sinkCodec), sinkCodec); } } /** * Run the mapper. Closes source and mapper. * If an exception occurs in any mapper, it is printed to {@code out} from within that thread. * Any one of the exceptions is picked and returned to the caller of this method. * The threadpool gets shut down on exceptions, too. */ // TODO: Perhaps report all exceptions as suppressed? private <I, E> CellSource<E> run(final CellSource<I> src, final Codec<I> srcCodec, final Mapper<I, E> mapper, final Codec<E> sinkCodec) throws IOException, InterruptedException { ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1); // Fork printing thread. Runnable printer = new Runnable() { @Override public void run() { printCounters(); } }; scheduler.scheduleAtFixedRate(printer, PRINT_INTERVAL, PRINT_INTERVAL, TimeUnit.SECONDS); // Run mappers and close source. final List<List<Cell<E>>> sinks = new ArrayList<>(src.nShards()); for (int i = 0; i < src.nShards(); i++) { sinks.add(new ArrayList<Cell<E>>()); } // If an exception occurs, write it in here. final ExceptionHolder exceptionHolder = new ExceptionHolder(); try (final Closer closer = Closer.create()) { closer.register(mapper); closer.register(src); // Clone as many mappers as we have threads. final BlockingQueue<Mapper<I, E>> mappers = new ArrayBlockingQueue<>( Runtime.getRuntime().availableProcessors()); for (int i = 0; i < Runtime.getRuntime().availableProcessors(); i++) { @SuppressWarnings("resource") // It's getting registered for closing right here. Mapper<I, E> cloned = ShallowSerializingCopy.clone(mapper); closer.register(cloned); mappers.add(cloned); } // Run all rows in current shard through mapper. final CountDownLatch mapCnt = new CountDownLatch(src.nShards()); for (int s = 0; s < src.nShards(); s++) { final int shard = s; threadPool.execute(new Runnable() { @Override public void run() { try { for (Iterable<Cell<I>> rawRow : Cells.breakIntoRows(src.getShard(shard))) { Iterable<I> decoded = Cells.decode(rawRow, srcCodec); Iterator<I> iter = decoded.iterator(); final I first = iter.next(); final Iterable<I> row = Iterables.concat( singletonList(first), new AdapterOneShotIterable<>(iter)); @SuppressWarnings("resource") // Closed by mapperCloser. Mapper<I, E> m = mappers.take(); m.map( first, new AdapterOneShotIterable<>(row), encode(sinks.get(shard), sinkCodec)); mappers.put(m); } } catch (Exception e) { exceptionHolder.e = e; out.println("Suppressed exception:"); e.printStackTrace(out); mapCnt.countDown(); return; } mapCnt.countDown(); } }); } mapCnt.await(); if (exceptionHolder.e != null) { Exception cause = exceptionHolder.e; // Don't let anybody overwrite the exception while handling. Throwables.propagateIfPossible(cause, IOException.class, InterruptedException.class); throw new RuntimeException(cause); } } finally { // to ensure counters will be printed upon completion. printCounters(); if (scope != null) { scope.exit(); } scheduler.shutdownNow(); scheduler.awaitTermination(SHUTDOWN_TIMEOUT, TimeUnit.SECONDS); } // Filter out empty shards. final List<List<Cell<E>>> filteredSinks = new ArrayList<>(); for (int shard = 0; shard < sinks.size(); shard++) { if (!sinks.get(shard).isEmpty()) { filteredSinks.add(sinks.get(shard)); } } // Sort output shards final CountDownLatch sortCnt = new CountDownLatch(filteredSinks.size()); for (int s = 0; s < filteredSinks.size(); s++) { final int shard = s; threadPool.execute(new Runnable() { @Override public void run() { Collections.sort(filteredSinks.get(shard)); sortCnt.countDown(); } }); } sortCnt.await(); // Suck filteredSinks into next source final CountDownLatch suckCnt = new CountDownLatch(filteredSinks.size()); final List<ByteString> splitters = splitters(filteredSinks); final List<List<Cell<E>>> ret = new ArrayList<>(filteredSinks.size()); for (int s = 0; s < filteredSinks.size(); s++) { ret.add(null); // Multithreaded `add` is forbidden, so grow beforehand. } for (int s = 0; s < filteredSinks.size() - 1; s++) { // All shards but last. final int shard = s; threadPool.execute(new Runnable() { @Override public void run() { List<Cell<E>> merged = merge(filteredSinks, splitters.get(shard), splitters.get(shard + 1)); ret.set(shard, merged); suckCnt.countDown(); } }); } if (filteredSinks.size() > 0) { // Add last shard. threadPool.execute(new Runnable() { @Override public void run() { int lastShard = filteredSinks.size() - 1; List<Cell<E>> merged = mergeUnbounded(filteredSinks, splitters.get(lastShard)); ret.set(lastShard, merged); suckCnt.countDown(); } }); } suckCnt.await(); return InMemorySource.make(ret); } private static <T> Sink<T> encode(final Collection<Cell<T>> cellSink, final Codec<T> codec) { return new Sink<T>() { private static final long serialVersionUID = 1L; @Override public void write(T obj) throws IOException, InterruptedException { cellSink.add(codec.encode(obj)); } @Override public void close() throws IOException { // Nothing to do. } }; } /** * @return the sorted list of all entries in {@code sources} between {@code from} and {@code to}. * @param sources collection of individually sorted lists. * @param from first row key to be included, inclusive. * @param to last row key to be included, exclusive. */ private static <T> List<Cell<T>> merge(Iterable<List<Cell<T>>> sources, ByteString from, ByteString to) { assert shardsOrdered(sources); Cell<T> fromProbe = new Cell<>(from, ByteString.EMPTY, ByteString.EMPTY); Cell<T> toProbe = new Cell<>(to, ByteString.EMPTY, ByteString.EMPTY); Set<Cell<T>> ret = new HashSet<>(); for (List<Cell<T>> src : sources) { int fromPos = insertionPoint(fromProbe, src); int toPos = insertionPoint(toProbe, src); ret.addAll(src.subList(fromPos, toPos)); } return Ordering.natural().immutableSortedCopy(ret); } /** Same as {@link #merge}, but assuming parameter {@code to} as infinite. */ private static <T> List<Cell<T>> mergeUnbounded(Iterable<List<Cell<T>>> sources, ByteString from) { assert shardsOrdered(sources); Cell<T> fromProbe = new Cell<>(from, ByteString.EMPTY, ByteString.EMPTY); Set<Cell<T>> ret = new HashSet<>(); for (List<Cell<T>> src : sources) { int fromPos = insertionPoint(fromProbe, src); ret.addAll(src.subList(fromPos, src.size())); } return Ordering.natural().immutableSortedCopy(ret); } /** @return the position {@code needle} in {@code needle}, or the insertion point, if absent. */ private static <T> int insertionPoint(Cell<T> needle, List<Cell<T>> haystack) { int pos = Collections.binarySearch(haystack, needle); if (pos < 0) { pos = ~pos; } return pos; } /** * @return row keys of the start of each partiton, inclusive, such that all partitions are equal size. * The returned list has the same size as {@code sources}. * @param sources each source should be sorted. */ private static <T> List<ByteString> splitters(Collection<List<Cell<T>>> sources) { assert shardsOrdered(sources); assert !sources.isEmpty(); // Grab a sample of SAMPLE_SIZE elements from each source. List<Cell<T>> sample = new ArrayList<>(); for (List<Cell<T>> src : sources) { // step is src.size / SAMPLE_SIZE, rounded up to ensure step > 0 int step = (src.size() + SAMPLE_SIZE - 1) / SAMPLE_SIZE; for (int i = 0; i < src.size(); i += step) { sample.add(src.get(i)); } } Collections.sort(sample); List<ByteString> ret = new ArrayList<>(); int nPartitions = sources.size(); // Rounded up, to ensure step > 0 int step = (sample.size() + nPartitions - 1) / nPartitions; for (int i = 0; i < sample.size(); i += step) { ret.add(sample.get(i).getRowKey()); } assert ret.size() == sources.size() : "Should be " + sources.size() + " but was " + ret.size() + " sample was " + sample.size(); return ret; } private static <T> boolean shardsOrdered(Iterable<List<Cell<T>>> shards) { for (List<Cell<T>> s : shards) { if (!Ordering.natural().isOrdered(s)) { return false; } } return true; } private void printCounters() { Set<LocalCounter> counters = registry.get(); synchronized (counters) { // Needed for iterating over a synchronized set. for (LocalCounter c : counters) { out.println(c.toString()); } } } @Override public void close() throws IOException { threadPool.shutdownNow(); try { threadPool.awaitTermination(SHUTDOWN_TIMEOUT, TimeUnit.SECONDS); } catch (InterruptedException e) { Thread.currentThread().interrupt(); // Restore interrupted flag. } } }