/*
* Copyright (c) 2011-2014 The original author or authors
* ------------------------------------------------------
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* and Apache License v2.0 which accompanies this distribution.
*
* The Eclipse Public License is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* The Apache License v2.0 is available at
* http://www.opensource.org/licenses/apache2.0.php
*
* You may elect to redistribute this code under either of these licenses.
*/
package io.vertx.test.core;
import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
import io.vertx.core.logging.Logger;
import io.vertx.core.logging.LoggerFactory;
import org.hamcrest.Matcher;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.internal.ArrayComparisonFailure;
import org.junit.rules.TestName;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;
import java.util.function.Consumer;
/**
* @author <a href="http://tfox.org">Tim Fox</a>
*/
public class AsyncTestBase {
private static final Logger log = LoggerFactory.getLogger(AsyncTestBase.class);
private CountDownLatch latch;
private volatile Throwable throwable;
private volatile Thread thrownThread;
private volatile boolean testCompleteCalled;
private volatile boolean awaitCalled;
private boolean threadChecksEnabled = true;
private volatile boolean tearingDown;
private volatile String mainThreadName;
private Map<String, Exception> threadNames = new ConcurrentHashMap<>();
@Rule
public TestName name = new TestName();
protected void setUp() throws Exception {
log.info("Starting test: " + this.getClass().getSimpleName() + "#" + name.getMethodName());
mainThreadName = Thread.currentThread().getName();
tearingDown = false;
waitFor(1);
throwable = null;
testCompleteCalled = false;
awaitCalled = false;
threadNames.clear();
}
protected void tearDown() throws Exception {
tearingDown = true;
afterAsyncTestBase();
}
@Before
public void before() throws Exception {
setUp();
}
@After
public void after() throws Exception {
tearDown();
}
protected synchronized void waitFor(int count) {
latch = new CountDownLatch(count);
}
protected synchronized void waitForMore(int count) {
latch = new CountDownLatch(count + (int)latch.getCount());
}
protected synchronized void complete() {
if (tearingDown) {
throw new IllegalStateException("testComplete called after test has completed");
}
checkThread();
if (testCompleteCalled) {
throw new IllegalStateException("already complete");
}
latch.countDown();
if (latch.getCount() == 0) {
testCompleteCalled = true;
}
}
protected void testComplete() {
if (tearingDown) {
throw new IllegalStateException("testComplete called after test has completed");
}
checkThread();
if (testCompleteCalled) {
throw new IllegalStateException("testComplete() already called");
}
testCompleteCalled = true;
latch.countDown();
}
protected void await() {
await(2, TimeUnit.MINUTES);
}
public void await(long delay, TimeUnit timeUnit) {
if (awaitCalled) {
throw new IllegalStateException("await() already called");
}
awaitCalled = true;
try {
boolean ok = latch.await(delay, timeUnit);
if (!ok) {
// timed out
throw new IllegalStateException("Timed out in waiting for test complete");
} else {
rethrowError();
}
} catch (InterruptedException e) {
throw new IllegalStateException("Test thread was interrupted!");
}
}
private void rethrowError() {
if (throwable != null) {
if (throwable instanceof Error) {
throw (Error)throwable;
} else if (throwable instanceof RuntimeException) {
throw (RuntimeException)throwable;
} else {
// Unexpected throwable- Should never happen
throw new IllegalStateException(throwable);
}
}
}
protected void disableThreadChecks() {
threadChecksEnabled = false;
}
protected void afterAsyncTestBase() {
if (throwable != null && thrownThread != Thread.currentThread() && !awaitCalled) {
// Throwable caught from non main thread
throw new IllegalStateException("Assert or failure from non main thread but no await() on main thread", throwable);
}
for (Map.Entry<String, Exception> entry: threadNames.entrySet()) {
if (!entry.getKey().equals(mainThreadName)) {
if (threadChecksEnabled && !entry.getKey().startsWith("vert.x-")) {
IllegalStateException is = new IllegalStateException("Non Vert.x thread! :" + entry.getKey());
is.setStackTrace(entry.getValue().getStackTrace());
throw is;
}
}
}
}
private void handleThrowable(Throwable t) {
if (tearingDown) {
throw new IllegalStateException("assert or failure occurred after test has completed");
}
throwable = t;
t.printStackTrace();
thrownThread = Thread.currentThread();
latch.countDown();
if (t instanceof AssertionError) {
throw (AssertionError)t;
}
}
protected void clearThrown() {
throwable = null;
}
protected void checkThread() {
threadNames.put(Thread.currentThread().getName(), new Exception());
}
protected void assertTrue(String message, boolean condition) {
checkThread();
try {
Assert.assertTrue(message, condition);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertFalse(boolean condition) {
checkThread();
try {
Assert.assertFalse(condition);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(String message, char[] expecteds, char[] actuals) throws ArrayComparisonFailure {
checkThread();
try {
Assert.assertArrayEquals(message, expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertSame(String message, Object expected, Object actual) {
checkThread();
try {
Assert.assertSame(message, expected, actual);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertEquals(long expected, long actual) {
checkThread();
try {
Assert.assertEquals(expected, actual);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertNull(Object object) {
checkThread();
try {
Assert.assertNull(object);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertFalse(String message, boolean condition) {
checkThread();
try {
Assert.assertFalse(message, condition);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void fail(String message) {
checkThread();
try {
Assert.fail(message);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertNull(String message, Object object) {
checkThread();
try {
Assert.assertNull(message, object);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(String message, float[] expecteds, float[] actuals, float delta) throws ArrayComparisonFailure {
checkThread();
try {
Assert.assertArrayEquals(message, expecteds, actuals, delta);
} catch (AssertionError e) {
handleThrowable(e);
}
}
@Deprecated
protected void assertEquals(String message, double expected, double actual) {
checkThread();
try {
Assert.assertEquals(message, expected, actual);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(String message, double[] expecteds, double[] actuals, double delta) throws ArrayComparisonFailure {
checkThread();
try {
Assert.assertArrayEquals(message, expecteds, actuals, delta);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(String message, Object[] expecteds, Object[] actuals) throws ArrayComparisonFailure {
checkThread();
try {
Assert.assertArrayEquals(message, expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(String message, short[] expecteds, short[] actuals) throws ArrayComparisonFailure {
checkThread();
try {
Assert.assertArrayEquals(message, expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(short[] expecteds, short[] actuals) {
checkThread();
try {
Assert.assertArrayEquals(expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(long[] expecteds, long[] actuals) {
checkThread();
try {
Assert.assertArrayEquals(expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertNotNull(Object object) {
checkThread();
try {
Assert.assertNotNull(object);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertEquals(Object expected, Object actual) {
checkThread();
try {
Assert.assertEquals(expected, actual);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertEquals(String message, Object expected, Object actual) {
checkThread();
try {
Assert.assertEquals(message, expected, actual);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertTrue(boolean condition) {
checkThread();
try {
Assert.assertTrue(condition);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(Object[] expecteds, Object[] actuals) {
checkThread();
try {
Assert.assertArrayEquals(expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertNotNull(String message, Object object) {
checkThread();
try {
Assert.assertNotNull(message, object);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertEquals(String message, double expected, double actual, double delta) {
checkThread();
try {
Assert.assertEquals(message, expected, actual, delta);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void fail() {
checkThread();
try {
Assert.fail();
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void fail(Throwable cause) {
checkThread();
handleThrowable(cause);
}
protected void assertSame(Object expected, Object actual) {
checkThread();
try {
Assert.assertSame(expected, actual);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertEquals(String message, long expected, long actual) {
checkThread();
try {
Assert.assertEquals(message, expected, actual);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(String message, byte[] expecteds, byte[] actuals) throws ArrayComparisonFailure {
checkThread();
try {
Assert.assertArrayEquals(message, expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(String message, long[] expecteds, long[] actuals) throws ArrayComparisonFailure {
checkThread();
try {
Assert.assertArrayEquals(message, expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertEquals(double expected, double actual, double delta) {
checkThread();
try {
Assert.assertEquals(expected, actual, delta);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected <T> void assertThat(T actual, Matcher<T> matcher) {
checkThread();
try {
Assert.assertThat(actual, matcher);
} catch (AssertionError e) {
handleThrowable(e);
}
}
@Deprecated
protected void assertEquals(String message, Object[] expecteds, Object[] actuals) {
checkThread();
try {
Assert.assertEquals(message, expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
@Deprecated
protected void assertEquals(Object[] expecteds, Object[] actuals) {
checkThread();
try {
Assert.assertEquals(expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertNotSame(String message, Object unexpected, Object actual) {
checkThread();
try {
Assert.assertNotSame(message, unexpected, actual);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected <T> void assertThat(String reason, T actual, Matcher<T> matcher) {
checkThread();
try {
Assert.assertThat(reason, actual, matcher);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(float[] expecteds, float[] actuals, float delta) {
checkThread();
try {
Assert.assertArrayEquals(expecteds, actuals, delta);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertNotSame(Object unexpected, Object actual) {
checkThread();
try {
Assert.assertNotSame(unexpected, actual);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(byte[] expecteds, byte[] actuals) {
checkThread();
try {
Assert.assertArrayEquals(expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(char[] expecteds, char[] actuals) {
checkThread();
try {
Assert.assertArrayEquals(expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(double[] expecteds, double[] actuals, double delta) {
checkThread();
try {
Assert.assertArrayEquals(expecteds, actuals, delta);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(int[] expecteds, int[] actuals) {
checkThread();
try {
Assert.assertArrayEquals(expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
@Deprecated
protected void assertEquals(double expected, double actual) {
checkThread();
try {
Assert.assertEquals(expected, actual);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected void assertArrayEquals(String message, int[] expecteds, int[] actuals) throws ArrayComparisonFailure {
checkThread();
try {
Assert.assertArrayEquals(message, expecteds, actuals);
} catch (AssertionError e) {
handleThrowable(e);
}
}
protected <T> Handler<AsyncResult<T>> onFailure(Consumer<Throwable> consumer) {
return result -> {
assertFalse(result.succeeded());
consumer.accept(result.cause());
};
}
protected void awaitLatch(CountDownLatch latch) throws InterruptedException {
assertTrue(latch.await(10, TimeUnit.SECONDS));
}
protected void assertWaitUntil(BooleanSupplier supplier) {
assertWaitUntil(supplier, 10000);
}
protected void waitUntil(BooleanSupplier supplier) {
waitUntil(supplier, 10000);
}
protected void assertWaitUntil(BooleanSupplier supplier, long timeout) {
if (!waitUntil(supplier, timeout)) {
throw new IllegalStateException("Timed out");
}
}
protected boolean waitUntil(BooleanSupplier supplier, long timeout) {
long start = System.currentTimeMillis();
while (true) {
if (supplier.getAsBoolean()) {
return true;
}
try {
Thread.sleep(10);
} catch (InterruptedException ignore) {
}
long now = System.currentTimeMillis();
if (now - start > timeout) {
return false;
}
}
}
protected <T> Handler<AsyncResult<T>> onSuccess(Consumer<T> consumer) {
return result -> {
if (result.failed()) {
result.cause().printStackTrace();
fail(result.cause().getMessage());
} else {
consumer.accept(result.result());
}
};
}
}