package com.github.davidmoten.rx.testing; import static com.github.davidmoten.util.Optional.of; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import junit.framework.TestCase; import junit.framework.TestSuite; import org.junit.runner.RunWith; import org.junit.runners.Suite; import org.junit.runners.Suite.SuiteClasses; import rx.Observable; import rx.Subscriber; import rx.functions.Action0; import rx.functions.Func1; import com.github.davidmoten.util.Optional; import com.github.davidmoten.util.Preconditions; /** * Testing utility functions. */ public final class TestingHelper { private static final Optional<Long> ABSENT = Optional.absent(); /** * Sets the transformation to be tested and returns a builder to create test * cases. * * @param function * the transformation to be tested * @param <T> * generic type of the from side of the transformation being * tested * @param <R> * generic type of the to side of the transformation being tested * @return builder for creating test cases */ public static <T, R> Builder<T, R> function(Func1<Observable<T>, Observable<R>> function) { return new Builder<T, R>().function(function); } /** * Defines test cases and builds a JUnit test suite. * * @param <T> * generic type of the from side of the transformation being * tested * @param <R> * generic type of the to side of the transformation being tested */ public static class Builder<T, R> { private final List<Case<T, R>> cases = new ArrayList<Case<T, R>>(); private Func1<Observable<T>, Observable<R>> function; private long waitForUnusbscribeMs = 100; private long waitForTerminalEventMs = 10000; private long waitForMoreTerminalEventsMs = 50; private Builder() { // must instantiate via TestingHelper.function method above } /** * Sets transformation to be tested and returns the current builder. * * @param function * transformation to be tested * @return builder */ public Builder<T, R> function(Func1<Observable<T>, Observable<R>> function) { Preconditions.checkNotNull(function, "function cannot be null"); this.function = function; return this; } /** * Sets duration to wait for unusubscription to occur (either of source * or of downstream subscriber). * * @param duration * number of time units * @param unit * time unit * @return builder */ public Builder<T, R> waitForUnsubscribe(long duration, TimeUnit unit) { Preconditions.checkNotNull(unit, "unit cannot be null"); waitForUnusbscribeMs = unit.toMillis(duration); return this; } /** * Sets duration to wait for a terminal event (completed or error) when * one is expected. * * @param duration * number of time units * @param unit * time unit * @return builder */ public Builder<T, R> waitForTerminalEvent(long duration, TimeUnit unit) { Preconditions.checkNotNull(unit, "unit cannot be null"); waitForTerminalEventMs = unit.toMillis(duration); return this; } /** * Sets duration to wait for more terminal events after one has been * received. * * @param duration * number of time units * @param unit * time unit * @return builder */ public Builder<T, R> waitForMoreTerminalEvents(long duration, TimeUnit unit) { Preconditions.checkNotNull(unit, "unit cannot be null"); waitForMoreTerminalEventsMs = unit.toMillis(duration); return this; } /** * Sets the name of the test which is used in the name of a junit test. * * @param name * name of the test * @return case builder */ public CaseBuilder<T, R> name(String name) { Preconditions.checkNotNull(name, "name cannot be null"); return new CaseBuilder<T, R>(this, Observable.<T> empty(), name); } /** * Returns the JUnit {@link TestSuite} comprised of the test cases * created so far. The cases will be listed under the root test named * according to the given class. * * @param cls * class corresponding to the tests root * @return test suite */ public TestSuite testSuite(Class<?> cls) { Preconditions.checkNotNull(cls, "cls cannot be null"); return new TestSuiteFromCases<T, R>(cls, new ArrayList<Case<T, R>>(this.cases)); } private Builder<T, R> expect(Observable<T> from, Optional<List<R>> expected, boolean ordered, Optional<Long> expectSize, boolean checkSourceUnsubscribed, String name, Optional<Integer> unsubscribeAfter, Optional<Class<? extends Throwable>> expectError, Optional<Class<? extends RuntimeException>> expectException) { cases.add(new Case<T, R>(from, expected, ordered, expectSize, checkSourceUnsubscribed, function, name, unsubscribeAfter, expectError, waitForUnusbscribeMs, waitForTerminalEventMs, waitForMoreTerminalEventsMs, expectException)); return this; } } public static class CaseBuilder<T, R> { private final Builder<T, R> builder; private String name; private Observable<T> from = Observable.empty(); private boolean checkSourceUnsubscribed = true; private Optional<Integer> unsubscribeAfter = Optional.absent(); private CaseBuilder(Builder<T, R> builder, Observable<T> from, String name) { Preconditions.checkNotNull(builder); Preconditions.checkNotNull(from); Preconditions.checkNotNull(name); this.builder = builder; this.from = from; this.name = name; } public CaseBuilder<T, R> name(String name) { Preconditions.checkNotNull(name, "name cannot be null"); this.name = name; return this; } public CaseBuilder<T, R> fromEmpty() { from = Observable.empty(); return this; } public CaseBuilder<T, R> from(T... source) { Preconditions.checkNotNull(source, "source cannot be null"); from = Observable.from(source); return this; } public CaseBuilder<T, R> from(Observable<T> source) { Preconditions.checkNotNull(source, "source cannot be null"); from = source; return this; } public CaseBuilder<T, R> fromError() { from = Observable.error(new TestingException()); return this; } public CaseBuilder<T, R> fromErrorAfter(T... source) { Preconditions.checkNotNull(source, "source cannot be null"); from = Observable.from(source).concatWith(Observable.<T> error(new TestingException())); return this; } public CaseBuilder<T, R> fromErrorAfter(Observable<T> source) { Preconditions.checkNotNull(source, "source cannot be null"); from = source; return this; } public CaseBuilder<T, R> skipUnsubscribedCheck() { this.checkSourceUnsubscribed = false; return this; } public Builder<T, R> expectEmpty() { return expect(Collections.<R> emptyList()); } public Builder<T, R> expectError() { return expectError(TestingException.class); } @SuppressWarnings("unchecked") public Builder<T, R> expectError(Class<? extends Throwable> cls) { Preconditions.checkNotNull(cls, "cls cannot be null"); return builder.expect(from, Optional.<List<R>> absent(), true, ABSENT, checkSourceUnsubscribed, name, unsubscribeAfter, (Optional<Class<? extends Throwable>>) (Optional<?>) of(cls), Optional.<Class<? extends RuntimeException>> absent()); } public Builder<T, R> expect(R... source) { Preconditions.checkNotNull(source, "source cannot be null"); return expect(Arrays.asList(source)); } public Builder<T, R> expectSize(long n) { return builder.expect(from, Optional.<List<R>> absent(), true, of(n), checkSourceUnsubscribed, name, unsubscribeAfter, Optional.<Class<? extends Throwable>> absent(), Optional.<Class<? extends RuntimeException>> absent()); } public Builder<T, R> expect(List<R> source) { Preconditions.checkNotNull(source, "source cannot be null"); return expect(source, true); } private Builder<T, R> expect(List<R> items, boolean ordered) { return builder.expect(from, of(items), ordered, ABSENT, checkSourceUnsubscribed, name, unsubscribeAfter, Optional.<Class<? extends Throwable>> absent(), Optional.<Class<? extends RuntimeException>> absent()); } public Builder<T, R> expectAnyOrder(R... source) { Preconditions.checkNotNull(source, "source cannot be null"); return expect(Arrays.asList(source), false); } public CaseBuilder<T, R> unsubscribeAfter(int n) { unsubscribeAfter = of(n); return this; } @SuppressWarnings("unchecked") public Builder<T, R> expectException(Class<? extends RuntimeException> cls) { return builder.expect(from, Optional.<List<R>> absent(), true, ABSENT, checkSourceUnsubscribed, name, unsubscribeAfter, Optional.<Class<? extends Throwable>> absent(), (Optional<Class<? extends RuntimeException>>) (Optional<?>) Optional.of(cls)); } } private static class Case<T, R> { final String name; final Observable<T> from; final Optional<List<R>> expected; final boolean checkSourceUnsubscribed; final Func1<Observable<T>, Observable<R>> function; final Optional<Integer> unsubscribeAfter; final boolean ordered; final Optional<Long> expectSize; final Optional<Class<? extends Throwable>> expectError; final long waitForUnusbscribeMs; final long waitForTerminalEventMs; final long waitForMoreTerminalEventsMs; final Optional<Class<? extends RuntimeException>> expectedException; Case(Observable<T> from, Optional<List<R>> expected, boolean ordered, Optional<Long> expectSize, boolean checkSourceUnsubscribed, Func1<Observable<T>, Observable<R>> function, String name, Optional<Integer> unsubscribeAfter, Optional<Class<? extends Throwable>> expectError, long waitForUnusbscribeMs, long waitForTerminalEventMs, long waitForMoreTerminalEventsMs, Optional<Class<? extends RuntimeException>> expectedException) { Preconditions.checkNotNull(from); Preconditions.checkNotNull(expected); Preconditions.checkNotNull(expectSize); Preconditions.checkNotNull(function); Preconditions.checkNotNull(name); Preconditions.checkNotNull(unsubscribeAfter); Preconditions.checkNotNull(expectError); Preconditions.checkNotNull(expectedException); this.from = from; this.expected = expected; this.ordered = ordered; this.expectSize = expectSize; this.checkSourceUnsubscribed = checkSourceUnsubscribed; this.function = function; this.name = name; this.unsubscribeAfter = unsubscribeAfter; this.expectError = expectError; this.waitForUnusbscribeMs = waitForUnusbscribeMs; this.waitForTerminalEventMs = waitForTerminalEventMs; this.waitForMoreTerminalEventsMs = waitForMoreTerminalEventsMs; this.expectedException = expectedException; } } private static <T, R> void runTest(Case<T, R> c, TestType testType) { try { CountDownLatch sourceUnsubscribeLatch = new CountDownLatch(1); MyTestSubscriber<R> sub = createTestSubscriber(testType, c.unsubscribeAfter); c.function.call(c.from.doOnUnsubscribe(countDown(sourceUnsubscribeLatch))) .subscribe(sub); if (c.unsubscribeAfter.isPresent()) { waitForUnsubscribe(sourceUnsubscribeLatch, c.waitForUnusbscribeMs, TimeUnit.MILLISECONDS); // if unsubscribe has occurred there is no mandated behaviour in // terms of terminal events so we don't check them } else { sub.awaitTerminalEvent(c.waitForTerminalEventMs, TimeUnit.MILLISECONDS); if (c.expectError.isPresent()) { sub.assertError(c.expectError.get()); // wait for more terminal events pause(c.waitForMoreTerminalEventsMs, TimeUnit.MILLISECONDS); if (sub.numOnCompletedEvents() > 0) throw new UnexpectedOnCompletedException(); } else { sub.assertNoErrors(); // wait for more terminal events pause(c.waitForMoreTerminalEventsMs, TimeUnit.MILLISECONDS); if (sub.numOnCompletedEvents() > 1) throw new TooManyOnCompletedException(); sub.assertNoErrors(); } } if (c.expected.isPresent()) sub.assertReceivedOnNext(c.expected.get(), c.ordered); if (c.expectSize.isPresent()) sub.assertReceivedCountIs(c.expectSize.get()); sub.assertUnsubscribed(); if (c.checkSourceUnsubscribed) waitForUnsubscribe(sourceUnsubscribeLatch, c.waitForUnusbscribeMs, TimeUnit.MILLISECONDS); if (c.expectedException.isPresent()) throw new ExpectedExceptionNotThrownException(); } catch (RuntimeException e) { if (!c.expectedException.isPresent() || !c.expectedException.get().isInstance(e)) throw e; // otherwise was expected } } private static Action0 countDown(final CountDownLatch latch) { return new Action0() { @Override public void call() { latch.countDown(); } }; } private static <T> void waitForUnsubscribe(CountDownLatch latch, long duration, TimeUnit unit) { try { if (!latch.await(duration, unit)) throw new UnsubscriptionFromSourceTimeoutException(); } catch (InterruptedException e) { // do nothing } } public static class UnsubscriptionFromSourceTimeoutException extends RuntimeException { private static final long serialVersionUID = -1142604414390722544L; } private static void pause(long duration, TimeUnit unit) { try { Thread.sleep(unit.toMillis(duration)); } catch (InterruptedException e) { // do nothing } } private static final class MyTestSubscriber<T> extends Subscriber<T> { private final List<T> next = new ArrayList<T>(); private final Optional<Long> onStartRequest; private final Optional<Long> onNextRequest; private final Optional<Integer> unsubscribeAfter; private final CountDownLatch terminalLatch; private int completed = 0; private int count = 0; private int errors = 0; private final AtomicLong expected = new AtomicLong(); private Optional<Throwable> lastError = Optional.absent(); private Optional<Long> onNextRequest2; MyTestSubscriber(Optional<Integer> unsubscribeAfter, final Optional<Long> onStartRequest, final Optional<Long> onNextRequest, final Optional<Long> onNextRequest2) { this.unsubscribeAfter = unsubscribeAfter; this.onStartRequest = onStartRequest; this.onNextRequest = onNextRequest; this.onNextRequest2 = onNextRequest2; this.terminalLatch = new CountDownLatch(1); } MyTestSubscriber(Optional<Integer> unsubscribeAfter) { this(unsubscribeAfter, ABSENT, ABSENT, ABSENT); } @Override public void onStart() { if (!onStartRequest.isPresent()) // if nothing requested in onStart then must be requesting all expected.set(Long.MAX_VALUE); else expected.set(0); if (onStartRequest.isPresent()) requestMore(onStartRequest.get()); } private void requestMore(long n) { if (expected.get() != Long.MAX_VALUE) { if (n > 0) expected.addAndGet(n); // allow zero or negative requests to pass through as a test request(n); } } @Override public void onCompleted() { completed++; terminalLatch.countDown(); } @Override public void onError(Throwable e) { errors++; lastError = of(e); terminalLatch.countDown(); } @Override public void onNext(T t) { final long exp; if (expected.get() != Long.MAX_VALUE) exp = expected.decrementAndGet(); else exp = expected.get(); next.add(t); count++; if (exp < 0) onError(new DeliveredMoreThanRequestedException()); else if (unsubscribeAfter.isPresent() && count == unsubscribeAfter.get()) unsubscribe(); else { if (onNextRequest.isPresent()) requestMore(onNextRequest.get()); if (onNextRequest2.isPresent()) requestMore(onNextRequest2.get()); } } void assertError(Class<?> cls) { if (errors != 1 || !cls.isInstance(lastError.get())) throw new ExpectedErrorNotReceivedException(); } void assertReceivedCountIs(long count) { if (count != next.size()) throw new WrongOnNextCountException(); } void awaitTerminalEvent(long duration, TimeUnit unit) { try { if (!terminalLatch.await(duration, unit)) throw new TerminalEventTimeoutException(); } catch (InterruptedException e) { // do nothing } } void assertReceivedOnNext(List<T> expected, boolean ordered) { if (!TestingHelper.equals(expected, next, ordered)) throw new UnexpectedOnNextException("expected=" + expected + ", actual=" + next); } void assertUnsubscribed() { if (!isUnsubscribed()) throw new DownstreamUnsubscriptionDidNotOccurException(); } int numOnCompletedEvents() { return completed; } void assertNoErrors() { if (errors > 0) { lastError.get().printStackTrace(); throw new UnexpectedOnErrorException(); } } } public static class TerminalEventTimeoutException extends RuntimeException { private static final long serialVersionUID = -7355281653999339840L; } public static class ExpectedErrorNotReceivedException extends RuntimeException { private static final long serialVersionUID = -567146145612029349L; } public static class ExpectedExceptionNotThrownException extends RuntimeException { private static final long serialVersionUID = -104410457605712970L; } public static class WrongOnNextCountException extends RuntimeException { private static final long serialVersionUID = 984672575527784559L; } public static class UnexpectedOnCompletedException extends RuntimeException { private static final long serialVersionUID = 7164517608988798969L; } public static class UnexpectedOnErrorException extends RuntimeException { private static final long serialVersionUID = -813740137771756205L; } public static class TooManyOnCompletedException extends RuntimeException { private static final long serialVersionUID = -405328882928962333L; } public static class DownstreamUnsubscriptionDidNotOccurException extends RuntimeException { private static final long serialVersionUID = 7218646111664183642L; } public static class UnexpectedOnNextException extends RuntimeException { private static final long serialVersionUID = -3656406263739222767L; public UnexpectedOnNextException(String message) { super(message); } } private static enum TestType { WITHOUT_BACKP, BACKP_INITIAL_REQUEST_MAX, BACKP_INITIAL_REQUEST_MAX_THEN_BY_ONE, BACKP_ONE_BY_ONE, BACKP_TWO_BY_TWO, BACKP_REQUEST_ZERO, BACKP_FIVE_BY_FIVE, BACKP_FIFTY_BY_FIFTY, BACKP_THOUSAND_BY_THOUSAND, BACKP_REQUEST_OVERFLOW; } private static <T> MyTestSubscriber<T> createTestSubscriber(Optional<Integer> unsubscribeAfter, long onStartRequest, Optional<Long> onNextRequest) { return new MyTestSubscriber<T>(unsubscribeAfter, of(onStartRequest), onNextRequest, ABSENT); } private static <T> MyTestSubscriber<T> createTestSubscriber(TestType testType, final Optional<Integer> unsubscribeAfter) { if (testType == TestType.WITHOUT_BACKP) return new MyTestSubscriber<T>(unsubscribeAfter); else if (testType == TestType.BACKP_INITIAL_REQUEST_MAX) return createTestSubscriber(unsubscribeAfter, Long.MAX_VALUE, ABSENT); else if (testType == TestType.BACKP_INITIAL_REQUEST_MAX_THEN_BY_ONE) return createTestSubscriber(unsubscribeAfter, Long.MAX_VALUE, of(1L)); else if (testType == TestType.BACKP_ONE_BY_ONE) return createTestSubscriber(unsubscribeAfter, 1L, of(1L)); else if (testType == TestType.BACKP_REQUEST_ZERO) return new MyTestSubscriber<T>(unsubscribeAfter, of(1L), of(0L), of(1L)); else if (testType == TestType.BACKP_REQUEST_OVERFLOW) return new MyTestSubscriber<T>(unsubscribeAfter, of(1L), of(Long.MAX_VALUE / 3 * 2), of(Long.MAX_VALUE / 3 * 2)); else if (testType == TestType.BACKP_TWO_BY_TWO) return createTestSubscriberWithBackpNbyN(unsubscribeAfter, 2); else if (testType == TestType.BACKP_FIVE_BY_FIVE) return createTestSubscriberWithBackpNbyN(unsubscribeAfter, 5); else if (testType == TestType.BACKP_FIFTY_BY_FIFTY) return createTestSubscriberWithBackpNbyN(unsubscribeAfter, 50); else if (testType == TestType.BACKP_THOUSAND_BY_THOUSAND) return createTestSubscriberWithBackpNbyN(unsubscribeAfter, 1000); else throw new RuntimeException(testType + " not implemented"); } private static <T> MyTestSubscriber<T> createTestSubscriberWithBackpNbyN( final Optional<Integer> unsubscribeAfter, final long requestSize) { return new MyTestSubscriber<T>(unsubscribeAfter, of(requestSize), ABSENT, of(requestSize)); } @RunWith(Suite.class) @SuiteClasses({}) private static class TestSuiteFromCases<T, R> extends TestSuite { TestSuiteFromCases(Class<?> cls, List<Case<T, R>> cases) { super(cls); for (Case<T, R> c : cases) { for (TestType testType : TestType.values()) if (testType != TestType.BACKP_REQUEST_OVERFLOW) addTest(new MyTestCase<T, R>(c.name + "_" + testType.name(), c, testType)); } } } private static class MyTestCase<T, R> extends TestCase { private final Case<T, R> c; private final TestType testType; MyTestCase(String name, Case<T, R> c, TestType testType) { super(name); this.c = c; this.testType = testType; } @Override protected void runTest() throws Throwable { TestingHelper.runTest(c, testType); } } private static <T> boolean equals(Collection<T> a, Collection<T> b, boolean ordered) { if (a == null) return b == null; else if (b == null) return a == null; else if (a.size() != b.size()) return false; else if (ordered) return a.equals(b); else { List<T> list = new ArrayList<T>(a); for (T t : b) { if (!list.remove(t)) return false; } return true; } } private static class TestingException extends RuntimeException { private static final long serialVersionUID = 4467514769366847747L; } /** * RuntimeException implementation to represent the situation of more items * being delivered by a source than are requested via backpressure. */ public static class DeliveredMoreThanRequestedException extends RuntimeException { private static final long serialVersionUID = 1369440545774454215L; public DeliveredMoreThanRequestedException() { super("more items arrived than requested"); } } /** * RuntimeException implementation to represent an assertion failure. */ public static class AssertionException extends RuntimeException { private static final long serialVersionUID = -6846674323693517388L; public AssertionException(String message) { super(message); } } /** * Returns a {@link Func1} For use with {@code Observable.to()}. Enables * method chaining from observable to assertions. * * @param <T> * type of item in observable stream * @return Func1 */ public static <T> Func1<Observable<T>, TestSubscriber2<T>> test() { return TestSubscriber2.test(); } /** * Returns a {@link Func1} For use with {@code Observable.to()}. Enables * method chaining from observable to assertions. * * @param initialRequest * amount to be requested in the {@code onStart} method of the * subscriber. * @param <T> * type of item in observable stream * @return Func1 */ public static <T> Func1<Observable<T>, TestSubscriber2<T>> testWithRequest( long initialRequest) { return TestSubscriber2.testWithRequest(initialRequest); } }