package org.sky.auto.runner; import org.junit.internal.builders.AllDefaultPossibilitiesBuilder; import org.junit.runner.Runner; import org.junit.runners.Suite; import org.junit.runners.model.InitializationError; import org.junit.runners.model.RunnerBuilder; import org.junit.runners.model.RunnerScheduler; import org.sky.auto.anno.ThreadRunner; import java.util.Arrays; import java.util.LinkedList; import java.util.List; import java.util.Queue; import java.util.concurrent.CompletionService; import java.util.concurrent.ExecutorCompletionService; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; public final class MyJUnitThreadSuitRunner extends Suite { public MyJUnitThreadSuitRunner(final Class<?> klass) throws InitializationError { super(klass, new AllDefaultPossibilitiesBuilder(true) { @Override public Runner runnerForClass(Class<?> testClass) throws Throwable { List<RunnerBuilder> builders = Arrays.asList( new RunnerBuilder() { @Override public Runner runnerForClass(Class<?> testClass) throws Throwable { ThreadRunner annotation = testClass.getAnnotation(ThreadRunner.class); if (annotation != null) return new MyJUnitThreadRunner(testClass); return null; } }, ignoredBuilder(), annotatedBuilder(), suiteMethodBuilder(), junit3Builder(), junit4Builder()); for (RunnerBuilder each : builders) { Runner runner = each.safeRunnerForClass(testClass); if (runner != null) return runner; } return null; } }); setScheduler(new RunnerScheduler() { ExecutorService executorService = Executors.newFixedThreadPool( klass.isAnnotationPresent(ThreadRunner.class) ? klass.getAnnotation(ThreadRunner.class).threads() : (int) (Runtime.getRuntime().availableProcessors() * 1.5), new NamedThreadFactory(klass.getSimpleName())); CompletionService<Void> completionService = new ExecutorCompletionService<Void>(executorService); Queue<Future<Void>> tasks = new LinkedList<Future<Void>>(); public void schedule(Runnable childStatement) { tasks.offer(completionService.submit(childStatement, null)); } public void finished() { try { while (!tasks.isEmpty()) tasks.remove(completionService.take()); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } finally { while (!tasks.isEmpty()) tasks.poll().cancel(true); executorService.shutdownNow(); } } }); } static final class NamedThreadFactory implements ThreadFactory { static final AtomicInteger poolNumber = new AtomicInteger(1); final AtomicInteger threadNumber = new AtomicInteger(1); final ThreadGroup group; NamedThreadFactory(String poolName) { group = new ThreadGroup(poolName + "-" + poolNumber.getAndIncrement()); } public Thread newThread(Runnable r) { return new Thread(group, r, group.getName() + "-thread-" + threadNumber.getAndIncrement(), 0); } } }