package com.tngtech.archunit.junit; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import com.tngtech.archunit.integration.junit.ExpectedViolationFrom; import org.junit.runner.Description; import org.junit.runner.notification.Failure; import org.junit.runner.notification.RunNotifier; import org.junit.runners.model.InitializationError; import org.junit.runners.model.Statement; /** * Doesn't reside within integration.junit, because this forces extended visibility on production code */ public class ArchUnitIntegrationTestRunner extends ArchUnitRunner { private ExpectedViolation expectedViolation; public ArchUnitIntegrationTestRunner(Class<?> testClass) throws InitializationError { super(testClass); } @Override protected void runChild(final ArchTestExecution child, final RunNotifier notifier) { expectedViolation = ExpectedViolation.none(); Description description = describeChild(child); notifier.fireTestStarted(description); try { extractExpectedConfiguration(child).configure(expectedViolation); expectedViolation.apply(new IntegrationTestStatement(child), description).evaluate(); } catch (Throwable throwable) { notifier.fireTestFailure(new Failure(description, throwable)); } finally { notifier.fireTestFinished(description); } } private ExpectedViolationDefinition extractExpectedConfiguration(ArchTestExecution child) { ExpectedViolationFrom annotation = child.getAnnotation(ExpectedViolationFrom.class); if (annotation == null) { throw new RuntimeException("IntegrationTests need to annotate their @" + ArchTest.class.getSimpleName() + "'s with @" + ExpectedViolationFrom.class.getSimpleName()); } return new ExpectedViolationDefinition(annotation); } private class IntegrationTestStatement extends Statement { private final ArchRuleExecution child; public IntegrationTestStatement(ArchTestExecution child) { this.child = (ArchRuleExecution) child; } @Override public void evaluate() throws Throwable { FailureSniffer sniffer = new FailureSniffer(); ArchUnitIntegrationTestRunner.super.runChild(child, sniffer); sniffer.rethrowIfFailure(); } } private static class FailureSniffer extends RunNotifier { private Throwable exception; @Override public void fireTestFailure(Failure failure) { exception = failure.getException(); } void rethrowIfFailure() throws Throwable { if (exception != null) { throw exception; } } } private static class ExpectedViolationDefinition { private final Class<?> location; private final String method; public ExpectedViolationDefinition(ExpectedViolationFrom annotation) { location = annotation.location(); method = annotation.method(); } public void configure(ExpectedViolation expectedViolation) { try { Method expectViolation = location.getDeclaredMethod(method, ExpectedViolation.class); expectViolation.setAccessible(true); expectViolation.invoke(null, expectedViolation); } catch (NoSuchMethodException e) { throw new RuntimeException("Cannot find method '" + method + "' on " + location.getSimpleName()); } catch (InvocationTargetException | IllegalAccessException e) { throw new RuntimeException("Can't call method '" + method + "' on " + location.getSimpleName()); } } } }