package com.tddinaction.concurrency.waitforthreads; import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Calendar; import java.util.Collections; import java.util.Date; import java.util.List; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import org.junit.Assert; import org.junit.internal.runners.InitializationError; import org.junit.internal.runners.TestClassMethodsRunner; import org.junit.runner.Description; import org.junit.runner.notification.Failure; import org.junit.runner.notification.RunNotifier; /** * A spawned threads-aware test runner implementation for JUnit 4. * * @author lkoskela (Based on work by Greg Vaughn <gvaughn@delphis.com>) */ public class ThreadedClassMethodsRunner extends TestClassMethodsRunner implements ThreadedExecutionContext { private static final String MAIN_TEST_THREAD = "ThreadRunner-main"; public boolean fTestThreadsRunning; private Description currentTest; public List<Failure> failures; private long endTime; private Method method; public ThreadedClassMethodsRunner(Class<?> test) throws InitializationError { super(test); } @Override protected void invokeTestMethod(Method method, RunNotifier notifier) { failures = Collections .synchronizedList(new ArrayList<Failure>()); this.method = method; setCurrentTest(methodDescription(method)); try { ThreadGroup grp = new FailureCatchingThreadGroup(this); startEndWatcherThread(notifier, grp); waitForTimeoutOrThreadCompletion(); assertThreadGroupHasStopped(notifier, grp); } catch (Throwable any) { any.printStackTrace(); } } private void assertThreadGroupHasStopped( final RunNotifier notifier, ThreadGroup group) { if (group.activeCount() > 1) { // Only EndWatcher should still run Throwable exception = new AssertionError( "Not all spawned threads were stopped: " + getNamesOfActiveThreads(group)); notifier.fireTestFailure(new Failure(getCurrentTest(), exception)); notifier.fireTestFinished(getCurrentTest()); } } public static List<String> getNamesOfActiveThreads(ThreadGroup grp) { List<String> names = ThreadUtils.namesOfActiveThreadsIn(grp); names.remove(MAIN_TEST_THREAD); return names; } private void waitForTimeoutOrThreadCompletion() { // Now the main JUnit thread waits to either timeout or be // notified. (Spawned threads might call notifyAll and wake the // JUnit thread.) long fTimeout = 10000; endTime = fTimeout + System.currentTimeMillis(); while (fTestThreadsRunning && endTime > System.currentTimeMillis()) { try { synchronized (this) { wait(fTimeout); } } catch (InterruptedException e) { e.printStackTrace(); } } } private ThreadGroupListener startEndWatcherThread( final RunNotifier notifier, ThreadGroup grp) { final Semaphore endWatcherStarted = new Semaphore(0); ThreadGroupListener endWatcher = new ThreadGroupListener( new LifecycleCallback() { public void run() { endWatcherStarted.release(); ThreadedClassMethodsRunner.this.runSuper( method, notifier); } public void before() { ThreadedClassMethodsRunner.this.fTestThreadsRunning = true; } public void after() { synchronized (ThreadedClassMethodsRunner.this) { ThreadedClassMethodsRunner.this.fTestThreadsRunning = false; ThreadedClassMethodsRunner.this .notify(); // main JUnit thread } } }); new Thread(grp, endWatcher, MAIN_TEST_THREAD).start(); try { endWatcherStarted.tryAcquire(1, TimeUnit.SECONDS); } catch (InterruptedException e) { throw new RuntimeException( "Problems starting EndWatcher thread?"); } return endWatcher; } public static void log(String msg) { System.out.println("[" + Thread.currentThread().getName() + "] [" + new Date(Calendar.getInstance().getTimeInMillis()) + "] " + msg); } protected void runSuper(final Method method, final RunNotifier notifier) { // TODO: encapsulate 'failures' behind an interface super.invokeTestMethod(method, new DelayedFailureRunNotifier( notifier, failures)); } public static void waitForSpawnedThreads() { ThreadGroup grp = Thread.currentThread().getThreadGroup(); assertThreadGroupIsWatched(grp); ThreadUtils.waitWhileActiveThreadCountIsHigherThan(1, grp); } private static void assertThreadGroupIsWatched(ThreadGroup grp) { Thread[] threads = new Thread[grp.activeCount() * 2]; int threadCount = grp.enumerate(threads); boolean endWatcherFound = false; for (int i = 0; i < threadCount; i++) { if (threads[i].getName().equals(MAIN_TEST_THREAD)) { endWatcherFound = true; break; } } Assert.assertTrue("No EndWatcher thread in ThreadGroup." + " Have you defined ThreadedRunner for @RunWith?", endWatcherFound); } public synchronized Description getCurrentTest() { return currentTest; } public synchronized void setCurrentTest(Description desc) { currentTest = desc; } public synchronized void add(Failure failure) { failures.add(failure); } }