package io.vertx.ext.unit.impl; import io.vertx.core.AsyncResult; import io.vertx.core.Handler; import io.vertx.ext.unit.Async; import io.vertx.ext.unit.TestContext; import java.util.ArrayList; import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; /** * @author <a href="mailto:julien@julienviet.com">Julien Viet</a> */ public class TestContextImpl implements TestContext { private static final AtomicInteger threadCount = new AtomicInteger(0); private final Map<String, Object> attributes; private final Handler<Throwable> unhandledFailureHandler; private Step current; /** * A step in the test. */ private class Step { private final Handler<Throwable> endHandler; private final LinkedList<AsyncImpl> asyncs = new LinkedList<>(); private boolean running = true; private boolean complete; private Throwable failure; public Step(Handler<Throwable> endHandler, Throwable failure) { this.endHandler = endHandler; this.failure = failure; } private void tryEnd() { List<AsyncImpl> copy; boolean end = false; synchronized (this) { if ((asyncs.isEmpty() || failure != null) && !complete && !running) { complete = true; end = true; } if (end) { // Stack contention to avoid CME. copy = new ArrayList<>(asyncs); asyncs.clear(); synchronized (TestContextImpl.this) { current = null; } this.notify(); } else { copy = Collections.emptyList(); } } if (end) { for (AsyncImpl a : copy) { a.release(); } endHandler.handle(failure); } } private boolean failed(Throwable t) { synchronized (this) { if (complete) { return false; } if (failure == null) { failure = t; } } tryEnd(); return true; } private AsyncImpl async(int count) { synchronized (this) { if (!complete) { AsyncImpl async = new AsyncImpl(count); if (failure == null) { asyncs.add(async); } return async; } else { throw new IllegalStateException("Test already completed"); } } } private void run(long timeout, Handler<TestContext> test) { if (timeout > 0) { Runnable cancel = () -> { try { synchronized (this) { if (complete) { return; } wait(timeout); if (complete) { return; } running = false; } failed(new TimeoutException()); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } }; Thread timeoutThread = new Thread(cancel); timeoutThread.setName("vert.x-unit-timeout-thread-" + threadCount.incrementAndGet()); timeoutThread.start(); } try { test.handle(TestContextImpl.this); } catch (Throwable t) { failed(t); } finally { synchronized (this) { running = false; } tryEnd(); } } class AsyncImpl extends CompletionImpl<Void> implements Async { private final int initialCount; private final AtomicInteger current; public AsyncImpl(int initialCount) { this.initialCount = initialCount; this.current = new AtomicInteger(initialCount); } @Override public int count() { return current.get(); } @Override public void countDown() { int value = current.updateAndGet(v -> v > 0 ? v - 1 : 0); if (value == 0) { completable.complete(null); internalComplete(); } } @Override public void complete() { int value = current.getAndSet(0); if (value > 0) { completable.complete(null); internalComplete(); } else { throw new IllegalStateException("The Async complete method has been called more than " + initialCount + " times, check your test."); } } private void release() { if (Step.this.failure != null) { completable.completeExceptionally(Step.this.failure); } else { completable.complete(null); } } void internalComplete() { boolean complete; synchronized (Step.this) { complete = asyncs.remove(this); } if (complete) { tryEnd(); } release(); } } } public TestContextImpl(Map<String, Object> attributes, Handler<Throwable> unhandledFailureHandler) { this.attributes = attributes; this.unhandledFailureHandler = unhandledFailureHandler; } @Override public synchronized <T> T get(String key) { return (T) attributes.get(key); } @Override public synchronized <T> T put(String key, Object value) { if (value != null) { return (T) attributes.put(key, value); } else { return (T) attributes.remove(key); } } @Override public synchronized <T> T remove(String key) { return (T) attributes.remove(key); } public void run(Throwable failed, long timeout, Handler<TestContext> test, Handler<Throwable> endHandler) { Step step; synchronized (this) { if (current != null) { throw new IllegalStateException("Wrong status"); } current = step = new Step(endHandler, failed); } step.run(timeout, test); } public void failed(Throwable t) { boolean reported; synchronized (this) { reported = current != null && current.failed(t); } if (!reported && unhandledFailureHandler != null) { unhandledFailureHandler.handle(t); } } @Override public Async async() { return async(1); } @Override public Async async(int count) { if (count < 1) { throw new IllegalArgumentException("Async completion count must be > 0"); } synchronized (this) { if (current != null) { return current.async(count); } else { throw new IllegalStateException(); } } } @Override public TestContext assertNull(Object expected) { return assertNull(expected, null); } @Override public TestContext assertNull(Object expected, String message) { if (expected != null) { throw reportAssertionError(formatMessage(message, "Expected null")); } return this; } @Override public TestContext assertNotNull(Object expected) { return assertNotNull(expected, null); } @Override public TestContext assertNotNull(Object expected, String message) { if (expected == null) { throw reportAssertionError(formatMessage(message, "Expected not null")); } return this; } @Override public TestContext assertTrue(boolean condition, String message) { if (!condition) { throw reportAssertionError(formatMessage(message, "Expected true")); } return this; } public TestContext assertTrue(boolean condition) { return assertTrue(condition, null); } @Override public TestContext assertFalse(boolean condition) { return assertFalse(condition, null); } @Override public TestContext assertFalse(boolean condition, String message) { if (condition) { throw reportAssertionError(formatMessage(message, "Expected false")); } return this; } @Override public void fail() { fail((String) null); } public void fail(String message) { throw reportAssertionError(message != null ? message : "Test failed"); } public void fail(Throwable cause) { failed(cause); Helper.uncheckedThrow(cause); } @Override public Handler<Throwable> exceptionHandler() { return this::failed; } @Override public TestContext assertEquals(Object expected, Object actual) { return assertEquals(expected, actual, null); } @Override public TestContext assertEquals(Object expected, Object actual, String message) { if (actual == null) { if (expected != null) { throw reportAssertionError(formatMessage(message, "Expected " + expected + " got null")); } } else { if (expected == null) { throw reportAssertionError(formatMessage(message, "Expected null instead of " + actual)); } else if (!expected.equals(actual)) { throw reportAssertionError(formatMessage(message, "Not equals : " + expected + " != " + actual)); } } return this; } @Override public TestContext assertInRange(double expected, double actual, double delta) { return assertInRange(expected, actual, delta, null); } @Override public TestContext assertInRange(double expected, double actual, double delta, String message) { if (Double.compare(expected, actual) != 0 && Math.abs((actual - expected)) > delta) { throw reportAssertionError(formatMessage(message, "Expected " + actual + " to belong to [" + (expected - delta) + "," + (expected + delta) + "]")); } return this; } @Override public TestContext assertNotEquals(Object first, Object second, String message) { if (first == null) { if (second == null) { throw reportAssertionError(formatMessage(message, "Expected null != null")); } } else { if (first.equals(second)) { throw reportAssertionError(formatMessage(message, "Expected different values " + first + " != " + second)); } } return this; } @Override public <T> Handler<AsyncResult<T>> asyncAssertSuccess() { return asyncAssertSuccess(result -> { }); } @Override public <T> Handler<AsyncResult<T>> asyncAssertSuccess(Handler<T> resultHandler) { Async async = async(); return ar -> { if (ar.succeeded()) { T result = ar.result(); try { resultHandler.handle(result); async.complete(); } catch (Throwable e) { failed(e); } } else { failed(ar.cause()); } }; } @Override public <T> Handler<AsyncResult<T>> asyncAssertFailure() { return asyncAssertFailure(cause -> { }); } @Override public <T> Handler<AsyncResult<T>> asyncAssertFailure(Handler<Throwable> causeHandler) { Async async = async(); return ar -> { if (ar.failed()) { Throwable result = ar.cause(); try { causeHandler.handle(result); async.complete(); } catch (Throwable e) { failed(e); } } else { reportAssertionError("Was expecting a failure instead of of success"); } }; } @Override public TestContext assertNotEquals(Object first, Object second) { return assertNotEquals(first, second, null); } /** * Create and report an assertion error, the returned throwable can be thrown to change * the control flow. * * @return an assertion error to eventually throw */ private AssertionError reportAssertionError(String message) { AssertionError err = new AssertionError(message); failed(err); return err; } private static String formatMessage(String providedMessage, String defaultMessage) { return providedMessage == null ? defaultMessage : (providedMessage + ". " + defaultMessage); } }