package com.github.davidmoten.rx.internal.operators; import java.util.HashMap; import java.util.LinkedList; import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import com.github.davidmoten.util.Preconditions; import rx.Observable; import rx.Observable.OnSubscribe; import rx.Producer; import rx.Subscriber; import rx.functions.Func1; import rx.functions.Func2; import rx.internal.operators.BackpressureUtils; import rx.internal.util.unsafe.MpscLinkedQueue; import rx.internal.util.unsafe.UnsafeAccess; public final class OnSubscribeMatch<A, B, K, C> implements OnSubscribe<C> { private final Observable<A> a; private final Observable<B> b; private final Func1<? super A, ? extends K> aKey; private final Func1<? super B, ? extends K> bKey; private final Func2<? super A, ? super B, C> combiner; private final long requestSize; private static final Object NULL_SENTINEL = new Object(); public OnSubscribeMatch(Observable<A> a, Observable<B> b, Func1<? super A, ? extends K> aKey, Func1<? super B, ? extends K> bKey, Func2<? super A, ? super B, C> combiner, long requestSize) { Preconditions.checkNotNull(a, "a should not be null"); Preconditions.checkNotNull(b, "b should not be null"); Preconditions.checkNotNull(aKey, "aKey cannot be null"); Preconditions.checkNotNull(bKey, "bKey cannot be null"); Preconditions.checkNotNull(combiner, "combiner cannot be null"); Preconditions.checkArgument(requestSize >= 1, "requestSize must be >=1"); this.a = a; this.b = b; this.aKey = aKey; this.bKey = bKey; this.combiner = combiner; this.requestSize = requestSize; } @Override public void call(Subscriber<? super C> child) { AtomicReference<Receiver> receiverHolder = new AtomicReference<Receiver>(); MySubscriber<A, K> aSub = new MySubscriber<A, K>(Source.A, receiverHolder, requestSize); MySubscriber<B, K> bSub = new MySubscriber<B, K>(Source.B, receiverHolder, requestSize); child.add(aSub); child.add(bSub); MyProducer<A, B, K, C> producer = new MyProducer<A, B, K, C>(a, b, aKey, bKey, combiner, aSub, bSub, child, requestSize); receiverHolder.set(producer); child.setProducer(producer); a.unsafeSubscribe(aSub); b.unsafeSubscribe(bSub); } @SuppressWarnings("serial") private static final class MyProducer<A, B, K, C> extends AtomicInteger implements Producer, Receiver { // extends AtomicInteger as a work-in-progress atomic (wip) private final Queue<Object> queue; private final Map<K, Queue<A>> as = new HashMap<K, Queue<A>>(); private final Map<K, Queue<B>> bs = new HashMap<K, Queue<B>>(); private final Func1<? super A, ? extends K> aKey; private final Func1<? super B, ? extends K> bKey; private final Func2<? super A, ? super B, C> combiner; private final Subscriber<? super C> child; private final MySubscriber<A, K> aSub; private final MySubscriber<B, K> bSub; private final long requestSize; private final AtomicLong requested = new AtomicLong(0); // mutable fields, guarded by `this` atomics private int requestFromA = 0; private int requestFromB = 0; // completion state machine private int completed = COMPLETED_NONE; // completion states private static final int COMPLETED_NONE = 0; private static final int COMPLETED_A = 1; private static final int COMPLETED_B = 2; private static final int COMPLETED_BOTH = 3; MyProducer(Observable<A> a, Observable<B> b, Func1<? super A, ? extends K> aKey, Func1<? super B, ? extends K> bKey, Func2<? super A, ? super B, C> combiner, MySubscriber<A, K> aSub, MySubscriber<B, K> bSub, Subscriber<? super C> child, long requestSize) { this.aKey = aKey; this.bKey = bKey; this.combiner = combiner; this.child = child; this.aSub = aSub; this.bSub = bSub; this.requestSize = requestSize; if (UnsafeAccess.isUnsafeAvailable()) { queue = new MpscLinkedQueue<Object>(); } else { queue = new ConcurrentLinkedQueue<Object>(); } } @Override public void request(long n) { if (BackpressureUtils.validate(n)) { BackpressureUtils.getAndAddRequest(requested, n); drain(); } } void drain() { if (getAndIncrement() != 0) { // work already in progress // so exit return; } int missed = 1; while (true) { long r = requested.get(); int emitted = 0; while (r > emitted) { if (child.isUnsubscribed()) { return; } // note will not return null Object v = queue.poll(); if (v == null) { // queue is empty break; } else if (v instanceof ItemA) { Emitted em = handleItem(((ItemA) v).value, Source.A); if (em == Emitted.FINISHED) { return; } else if (em == Emitted.ONE) { emitted += 1; } } else if (v instanceof Source) { // source completed Status status = handleCompleted((Source) v); if (status == Status.FINISHED) { return; } } else if (v instanceof MyError) { // v must be an error clear(); child.onError(((MyError) v).error); return; } else { // is onNext from B Emitted em = handleItem(v, Source.B); if (em == Emitted.FINISHED) { return; } else if (em == Emitted.ONE) { emitted += 1; } } if (r == emitted) { break; } } if (emitted > 0) { // reduce requested by emitted BackpressureUtils.produced(requested, emitted); } missed = this.addAndGet(-missed); if (missed == 0 ) { return; } } } private Emitted handleItem(Object value, Source source) { final Emitted result; // logic duplication occurs below // would be nice to simplify without making code // unreadable. A bit of a toss-up. if (source == Source.A) { // look for match @SuppressWarnings("unchecked") A a = (A) value; K key; try { key = aKey.call(a); } catch (Throwable e) { clear(); child.onError(e); return Emitted.FINISHED; } Queue<B> q = bs.get(key); if (q == null) { // cache value add(as, key, a); result = Emitted.NONE; } else { // emit match B b = poll(bs, q, key); C c; try { c = combiner.call(replaceSentinel(a), replaceSentinel(b)); } catch (Throwable e) { clear(); child.onError(e); return Emitted.FINISHED; } child.onNext(c); result = Emitted.ONE; } // if the other source has completed and there // is nothing to match with then we should stop if (completed == COMPLETED_B && bs.isEmpty()) { // can finish clear(); child.onCompleted(); return Emitted.FINISHED; } else { requestFromA += 1; } } else { // look for match @SuppressWarnings("unchecked") B b = (B) value; K key; try { key = bKey.call(b); } catch (Throwable e) { clear(); child.onError(e); return Emitted.FINISHED; } Queue<A> q = as.get(key); if (q == null) { // cache value add(bs, key, b); result = Emitted.NONE; } else { // emit match A a = poll(as, q, key); C c; try { c = combiner.call(replaceSentinel(a), replaceSentinel(b)); } catch (Throwable e) { clear(); child.onError(e); return Emitted.FINISHED; } child.onNext(c); result = Emitted.ONE; } // if the other source has completed and there // is nothing to match with then we should stop if (completed == COMPLETED_A && as.isEmpty()) { // can finish clear(); child.onCompleted(); return Emitted.FINISHED; } else { requestFromB += 1; } } // requests are batched so that each source gets a turn checkToRequestMore(); return result; } private enum Emitted { ONE, NONE, FINISHED; } private Status handleCompleted(Source source) { completed(source); final boolean done; if (source == Source.A) { aSub.unsubscribe(); done = (completed == COMPLETED_BOTH) || (completed == COMPLETED_A && as.isEmpty()); } else { bSub.unsubscribe(); done = (completed == COMPLETED_BOTH) || (completed == COMPLETED_B && bs.isEmpty()); } if (done) { clear(); child.onCompleted(); return Status.FINISHED; } else { checkToRequestMore(); return Status.KEEP_GOING; } } private enum Status { FINISHED, KEEP_GOING; } private void checkToRequestMore() { if (requestFromA == requestSize && completed == COMPLETED_B) { requestFromA = 0; aSub.requestMore(requestSize); } else if (requestFromB == requestSize && completed == COMPLETED_A) { requestFromB = 0; bSub.requestMore(requestSize); } else if (requestFromA == requestSize && requestFromB == requestSize) { requestFromA = 0; requestFromB = 0; aSub.requestMore(requestSize); bSub.requestMore(requestSize); } } private void completed(Source source) { if (source == Source.A) { if (completed == COMPLETED_NONE) { completed = COMPLETED_A; } else if (completed == COMPLETED_B) { completed = COMPLETED_BOTH; } } else { if (completed == COMPLETED_NONE) { completed = COMPLETED_B; } else if (completed == COMPLETED_A) { completed = COMPLETED_BOTH; } } } private void clear() { as.clear(); bs.clear(); queue.clear(); aSub.unsubscribe(); bSub.unsubscribe(); } private static <K, T> void add(Map<K, Queue<T>> map, K key, T value) { Queue<T> q = map.get(key); if (q == null) { q = new LinkedList<T>(); map.put(key, q); } q.offer(value); } private static <K, T> T poll(Map<K, Queue<T>> map, Queue<T> q, K key) { T t = q.poll(); if (q.isEmpty()) { map.remove(key); } return t; } @Override public void offer(Object item) { queue.offer(item); drain(); } private static <T> T replaceSentinel(T t) { if (t == NULL_SENTINEL) { return null; } else { return t; } } } interface Receiver { void offer(Object item); } static final class MySubscriber<T, K> extends Subscriber<T> { private final AtomicReference<Receiver> receiver; private final Source source; MySubscriber(Source source, AtomicReference<Receiver> receiver, long requestSize) { this.source = source; this.receiver = receiver; request(requestSize); } @Override public void onNext(T t) { if (source == Source.A) { receiver.get().offer(new ItemA(replaceNull(t))); } else { receiver.get().offer(replaceNull(t)); } } private static Object replaceNull(Object t) { if (t == null) { return NULL_SENTINEL; } else { return t; } } @Override public void onCompleted() { receiver.get().offer(source); } @Override public void onError(Throwable e) { receiver.get().offer(new MyError(e)); } public void requestMore(long n) { request(n); } } static final class MyError { final Throwable error; MyError(Throwable error) { this.error = error; } } static final class ItemA { final Object value; ItemA(Object value) { this.value = value; } } enum Source { A, B; } }