package com.twitter.common.testing.runner;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import com.google.common.base.Charsets;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.io.Closeables;
import com.google.common.io.Files;
import org.junit.runner.Description;
import org.junit.runner.JUnitCore;
import org.junit.runner.Request;
import org.junit.runner.Result;
import org.kohsuke.args4j.Argument;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.kohsuke.args4j.spi.StringArrayOptionHandler;
/**
* An alternative to {@link JUnitCore} with stream capture and junit-report xml output capabilities.
*/
public class JUnitConsoleRunner {
private static final SwappableStream<PrintStream> SWAPPABLE_OUT =
new SwappableStream<PrintStream>(System.out);
private static final SwappableStream<PrintStream> SWAPPABLE_ERR =
new SwappableStream<PrintStream>(System.err);
/**
* A stream that allows its underlying output to be swapped.
*/
static class SwappableStream<T extends OutputStream> extends FilterOutputStream {
private final T original;
SwappableStream(T out) {
super(out);
this.original = out;
}
OutputStream swap(OutputStream out) {
OutputStream old = this.out;
this.out = out;
return old;
}
/**
* Returns the original stream this swappable stream was created with.
*/
public T getOriginal() {
return original;
}
}
/**
* Captures a tests stderr and stdout streams, restoring the previous streams on {@link #close()}.
*/
static class StreamCapture {
private final File out;
private OutputStream outstream;
private final File err;
private OutputStream errstream;
private int useCount;
private boolean closed;
StreamCapture(File out, File err) throws IOException {
this.out = out;
this.err = err;
}
void incrementUseCount() {
this.useCount++;
}
void open() throws FileNotFoundException {
if (outstream == null) {
outstream = new FileOutputStream(out);
}
if (errstream == null) {
errstream = new FileOutputStream(err);
}
SWAPPABLE_OUT.swap(outstream);
SWAPPABLE_ERR.swap(errstream);
}
void close() throws IOException {
if (--useCount <= 0 && !closed) {
if (outstream != null) {
Closeables.closeQuietly(outstream);
}
if (errstream != null) {
Closeables.closeQuietly(errstream);
}
closed = true;
}
}
void dispose() throws IOException {
useCount = 0;
close();
}
byte[] readOut() throws IOException {
return read(out);
}
byte[] readErr() throws IOException {
return read(err);
}
private byte[] read(File file) throws IOException {
Preconditions.checkState(closed, "Capture must be closed by all users before it can be read");
return Files.toByteArray(file);
}
}
/**
* A run listener that captures the output and error streams for each test class and makes the
* content of these available.
*/
static class StreamCapturingListener extends ForwardingListener implements StreamSource {
private final Map<Class<?>, StreamCapture> captures = Maps.newHashMap();
private final File outdir;
StreamCapturingListener(File outdir) {
this.outdir = outdir;
}
@Override
public void testRunStarted(Description description) throws Exception {
registerTests(description.getChildren());
super.testRunStarted(description);
}
private void registerTests(Iterable<Description> tests) throws IOException {
for (Description test : tests) {
registerTests(test.getChildren());
if (Util.isRunnable(test)) {
StreamCapture capture = captures.get(test.getTestClass());
if (capture == null) {
String prefix = test.getClassName();
File out = new File(outdir, prefix + ".out.txt");
Files.createParentDirs(out);
File err = new File(outdir, prefix + ".err.txt");
Files.createParentDirs(err);
capture = new StreamCapture(out, err);
captures.put(test.getTestClass(), capture);
}
capture.incrementUseCount();
}
}
}
@Override
public void testRunFinished(Result result) throws Exception {
for (StreamCapture capture : captures.values()) {
capture.dispose();
}
super.testRunFinished(result);
}
@Override
public void testStarted(Description description) throws Exception {
captures.get(description.getTestClass()).open();
super.testStarted(description);
}
@Override
public void testFinished(Description description) throws Exception {
captures.get(description.getTestClass()).close();
super.testFinished(description);
}
@Override
public byte[] readOut(Class<?> testClass) throws IOException {
return captures.get(testClass).readOut();
}
@Override
public byte[] readErr(Class<?> testClass) throws IOException {
return captures.get(testClass).readErr();
}
}
private static final Pattern METHOD_PARSER = Pattern.compile("^([^#]+)#([^#]+)$");
private final boolean failFast;
private final boolean suppressOutput;
private final boolean xmlReport;
private final File outdir;
JUnitConsoleRunner(boolean failFast, boolean suppressOutput, boolean xmlReport, File outdir) {
this.failFast = failFast;
this.suppressOutput = suppressOutput;
this.xmlReport = xmlReport;
this.outdir = outdir;
}
void run(Iterable<String> tests) {
final PrintStream out = System.out;
System.setOut(new PrintStream(SWAPPABLE_OUT));
System.setErr(new PrintStream(SWAPPABLE_ERR));
List<Request> requests = parseRequests(out, tests);
JUnitCore core = new JUnitCore();
final AbortableListener abortableListener = new AbortableListener(failFast) {
@Override protected void abort(Result failureResult) {
exit(failureResult.getFailureCount());
}
};
core.addListener(abortableListener);
if (xmlReport || suppressOutput) {
if (!outdir.exists()) {
if (!outdir.mkdirs()) {
throw new IllegalStateException("Failed to create output directory: " + outdir);
}
}
StreamCapturingListener streamCapturingListener = new StreamCapturingListener(outdir);
abortableListener.addListener(streamCapturingListener);
if (xmlReport) {
AntJunitXmlReportListener xmlReportListener =
new AntJunitXmlReportListener(outdir, streamCapturingListener);
abortableListener.addListener(xmlReportListener);
}
}
abortableListener.addListener(new ConsoleListener(out));
Thread abnormalExitHook = new Thread() {
@Override public void run() {
try {
abortableListener.abort(new UnknownError("Abnormal VM exit - test crashed."));
} catch (Exception e) {
out.println(e);
e.printStackTrace(out);
}
}
};
abnormalExitHook.setDaemon(true);
Runtime.getRuntime().addShutdownHook(abnormalExitHook);
int failures = 0;
for (Request request : requests) {
Result result = core.run(request);
failures += result.getFailureCount();
}
Runtime.getRuntime().removeShutdownHook(abnormalExitHook);
exit(failures);
}
private List<Request> parseRequests(PrintStream out, Iterable<String> specs) {
/**
* Datatype representing an individual test method.
*/
class TestMethod {
private final Class<?> clazz;
private final String name;
TestMethod(Class<?> clazz, String name) {
this.clazz = clazz;
this.name = name;
}
}
Set<TestMethod> testMethods = Sets.newLinkedHashSet();
Set<Class<?>> classes = Sets.newLinkedHashSet();
for (String spec : specs) {
Matcher matcher = METHOD_PARSER.matcher(spec);
try {
if (matcher.matches()) {
Class<?> testClass = Class.forName(matcher.group(1));
if (isTest(testClass)) {
String method = matcher.group(2);
testMethods.add(new TestMethod(testClass, method));
}
} else {
Class<?> testClass = Class.forName(spec);
if (isTest(testClass)) {
classes.add(testClass);
}
}
} catch (NoClassDefFoundError e) {
notFoundError(spec, out, e);
} catch (ClassNotFoundException e) {
notFoundError(spec, out, e);
}
}
List<Request> requests = Lists.newArrayList();
if (!classes.isEmpty()) {
requests.add(Request.classes(classes.toArray(new Class<?>[classes.size()])));
}
for (TestMethod testMethod : testMethods) {
requests.add(Request.method(testMethod.clazz, testMethod.name));
}
return requests;
}
private void notFoundError(String spec, PrintStream out, Throwable t) {
out.printf("FATAL: Error during test discovery for %s: %s\n", spec, t);
throw new RuntimeException("Classloading error during test discovery for " + spec, t);
}
/**
* Launcher for JUnitConsoleRunner.
*/
public static void main(String[] args) {
/**
* Command line option bean.
*/
class Options {
private boolean failFast = false;
private boolean suppressOutput = false;
private boolean xmlReport = false;
private File outdir = new File(System.getProperty("java.io.tmpdir"));
private List<String> tests = Lists.newArrayList();
@Option(name = "-fail-fast", usage = "Causes the test suite run to fail fast.")
public void setFailFast(boolean failFast) {
this.failFast = failFast;
}
@Option(name = "-suppress-output", usage = "Suppresses test output.")
public void setSuppressOutput(boolean suppressOutput) {
this.suppressOutput = suppressOutput;
}
@Option(name = "-xmlreport",
usage = "Create ant compatible junit xml report files in -outdir.")
public void setXmlReport(boolean xmlReport) {
this.xmlReport = xmlReport;
}
@Option(name = "-outdir",
usage = "Directory to output test captures too. Only used if -suppress-output or "
+ "-xmlreport is set.")
public void setOutdir(File outdir) {
this.outdir = outdir;
}
@Argument(usage = "Names of junit test classes or test methods to run. Names prefixed "
+ "with @ are considered arg file paths and these will be loaded and the "
+ "whitespace delimited arguments found inside added to the list",
required = true,
metaVar = "TESTS",
handler = StringArrayOptionHandler.class)
public void setTests(String[] tests) {
this.tests = Arrays.asList(tests);
}
}
Options options = new Options();
CmdLineParser parser = new CmdLineParser(options);
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
e.getParser().printUsage(System.out);
exit(1);
}
JUnitConsoleRunner runner =
new JUnitConsoleRunner(options.failFast,
options.suppressOutput,
options.xmlReport,
options.outdir);
List<String> tests = Lists.newArrayList();
for (String test : options.tests) {
if (test.startsWith("@")) {
try {
String argFileContents = Files.toString(new File(test.substring(1)), Charsets.UTF_8);
tests.addAll(Arrays.asList(argFileContents.split("\\s+")));
} catch (IOException e) {
System.err.printf("Failed to load args from arg file %s: %s\n", test, e.getMessage());
exit(1);
}
} else {
tests.add(test);
}
}
runner.run(tests);
}
private static boolean isTest(final Class<?> clazz) {
// Must be a public concrete class to be a runnable junit Test.
if (clazz.isInterface()
|| Modifier.isAbstract(clazz.getModifiers())
|| !Modifier.isPublic(clazz.getModifiers())) {
return false;
}
// Support junit 3.x Test hierarchy.
if (junit.framework.Test.class.isAssignableFrom(clazz)) {
return true;
}
// Support junit 4.x @Test annotated methods.
return Iterables.any(Arrays.asList(clazz.getMethods()), new Predicate<Method>() {
@Override public boolean apply(Method method) {
return Modifier.isPublic(method.getModifiers())
&& method.isAnnotationPresent(org.junit.Test.class);
}
});
}
private static void exit(int code) {
// We're a main - its fine to exit.
// SUPPRESS CHECKSTYLE RegexpSinglelineJava
System.exit(code);
}
}