/** * Copyright 2014 Netflix, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package rx.internal.operators; import java.util.*; import java.util.concurrent.atomic.*; import rx.*; import rx.Observable; import rx.exceptions.*; import rx.functions.*; import rx.internal.util.RxRingBuffer; import rx.observables.ConnectableObservable; import rx.subscriptions.Subscriptions; public class OperatorPublish<T> extends ConnectableObservable<T> { final Observable<? extends T> source; private final RequestHandler<T> requestHandler; public static <T> ConnectableObservable<T> create(Observable<? extends T> source) { return new OperatorPublish<T>(source); } public static <T, R> Observable<R> create(final Observable<? extends T> source, final Func1<? super Observable<T>, ? extends Observable<R>> selector) { return Observable.create(new OnSubscribe<R>() { @Override public void call(final Subscriber<? super R> child) { OperatorPublish<T> op = new OperatorPublish<T>(source); selector.call(op).unsafeSubscribe(child); op.connect(new Action1<Subscription>() { @Override public void call(Subscription sub) { child.add(sub); } }); } }); } private OperatorPublish(Observable<? extends T> source) { this(source, new Object(), new RequestHandler<T>()); } private OperatorPublish(Observable<? extends T> source, final Object guard, final RequestHandler<T> requestHandler) { super(new OnSubscribe<T>() { @Override public void call(final Subscriber<? super T> subscriber) { subscriber.setProducer(new Producer() { @Override public void request(long n) { requestHandler.requestFromChildSubscriber(subscriber, n); } }); subscriber.add(Subscriptions.create(new Action0() { @Override public void call() { requestHandler.state.removeSubscriber(subscriber); } })); } }); this.source = source; this.requestHandler = requestHandler; } @Override public void connect(Action1<? super Subscription> connection) { // each time we connect we create a new Subscription boolean shouldSubscribe = false; // subscription is the state of whether we are connected or not OriginSubscriber<T> origin = requestHandler.state.getOrigin(); if (origin == null) { shouldSubscribe = true; requestHandler.state.setOrigin(new OriginSubscriber<T>(requestHandler)); } // in the lock above we determined we should subscribe, do it now outside the lock if (shouldSubscribe) { // register a subscription that will shut this down connection.call(Subscriptions.create(new Action0() { @Override public void call() { OriginSubscriber<T> s = requestHandler.state.getOrigin(); requestHandler.state.setOrigin(null); if (s != null) { s.unsubscribe(); } } })); // now that everything is hooked up let's subscribe // as long as the subscription is not null (which can happen if already unsubscribed) OriginSubscriber<T> os = requestHandler.state.getOrigin(); if (os != null) { source.unsafeSubscribe(os); } } } private static class OriginSubscriber<T> extends Subscriber<T> { private final RequestHandler<T> requestHandler; private final AtomicLong originOutstanding = new AtomicLong(); private final long THRESHOLD = RxRingBuffer.SIZE / 4; private final RxRingBuffer buffer = RxRingBuffer.getSpmcInstance(); OriginSubscriber(RequestHandler<T> requestHandler) { this.requestHandler = requestHandler; add(buffer); } @Override public void onStart() { requestMore(RxRingBuffer.SIZE); } private void requestMore(long r) { originOutstanding.addAndGet(r); request(r); } @Override public void onCompleted() { try { requestHandler.emit(requestHandler.notifier.completed()); } catch (MissingBackpressureException e) { onError(e); } } @Override public void onError(Throwable e) { List<Throwable> errors = null; for (Subscriber<? super T> subscriber : requestHandler.state.getSubscribers()) { try { subscriber.onError(e); } catch (Throwable e2) { if (errors == null) { errors = new ArrayList<Throwable>(); } errors.add(e2); } } Exceptions.throwIfAny(errors); } @Override public void onNext(T t) { try { requestHandler.emit(requestHandler.notifier.next(t)); } catch (MissingBackpressureException e) { onError(e); } } } /** * Synchronized mutable state. * * benjchristensen => I have not figured out a non-blocking approach to this that doesn't involve massive object allocation overhead * with a complicated state machine so I'm sticking with mutex locks and just trying to make sure the work done while holding the * lock is small (such as never emitting data). * * This does however mean we can't rely on a reference to State being consistent. For example, it can end up with a null OriginSubscriber. * * @param <T> */ private static class State<T> { private long outstandingRequests = -1; private OriginSubscriber<T> origin; // using AtomicLong to simplify mutating it, not for thread-safety since we're synchronizing access to this class // using LinkedHashMap so the order of Subscribers having onNext invoked is deterministic (same each time the code is run) private final Map<Subscriber<? super T>, AtomicLong> ss = new LinkedHashMap<Subscriber<? super T>, AtomicLong>(); @SuppressWarnings("unchecked") private Subscriber<? super T>[] subscribers = new Subscriber[0]; public synchronized OriginSubscriber<T> getOrigin() { return origin; } public synchronized void setOrigin(OriginSubscriber<T> o) { this.origin = o; } public synchronized boolean canEmitWithDecrement() { if (outstandingRequests > 0) { outstandingRequests--; return true; } return false; } public synchronized boolean hasNoSubscriber() { return subscribers.length == 0; } public synchronized void incrementOutstandingAfterFailedEmit() { outstandingRequests++; } public synchronized Subscriber<? super T>[] getSubscribers() { return subscribers; } /** * @return long outstandingRequests */ public synchronized long requestFromSubscriber(Subscriber<? super T> subscriber, long request) { Map<Subscriber<? super T>, AtomicLong> subs = ss; AtomicLong r = subs.get(subscriber); if (r == null) { subs.put(subscriber, new AtomicLong(request)); } else { do { long current = r.get(); if (current == Long.MAX_VALUE) { break; } long u = current + request; if (u < 0) { u = Long.MAX_VALUE; } if (r.compareAndSet(current, u)) { break; } } while (true); } return resetAfterSubscriberUpdate(subs); } public synchronized void removeSubscriber(Subscriber<? super T> subscriber) { Map<Subscriber<? super T>, AtomicLong> subs = ss; subs.remove(subscriber); resetAfterSubscriberUpdate(subs); } @SuppressWarnings("unchecked") private long resetAfterSubscriberUpdate(Map<Subscriber<? super T>, AtomicLong> subs) { Subscriber<? super T>[] subscriberArray = new Subscriber[subs.size()]; int i = 0; long lowest = -1; for (Map.Entry<Subscriber<? super T>, AtomicLong> e : subs.entrySet()) { subscriberArray[i++] = e.getKey(); AtomicLong l = e.getValue(); long c = l.get(); if (lowest == -1 || c < lowest) { lowest = c; } } this.subscribers = subscriberArray; /* * when receiving a request from a subscriber we reset 'outstanding' to the lowest of all subscribers */ outstandingRequests = lowest; return lowest; } } private static class RequestHandler<T> { private final NotificationLite<T> notifier = NotificationLite.instance(); private final State<T> state = new State<T>(); @SuppressWarnings("unused") volatile long wip; @SuppressWarnings("rawtypes") static final AtomicLongFieldUpdater<RequestHandler> WIP = AtomicLongFieldUpdater.newUpdater(RequestHandler.class, "wip"); public void requestFromChildSubscriber(Subscriber<? super T> subscriber, long request) { state.requestFromSubscriber(subscriber, request); OriginSubscriber<T> originSubscriber = state.getOrigin(); if(originSubscriber != null) { drainQueue(originSubscriber); } } public void emit(Object t) throws MissingBackpressureException { OriginSubscriber<T> originSubscriber = state.getOrigin(); if(originSubscriber == null) { // unsubscribed so break ... we are done return; } if (notifier.isCompleted(t)) { originSubscriber.buffer.onCompleted(); } else { originSubscriber.buffer.onNext(notifier.getValue(t)); } drainQueue(originSubscriber); } private void requestMoreAfterEmission(int emitted) { if (emitted > 0) { OriginSubscriber<T> origin = state.getOrigin(); if (origin != null) { long r = origin.originOutstanding.addAndGet(-emitted); if (r <= origin.THRESHOLD) { origin.requestMore(RxRingBuffer.SIZE - origin.THRESHOLD); } } } } public void drainQueue(OriginSubscriber<T> originSubscriber) { if (WIP.getAndIncrement(this) == 0) { State<T> localState = state; Map<Subscriber<? super T>, AtomicLong> localMap = localState.ss; RxRingBuffer localBuffer = originSubscriber.buffer; NotificationLite<T> nl = notifier; int emitted = 0; do { /* * Set to 1 otherwise it could have grown very large while in the last poll loop * and then we can end up looping all those times again here before exiting even once we've drained */ WIP.set(this, 1); /** * This is done in the most inefficient possible way right now and we can revisit the approach. * If we want to batch this then we need to account for new subscribers arriving with a lower request count * concurrently while iterating the batch ... or accept that they won't */ while (true) { if (localState.hasNoSubscriber()) { // Drop items due to no subscriber if (localBuffer.poll() == null) { // Exit due to no more item break; } else { // Keep dropping cached items. continue; } } boolean shouldEmit = localState.canEmitWithDecrement(); if (!shouldEmit) { break; } Object o = localBuffer.poll(); if (o == null) { // nothing in buffer so increment outstanding back again localState.incrementOutstandingAfterFailedEmit(); break; } for (Subscriber<? super T> s : localState.getSubscribers()) { AtomicLong req = localMap.get(s); if (req != null) { // null req indicates a concurrent unsubscription happened nl.accept(s, o); req.decrementAndGet(); } } emitted++; } } while (WIP.decrementAndGet(this) > 0); requestMoreAfterEmission(emitted); } } } }