package org.junit.experimental.theories; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.List; import org.junit.Assert; import org.junit.experimental.theories.PotentialAssignment.CouldNotGenerateValueException; import org.junit.experimental.theories.internal.Assignments; import org.junit.experimental.theories.internal.ParameterizedAssertionError; import org.junit.internal.AssumptionViolatedException; import org.junit.runners.BlockJUnit4ClassRunner; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.InitializationError; import org.junit.runners.model.Statement; import org.junit.runners.model.TestClass; public class Theories extends BlockJUnit4ClassRunner { public Theories(Class<?> klass) throws InitializationError { super(klass); } @Override protected void collectInitializationErrors(List<Throwable> errors) { super.collectInitializationErrors(errors); validateDataPointFields(errors); } private void validateDataPointFields(List<Throwable> errors) { Field[] fields = getTestClass().getJavaClass().getDeclaredFields(); for (Field field : fields) { if (field.getAnnotation(DataPoint.class) == null) { continue; } if (!Modifier.isStatic(field.getModifiers())) { errors.add(new Error("DataPoint field " + field.getName() + " must be static")); } if (!Modifier.isPublic(field.getModifiers())) { errors.add(new Error("DataPoint field " + field.getName() + " must be public")); } } } @Override protected void validateConstructor(List<Throwable> errors) { validateOnlyOneConstructor(errors); } @Override protected void validateTestMethods(List<Throwable> errors) { for (FrameworkMethod each : computeTestMethods()) { if (each.getAnnotation(Theory.class) != null) { each.validatePublicVoid(false, errors); } else { each.validatePublicVoidNoArg(false, errors); } } } @Override protected List<FrameworkMethod> computeTestMethods() { List<FrameworkMethod> testMethods = super.computeTestMethods(); List<FrameworkMethod> theoryMethods = getTestClass().getAnnotatedMethods(Theory.class); testMethods.removeAll(theoryMethods); testMethods.addAll(theoryMethods); return testMethods; } @Override public Statement methodBlock(final FrameworkMethod method) { return new TheoryAnchor(method, getTestClass()); } public static class TheoryAnchor extends Statement { private int successes = 0; private FrameworkMethod fTestMethod; private TestClass fTestClass; private List<AssumptionViolatedException> fInvalidParameters = new ArrayList<AssumptionViolatedException>(); public TheoryAnchor(FrameworkMethod method, TestClass testClass) { fTestMethod = method; fTestClass = testClass; } private TestClass getTestClass() { return fTestClass; } @Override public void evaluate() throws Throwable { runWithAssignment(Assignments.allUnassigned( fTestMethod.getMethod(), getTestClass())); if (successes == 0) { Assert .fail("Never found parameters that satisfied method assumptions. Violated assumptions: " + fInvalidParameters); } } protected void runWithAssignment(Assignments parameterAssignment) throws Throwable { if (!parameterAssignment.isComplete()) { runWithIncompleteAssignment(parameterAssignment); } else { runWithCompleteAssignment(parameterAssignment); } } protected void runWithIncompleteAssignment(Assignments incomplete) throws InstantiationException, IllegalAccessException, Throwable { for (PotentialAssignment source : incomplete .potentialsForNextUnassigned()) { runWithAssignment(incomplete.assignNext(source)); } } protected void runWithCompleteAssignment(final Assignments complete) throws InstantiationException, IllegalAccessException, InvocationTargetException, NoSuchMethodException, Throwable { new BlockJUnit4ClassRunner(getTestClass().getJavaClass()) { @Override protected void collectInitializationErrors( List<Throwable> errors) { // do nothing } @Override public Statement methodBlock(FrameworkMethod method) { final Statement statement = super.methodBlock(method); return new Statement() { @Override public void evaluate() throws Throwable { try { statement.evaluate(); handleDataPointSuccess(); } catch (AssumptionViolatedException e) { handleAssumptionViolation(e); } catch (Throwable e) { reportParameterizedError(e, complete .getArgumentStrings(nullsOk())); } } }; } @Override protected Statement methodInvoker(FrameworkMethod method, Object test) { return methodCompletesWithParameters(method, complete, test); } @Override public Object createTest() throws Exception { return getTestClass().getOnlyConstructor().newInstance( complete.getConstructorArguments(nullsOk())); } }.methodBlock(fTestMethod).evaluate(); } private Statement methodCompletesWithParameters( final FrameworkMethod method, final Assignments complete, final Object freshInstance) { return new Statement() { @Override public void evaluate() throws Throwable { try { final Object[] values = complete.getMethodArguments( nullsOk()); method.invokeExplosively(freshInstance, values); } catch (CouldNotGenerateValueException e) { // ignore } } }; } protected void handleAssumptionViolation(AssumptionViolatedException e) { fInvalidParameters.add(e); } protected void reportParameterizedError(Throwable e, Object... params) throws Throwable { if (params.length == 0) { throw e; } throw new ParameterizedAssertionError(e, fTestMethod.getName(), params); } private boolean nullsOk() { Theory annotation = fTestMethod.getMethod().getAnnotation( Theory.class); if (annotation == null) { return false; } return annotation.nullsAccepted(); } protected void handleDataPointSuccess() { successes++; } } }