/** * PODD is an OWL ontology database used for scientific project management * * Copyright (C) 2009-2013 The University Of Queensland * * This program is free software: you can redistribute it and/or modify it under the terms of the * GNU Affero General Public License as published by the Free Software Foundation, either version 3 * of the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without * even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License along with this program. * If not, see <http://www.gnu.org/licenses/>. */ package com.github.podd.junit.ext; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.FutureTask; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.junit.internal.runners.statements.FailOnTimeout; import org.junit.runners.model.Statement; /** * Extension of the default JUnit {@link FailOnTimeout} statement to print the stack traces of all * active threads when a timeout occurs. * * Enhanced version of {@link FailOnTimeout} */ public class FailOnTimeoutWithStackTraces extends Statement { private class CallableStatement implements Callable<Throwable> { @Override public Throwable call() throws Exception { try { FailOnTimeoutWithStackTraces.this.fOriginalStatement.evaluate(); } catch(final Exception e) { throw e; } catch(final Throwable e) { return e; } return null; } } private final Statement fOriginalStatement; private final TimeUnit fTimeUnit; private final long fTimeout; public FailOnTimeoutWithStackTraces(final Statement originalStatement, final long millis) { this(originalStatement, millis, TimeUnit.MILLISECONDS); } public FailOnTimeoutWithStackTraces(final Statement originalStatement, final long timeout, final TimeUnit unit) { this.fOriginalStatement = originalStatement; this.fTimeout = timeout; this.fTimeUnit = unit; } private Exception createTimeoutException(final Thread thread) { final String allStackTraces = this.getStackTraces(); Exception exception; if(allStackTraces.length() == 0) { exception = new Exception(String.format("test timed out after %d %s", this.fTimeout, this.fTimeUnit.name() .toLowerCase())); } else { exception = new Exception(String.format( "test timed out after %d %s\nAll threads active when test timeout occurred:\n %s", this.fTimeout, this.fTimeUnit.name().toLowerCase(), allStackTraces)); } final StackTraceElement[] stackTrace = thread.getStackTrace(); if(stackTrace != null) { exception.setStackTrace(stackTrace); thread.interrupt(); } return exception; } @Override public void evaluate() throws Throwable { final FutureTask<Throwable> task = new FutureTask<Throwable>(new CallableStatement()); final Thread thread = new Thread(task, "Time-limited test"); thread.setDaemon(true); thread.start(); final Throwable throwable = this.getResult(task, thread); if(throwable != null) { throw throwable; } } /** * Wait for the test task, returning the exception thrown by the test if the test failed, an * exception indicating a timeout if the test timed out, or {@code null} if the test passed. */ private Throwable getResult(final FutureTask<Throwable> task, final Thread thread) { try { return task.get(this.fTimeout, this.fTimeUnit); } catch(final InterruptedException e) { return e; // caller will re-throw; no need to call // Thread.interrupt() } catch(final ExecutionException e) { // test failed; have caller re-throw the exception thrown by the // test return e.getCause(); } catch(final TimeoutException e) { return this.createTimeoutException(thread); } } /** * Gets all thread stack traces. * * @return string of all thread stack traces */ private String getStackTraces() { final StringBuilder sb = new StringBuilder(); final Map<Thread, StackTraceElement[]> stacks = Thread.getAllStackTraces(); for(final Thread t : stacks.keySet()) { sb.append(t.toString()).append('\n'); for(final StackTraceElement ste : t.getStackTrace()) { sb.append("\tat ").append(ste.toString()).append('\n'); } sb.append('\n'); } return sb.toString(); } }