/* * Copyright (C) 2011-2014 Chris Vest (mr.chrisvest@gmail.com) * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package stormpot.slow; import org.junit.rules.TestRule; import org.junit.runner.Description; import org.junit.runners.model.Statement; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; class ExecutorTestRule implements TestRule { private ExecutorService executor; private List<Future<?>> futuresToPrintOnFailure = new ArrayList<>(); ExecutorService getExecutorService() { return executor; } @Override public Statement apply(final Statement base, final Description description) { return new Statement() { @Override public void evaluate() throws Throwable { String testName = description.getDisplayName(); TestThreadFactory threadFactory = new TestThreadFactory(testName); executor = createExecutor(threadFactory); try { base.evaluate(); executor.shutdown(); if (!executor.awaitTermination(5000, TimeUnit.SECONDS)) { throw new Exception( "ExecutorService.shutdown timed out after 5 second"); } threadFactory.verifyAllThreadsTerminatedSuccessfully(); } catch (Throwable throwable) { threadFactory.dumpAllThreads(); printFuturesForFailure(); throw throwable; } } }; } private ExecutorService createExecutor(ThreadFactory threadFactory) { return Executors.newCachedThreadPool(threadFactory); } void printOnFailure(List<Future<?>> futures) { futuresToPrintOnFailure.addAll(futures); } void printOnFailure(Future<?> future) { futuresToPrintOnFailure.add(future); } private void printFuturesForFailure() { System.err.println( "\n===[ Dumping all registered futures ]===\n"); for (Future<?> future : futuresToPrintOnFailure) { System.err.printf( "future = %s, isDone? %s, isCancelled? %s%n", future, future.isDone(), future.isCancelled()); if (future.isDone()) { System.err.print(" result: "); try { System.err.println(future.get()); } catch (Exception e) { e.printStackTrace(); } } } System.err.println( "\n===[ End dumping all registered futures ]===\n"); } private class TestThreadFactory implements ThreadFactory { private final String testName; private final AtomicInteger threadCounter = new AtomicInteger(); private final List<Thread> threads = Collections.synchronizedList(new ArrayList<>()); private TestThreadFactory(String testName) { this.testName = testName; } @Override public Thread newThread(Runnable runnable) { int id = threadCounter.incrementAndGet(); Thread thread = new Thread(runnable, "TestThread#" + id + "[" + testName + "]"); threads.add(thread); return thread; } void verifyAllThreadsTerminatedSuccessfully() { synchronized (threads) { for (Thread thread : threads) { // The Thread.State is updated asynchronously by the JVM, // so we occasionally have to do a couple of retries before we // observe the state change. Thread.State state = thread.getState(); int tries = 100; while (state != Thread.State.TERMINATED && tries --> 0) { try { thread.join(10); } catch (InterruptedException e) { throw new AssertionError(e); } state = thread.getState(); } if (tries == 0) { // Okay, this is odd. Let's ask everybody to come to a safe-point // before we pass our final judgement on the thread state. System.gc(); state = thread.getState(); } assertThat( "Unexpected thread state: " + thread + " (id " + thread.getId() + ")", state, is(Thread.State.TERMINATED)); } } } void dumpAllThreads() throws Exception { synchronized (threads) { System.err.println( "\n===[ Dumping stack traces for all created threads ]===\n"); for (Thread thread : threads) { StackTraceElement[] stackTrace = thread.getStackTrace(); printStackTrace(thread, stackTrace); } System.err.println( "\n===[ End stack traces for all created threads ]===\n"); System.err.println( "\n===[ Dumping stack traces for all other threads ]===\n"); Set<Map.Entry<Thread, StackTraceElement[]>> entries = Thread.getAllStackTraces().entrySet(); for (Map.Entry<Thread,StackTraceElement[]> entry : entries) { printStackTrace(entry.getKey(), entry.getValue()); } System.err.println( "\n===[ End stack traces for all other threads ]===\n"); } } private void printStackTrace( Thread thread, StackTraceElement[] stackTrace) { Exception printer = new Exception( "Stack trace for " + thread + " (id " + thread.getId() + "), state = " + thread.getState()); printer.setStackTrace(stackTrace); printer.printStackTrace(); } } }