package org.infinispan.stream.impl; import java.util.AbstractMap; import java.util.ArrayDeque; import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.Queue; import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.BinaryOperator; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; import java.util.stream.BaseStream; import java.util.stream.Collectors; import java.util.stream.Stream; import org.infinispan.CacheStream; import org.infinispan.commons.CacheException; import org.infinispan.commons.util.concurrent.ConcurrentHashSet; import org.infinispan.container.entries.CacheEntry; import org.infinispan.distribution.DistributionManager; import org.infinispan.distribution.ch.ConsistentHash; import org.infinispan.distribution.ch.KeyPartitioner; import org.infinispan.factories.ComponentRegistry; import org.infinispan.partitionhandling.impl.PartitionHandlingManager; import org.infinispan.remoting.transport.Address; import org.infinispan.stream.impl.intops.IntermediateOperation; import org.infinispan.stream.impl.termop.SegmentRetryingOperation; import org.infinispan.stream.impl.termop.SingleRunOperation; import org.infinispan.stream.impl.termop.object.FlatMapIteratorOperation; import org.infinispan.stream.impl.termop.object.MapIteratorOperation; import org.infinispan.stream.impl.termop.object.NoMapIteratorOperation; import org.infinispan.util.RangeSet; import org.infinispan.util.concurrent.TimeoutException; import org.infinispan.util.logging.Log; import org.infinispan.util.logging.LogFactory; /** * Abstract stream that provides all of the common functionality required for all types of Streams including the various * primitive types. * @param <T> The type returned by the stream * @param <S> The stream interface */ public abstract class AbstractCacheStream<T, S extends BaseStream<T, S>, S2 extends S> implements BaseStream<T, S> { protected final Log log = LogFactory.getLog(getClass()); protected final Queue<IntermediateOperation> intermediateOperations; protected final Address localAddress; protected final DistributionManager dm; protected final Supplier<CacheStream<CacheEntry>> supplier; protected final ClusterStreamManager csm; protected final boolean includeLoader; protected final Executor executor; protected final ComponentRegistry registry; protected final PartitionHandlingManager partition; protected final KeyPartitioner keyPartitioner; protected Runnable closeRunnable = null; protected boolean parallel; protected Boolean parallelDistribution; protected boolean rehashAware = true; protected Set<?> keysToFilter; protected Set<Integer> segmentsToFilter; protected int distributedBatchSize; protected CacheStream.SegmentCompletionListener segmentCompletionListener; protected IteratorOperation iteratorOperation = IteratorOperation.NO_MAP; protected long timeout = 30; protected TimeUnit timeoutUnit = TimeUnit.SECONDS; protected AbstractCacheStream(Address localAddress, boolean parallel, DistributionManager dm, Supplier<CacheStream<CacheEntry>> supplier, ClusterStreamManager<Object> csm, boolean includeLoader, int distributedBatchSize, Executor executor, ComponentRegistry registry) { this.localAddress = localAddress; this.parallel = parallel; this.dm = dm; this.supplier = supplier; this.csm = csm; this.includeLoader = includeLoader; this.distributedBatchSize = distributedBatchSize; this.executor = executor; this.registry = registry; this.partition = registry.getComponent(PartitionHandlingManager.class); this.keyPartitioner = registry.getComponent(KeyPartitioner.class); intermediateOperations = new ArrayDeque<>(); } protected AbstractCacheStream(AbstractCacheStream<T, S, S2> other) { this.intermediateOperations = other.intermediateOperations; this.localAddress = other.localAddress; this.dm = other.dm; this.supplier = other.supplier; this.csm = other.csm; this.includeLoader = other.includeLoader; this.executor = other.executor; this.registry = other.registry; this.partition = other.partition; this.keyPartitioner = other.keyPartitioner; this.closeRunnable = other.closeRunnable; this.parallel = other.parallel; this.parallelDistribution = other.parallelDistribution; this.rehashAware = other.rehashAware; this.keysToFilter = other.keysToFilter; this.segmentsToFilter = other.segmentsToFilter; this.distributedBatchSize = other.distributedBatchSize; this.segmentCompletionListener = other.segmentCompletionListener; this.iteratorOperation = other.iteratorOperation; this.timeout = other.timeout; this.timeoutUnit = other.timeoutUnit; } protected S2 addIntermediateOperation(IntermediateOperation<T, S, T, S> intermediateOperation) { intermediateOperation.handleInjection(registry); addIntermediateOperation(intermediateOperations, intermediateOperation); return unwrap(); } protected void addIntermediateOperationMap(IntermediateOperation<T, S, ?, ?> intermediateOperation) { intermediateOperation.handleInjection(registry); addIntermediateOperation(intermediateOperations, intermediateOperation); } protected void addIntermediateOperation(Queue<IntermediateOperation> intermediateOperations, IntermediateOperation<T, S, ?, ?> intermediateOperation) { intermediateOperations.add(intermediateOperation); } protected abstract S2 unwrap(); @Override public boolean isParallel() { return parallel; } boolean getParallelDistribution() { return parallelDistribution == null ? true : parallelDistribution; } @Override public S2 sequential() { parallel = false; return unwrap(); } @Override public S2 parallel() { parallel = true; return unwrap(); } @Override public S2 unordered() { // This by default is always unordered return unwrap(); } @Override public S2 onClose(Runnable closeHandler) { if (this.closeRunnable == null) { this.closeRunnable = closeHandler; } else { this.closeRunnable = composeWithExceptions(this.closeRunnable, closeHandler); } return unwrap(); } @Override public void close() { if (closeRunnable != null) { closeRunnable.run(); } } <R> R performOperation(Function<? super S2, ? extends R> function, boolean retryOnRehash, BinaryOperator<R> accumulator, Predicate<? super R> earlyTerminatePredicate) { ResultsAccumulator<R> remoteResults = new ResultsAccumulator<>(accumulator); if (rehashAware) { return performOperationRehashAware(function, retryOnRehash, remoteResults, earlyTerminatePredicate); } else { return performOperation(function, remoteResults, earlyTerminatePredicate); } } <R> R performOperation(Function<? super S2, ? extends R> function, ResultsAccumulator<R> remoteResults, Predicate<? super R> earlyTerminatePredicate) { ConsistentHash ch = dm.getWriteConsistentHash(); TerminalOperation<R> op = new SingleRunOperation(intermediateOperations, supplierForSegments(ch, segmentsToFilter, null), function); Object id = csm.remoteStreamOperation(getParallelDistribution(), parallel, ch, segmentsToFilter, keysToFilter, Collections.emptyMap(), includeLoader, op, remoteResults, earlyTerminatePredicate); try { R localValue = op.performOperation(); remoteResults.onCompletion(null, Collections.emptySet(), localValue); if (id != null) { try { if ((earlyTerminatePredicate == null || !earlyTerminatePredicate.test(localValue)) && !csm.awaitCompletion(id, timeout, timeoutUnit)) { throw new TimeoutException(); } } catch (InterruptedException e) { throw new CacheException(e); } } log.tracef("Finished operation for id %s", id); return remoteResults.currentValue; } finally { csm.forgetOperation(id); } } <R> R performOperationRehashAware(Function<? super S2, ? extends R> function, boolean retryOnRehash, ResultsAccumulator<R> remoteResults, Predicate<? super R> earlyTerminatePredicate) { Set<Integer> segmentsToProcess = segmentsToFilter; TerminalOperation<R> op; do { ConsistentHash ch = dm.getReadConsistentHash(); if (retryOnRehash) { op = new SegmentRetryingOperation(intermediateOperations, supplierForSegments(ch, segmentsToProcess, null), function); } else { op = new SingleRunOperation(intermediateOperations, supplierForSegments(ch, segmentsToProcess, null), function); } Object id = csm.remoteStreamOperationRehashAware(getParallelDistribution(), parallel, ch, segmentsToProcess, keysToFilter, Collections.emptyMap(), includeLoader, op, remoteResults, earlyTerminatePredicate); try { R localValue; boolean localRun = ch.getMembers().contains(localAddress); if (localRun) { localValue = op.performOperation(); // TODO: we can do this more efficiently - since we drop all results locally if (dm.getReadConsistentHash().equals(ch)) { Set<Integer> ourSegments = ch.getPrimarySegmentsForOwner(localAddress); if (segmentsToProcess != null) { ourSegments.retainAll(segmentsToProcess); } remoteResults.onCompletion(null, ourSegments, localValue); } else { if (segmentsToProcess != null) { Set<Integer> ourSegments = ch.getPrimarySegmentsForOwner(localAddress); ourSegments.retainAll(segmentsToProcess); remoteResults.onSegmentsLost(ourSegments); } else { remoteResults.onSegmentsLost(ch.getPrimarySegmentsForOwner(localAddress)); } } } else { // This isn't actually used because localRun short circuits first localValue = null; } if (id != null) { try { if ((!localRun || earlyTerminatePredicate == null || !earlyTerminatePredicate.test(localValue)) && !csm.awaitCompletion(id, timeout, timeoutUnit)) { throw new TimeoutException(); } } catch (InterruptedException e) { throw new CacheException(e); } } if (!remoteResults.lostSegments.isEmpty()) { segmentsToProcess = new HashSet<>(remoteResults.lostSegments); remoteResults.lostSegments.clear(); log.tracef("Found %s lost segments for identifier %s", segmentsToProcess, id); } else { // If we didn't lose any segments we don't need to process anymore if (segmentsToProcess != null) { segmentsToProcess = null; } log.tracef("Finished rehash aware operation for id %s", id); } } finally { csm.forgetOperation(id); } } while (segmentsToProcess != null && !segmentsToProcess.isEmpty()); return remoteResults.currentValue; } void performRehashKeyTrackingOperation( Function<Supplier<Stream<CacheEntry>>, KeyTrackingTerminalOperation<Object, ? extends T, Object>> function) { final AtomicBoolean complete = new AtomicBoolean(); ConsistentHash segmentInfoCH = dm.getReadConsistentHash(); KeyTrackingConsumer<Object, Object> results = new KeyTrackingConsumer<>(keyPartitioner, segmentInfoCH, (c) -> {}, c -> c, null); Set<Integer> segmentsToProcess = segmentsToFilter == null ? new RangeSet(segmentInfoCH.getNumSegments()) : segmentsToFilter; do { ConsistentHash ch = dm.getReadConsistentHash(); boolean localRun = ch.getMembers().contains(localAddress); Set<Integer> segments; Set<Object> excludedKeys; if (localRun) { segments = ch.getPrimarySegmentsForOwner(localAddress); segments.retainAll(segmentsToProcess); excludedKeys = segments.stream().flatMap(s -> results.referenceArray.get(s).stream()).collect( Collectors.toSet()); } else { // This null is okay as it is only referenced if it was a localRun segments = null; excludedKeys = Collections.emptySet(); } KeyTrackingTerminalOperation<Object, ? extends T, Object> op = function.apply(supplierForSegments(ch, segmentsToProcess, excludedKeys)); op.handleInjection(registry); Object id = csm.remoteStreamOperationRehashAware(getParallelDistribution(), parallel, ch, segmentsToProcess, keysToFilter, new AtomicReferenceArrayToMap<>(results.referenceArray), includeLoader, op, results); try { if (localRun) { Collection<CacheEntry<Object, Object>> localValue = op.performOperationRehashAware(results); // TODO: we can do this more efficiently - this hampers performance during rehash if (dm.getReadConsistentHash().equals(ch)) { log.tracef("Found local values %s for id %s", localValue.size(), id); results.onCompletion(null, segments, localValue); } else { Set<Integer> ourSegments = ch.getPrimarySegmentsForOwner(localAddress); ourSegments.retainAll(segmentsToProcess); log.tracef("CH changed - making %s segments suspect for identifier %s", ourSegments, id); results.onSegmentsLost(ourSegments); // We keep track of those keys so we don't fire them again results.onIntermediateResult(null, localValue); } } if (id != null) { try { if (!csm.awaitCompletion(id, timeout, timeoutUnit)) { throw new TimeoutException(); } } catch (InterruptedException e) { throw new CacheException(e); } } if (!results.lostSegments.isEmpty()) { segmentsToProcess = new HashSet<>(results.lostSegments); results.lostSegments.clear(); log.tracef("Found %s lost segments for identifier %s", segmentsToProcess, id); } else { log.tracef("Finished rehash aware operation for id %s", id); complete.set(true); } } finally { csm.forgetOperation(id); } } while (!complete.get()); } protected boolean isPrimaryOwner(ConsistentHash ch, CacheEntry e) { return localAddress.equals(ch.locatePrimaryOwnerForSegment(keyPartitioner.getSegment(e.getKey()))); } static class AtomicReferenceArrayToMap<R> extends AbstractMap<Integer, R> { final AtomicReferenceArray<R> array; AtomicReferenceArrayToMap(AtomicReferenceArray<R> array) { this.array = array; } @Override public boolean containsKey(Object o) { if (!(o instanceof Integer)) return false; int i = (int) o; return 0 <= i && i < array.length(); } @Override public R get(Object key) { if (!(key instanceof Integer)) return null; int i = (int) key; if (0 <= i && i < array.length()) { return array.get(i); } return null; } @Override public int size() { return array.length(); } @Override public boolean remove(Object key, Object value) { throw new UnsupportedOperationException(); } @Override public void clear() { throw new UnsupportedOperationException(); } @Override public Set<Entry<Integer, R>> entrySet() { // Do we want to implement this later? throw new UnsupportedOperationException(); } } class KeyTrackingConsumer<K, V> implements ClusterStreamManager.ResultsCallback<Collection<CacheEntry<K, Object>>>, KeyTrackingTerminalOperation.IntermediateCollector<Collection<CacheEntry<K, Object>>> { final KeyPartitioner keyPartitioner; final ConsistentHash ch; final Consumer<V> consumer; final Set<Integer> lostSegments = new ConcurrentHashSet<>(); final Function<CacheEntry<K, Object>, V> valueFunction; final AtomicReferenceArray<Set<K>> referenceArray; final DistributedCacheStream.SegmentListenerNotifier listenerNotifier; KeyTrackingConsumer(KeyPartitioner keyPartitioner, ConsistentHash ch, Consumer<V> consumer, Function<CacheEntry<K, Object>, V> valueFunction, DistributedCacheStream.SegmentListenerNotifier completedSegments) { this.keyPartitioner = keyPartitioner; this.ch = ch; this.consumer = consumer; this.valueFunction = valueFunction; this.listenerNotifier = completedSegments; this.referenceArray = new AtomicReferenceArray<>(ch.getNumSegments()); for (int i = 0; i < referenceArray.length(); ++i) { // We only allow 1 request per id referenceArray.set(i, new HashSet<>()); } } @Override public Set<Integer> onIntermediateResult(Address address, Collection<CacheEntry<K, Object>> results) { if (results != null) { log.tracef("Response from %s with results %s", address, results.size()); Set<Integer> segmentsCompleted; CacheEntry<K, Object>[] lastCompleted = new CacheEntry[1]; if (listenerNotifier != null) { segmentsCompleted = new HashSet<>(); } else { segmentsCompleted = null; } results.forEach(e -> { K key = e.getKey(); int segment = keyPartitioner.getSegment(key); Set<K> keys = referenceArray.get(segment); // On completion we null this out first - thus we don't need to add if (keys != null) { keys.add(key); } else if (segmentsCompleted != null) { segmentsCompleted.add(segment); lastCompleted[0] = e; } consumer.accept(valueFunction.apply(e)); }); if (lastCompleted[0] != null) { listenerNotifier.addSegmentsForObject(valueFunction.apply(lastCompleted[0]), segmentsCompleted); } return segmentsCompleted; } return null; } @Override public void onCompletion(Address address, Set<Integer> completedSegments, Collection<CacheEntry<K, Object>> results) { if (!completedSegments.isEmpty()) { log.tracef("Completing segments %s", completedSegments); // We null this out first so intermediate results don't add for no reason completedSegments.forEach(s -> referenceArray.set(s, null)); } else { log.tracef("No segments to complete from %s", address); } Set<Integer> valueSegments = onIntermediateResult(address, results); if (valueSegments != null) { // We don't want to modify the completed segments as the caller may need it Set<Integer> emptyCompletedSegments = new HashSet<>(completedSegments.size()); completedSegments.forEach(s -> { // First complete the segments that didn't have any keys - completed segments have to wait // until the user retrieves them if (!valueSegments.contains(s)) { emptyCompletedSegments.add(s); } }); listenerNotifier.completeSegmentsNoResults(emptyCompletedSegments); } } @Override public void onSegmentsLost(Set<Integer> segments) { // Have to use for loop since ConcurrentHashSet doesn't support addAll for (Integer segment : segments) { lostSegments.add(segment); } } @Override public void sendDataResonse(Collection<CacheEntry<K, Object>> response) { onIntermediateResult(null, response); } } static class ResultsAccumulator<R> implements ClusterStreamManager.ResultsCallback<R> { private final BinaryOperator<R> binaryOperator; private final Set<Integer> lostSegments = new ConcurrentHashSet<>(); R currentValue; ResultsAccumulator(BinaryOperator<R> binaryOperator) { this.binaryOperator = binaryOperator; } @Override public Set<Integer> onIntermediateResult(Address address, R results) { if (results != null) { synchronized (this) { if (currentValue != null) { currentValue = binaryOperator.apply(currentValue, results); } else { currentValue = results; } } } return null; } @Override public void onCompletion(Address address, Set<Integer> completedSegments, R results) { onIntermediateResult(address, results); } @Override public void onSegmentsLost(Set<Integer> segments) { // Have to use for loop since ConcurrentHashSet doesn't support addAll for (Integer segment : segments) { lostSegments.add(segment); } } } static class CollectionConsumer<R> implements ClusterStreamManager.ResultsCallback<Collection<R>>, KeyTrackingTerminalOperation.IntermediateCollector<Collection<R>> { private final Consumer<R> consumer; CollectionConsumer(Consumer<R> consumer) { this.consumer = consumer; } @Override public Set<Integer> onIntermediateResult(Address address, Collection<R> results) { if (results != null) { results.forEach(consumer); } return null; } @Override public void onCompletion(Address address, Set<Integer> completedSegments, Collection<R> results) { onIntermediateResult(address, results); } @Override public void onSegmentsLost(Set<Integer> segments) { } @Override public void sendDataResonse(Collection<R> response) { onIntermediateResult(null, response); } } protected Supplier<Stream<CacheEntry>> supplierForSegments(ConsistentHash ch, Set<Integer> targetSegments, Set<Object> excludedKeys) { return supplierForSegments(ch, targetSegments, excludedKeys, true); } /** * If <code>usePrimary</code> is true the segments are the primary segments but only those that exist in * targetSegments. However if <code>usePrimary</code> is false then <code>targetSegments</code> must be * provided and non null and this will be used specifically. * @param ch * @param targetSegments * @param excludedKeys * @param usePrimary determines whether we should utilize the primary segments or not. * @return */ protected Supplier<Stream<CacheEntry>> supplierForSegments(ConsistentHash ch, Set<Integer> targetSegments, Set<Object> excludedKeys, boolean usePrimary) { if (!ch.getMembers().contains(localAddress)) { return Stream::empty; } Set<Integer> segments; if (usePrimary) { segments = ch.getPrimarySegmentsForOwner(localAddress); if (targetSegments != null) { segments.retainAll(targetSegments); } } else { segments = targetSegments; } return () -> { if (segments.isEmpty()) { return Stream.empty(); } CacheStream<CacheEntry> stream = supplier.get().filterKeySegments(segments); if (keysToFilter != null) { stream = stream.filterKeys(keysToFilter); } if (excludedKeys != null) { return stream.filter(e -> !excludedKeys.contains(e.getKey())); } // Make sure the stream is set to be parallel or not return parallel ? stream.parallel() : stream.sequential(); }; } /** * Given two Runnables, return a Runnable that executes both in sequence, * even if the first throws an exception, and if both throw exceptions, add * any exceptions thrown by the second as suppressed exceptions of the first. */ static Runnable composeWithExceptions(Runnable a, Runnable b) { return () -> { try { a.run(); } catch (Throwable e1) { try { b.run(); } catch (Throwable e2) { try { e1.addSuppressed(e2); } catch (Throwable ignore) {} } throw e1; } b.run(); }; } enum IteratorOperation { NO_MAP { @Override public KeyTrackingTerminalOperation getOperation(Iterable<IntermediateOperation> intermediateOperations, Supplier<Stream<CacheEntry>> supplier, int batchSize) { return new NoMapIteratorOperation<>(intermediateOperations, supplier, batchSize); } @Override public <K, V, R> Function<CacheEntry<K, V>, R> getFunction() { return e -> (R) e; } }, MAP { @Override public KeyTrackingTerminalOperation getOperation(Iterable<IntermediateOperation> intermediateOperations, Supplier<Stream<CacheEntry>> supplier, int batchSize) { return new MapIteratorOperation<>(intermediateOperations, supplier, batchSize); } }, FLAT_MAP { @Override public KeyTrackingTerminalOperation getOperation(Iterable<IntermediateOperation> intermediateOperations, Supplier<Stream<CacheEntry>> supplier, int batchSize) { return new FlatMapIteratorOperation<>(intermediateOperations, supplier, batchSize); } @Override public <V, V2> Consumer<V2> wrapConsumer(Consumer<V> consumer) { return new CollectionDecomposerConsumer(consumer); } }; public abstract KeyTrackingTerminalOperation getOperation(Iterable<IntermediateOperation> intermediateOperations, Supplier<Stream<CacheEntry>> supplier, int batchSize); public <K, V, R> Function<CacheEntry<K, V>, R> getFunction() { return e -> (R) e.getValue(); } public <V, V2> Consumer<V2> wrapConsumer(Consumer<V> consumer) { return (Consumer<V2>) consumer; } } static class CollectionDecomposerConsumer<E> implements Consumer<Iterable<E>> { private final Consumer<E> consumer; CollectionDecomposerConsumer(Consumer<E> consumer) { this.consumer = consumer; } @Override public void accept(Iterable<E> es) { es.forEach(consumer); } } /** * Given two SegmentCompletionListener, return a SegmentCompletionListener that * executes both in sequence, even if the first throws an exception, and if both * throw exceptions, add any exceptions thrown by the second as suppressed * exceptions of the first. */ protected static CacheStream.SegmentCompletionListener composeWithExceptions(CacheStream.SegmentCompletionListener a, CacheStream.SegmentCompletionListener b) { return (segments) -> { try { a.segmentCompleted(segments); } catch (Throwable e1) { try { b.segmentCompleted(segments); } catch (Throwable e2) { try { e1.addSuppressed(e2); } catch (Throwable ignore) {} } throw e1; } b.segmentCompleted(segments); }; } }