package org.javaee7;
import org.jboss.arquillian.container.test.api.Deployment;
import org.junit.rules.MethodRule;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.Statement;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.List;
/**
* Helper class for Parametrized tests as described here:
* http://blog.schauderhaft.de/2012/12/16/writing-parameterized-tests-with-junit-rules/
*
* @param <T>
*/
public class ParameterRule<T> implements MethodRule {
private final List<T> params;
public ParameterRule(List<T> params) {
if (params == null || params.size() == 0) {
throw new IllegalArgumentException("'params' must be specified and have more then zero length!");
}
this.params = params;
}
@Override
public Statement apply(final Statement base, final FrameworkMethod method, final Object target) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
boolean runInContainer = getDeploymentMethod(target).getAnnotation(Deployment.class).testable();
if (runInContainer) {
evaluateParametersInContainer(base, target);
} else {
evaluateParametersInClient(base, target);
}
}
};
}
private Method getDeploymentMethod(Object target) throws NoSuchMethodException {
Method[] methods = target.getClass().getDeclaredMethods();
for (Method method : methods) {
if (method.getAnnotation(Deployment.class) != null) return method;
}
throw new IllegalStateException("No method with @Deployment annotation found!");
}
private void evaluateParametersInContainer(Statement base, Object target) throws Throwable {
if (isRunningInContainer()) {
evaluateParamsToTarget(base, target);
} else {
ignoreStatementExecution(base);
}
}
private void evaluateParametersInClient(Statement base, Object target) throws Throwable {
if (isRunningInContainer()) {
ignoreStatementExecution(base);
} else {
evaluateParamsToTarget(base, target);
}
}
private boolean isRunningInContainer() {
try {
new InitialContext().lookup("java:comp/env");
return true;
} catch (NamingException e) {
return false;
}
}
private void evaluateParamsToTarget(Statement base, Object target) throws Throwable {
for (Object param : params) {
Field targetField = getTargetField(target);
if (!targetField.isAccessible()) {
targetField.setAccessible(true);
}
targetField.set(target, param);
base.evaluate();
}
}
private Field getTargetField(Object target) throws NoSuchFieldException {
Field[] allFields = target.getClass().getDeclaredFields();
for (Field field : allFields) {
if (field.getAnnotation(Parameter.class) != null) return field;
}
throw new IllegalStateException("No field with @Parameter annotation found! Forgot to add it?");
}
private void ignoreStatementExecution(Statement base) {
try {
base.evaluate();
} catch (Throwable ignored) {}
}
}