package org.junit.experimental.theories.internal; import java.lang.reflect.Array; import java.lang.reflect.Field; import java.lang.reflect.Modifier; import java.util.ArrayList; import java.util.List; import org.junit.experimental.theories.DataPoint; import org.junit.experimental.theories.DataPoints; import org.junit.experimental.theories.ParameterSignature; import org.junit.experimental.theories.ParameterSupplier; import org.junit.experimental.theories.PotentialAssignment; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.TestClass; /** * Supplies Theory parameters based on all public members of the target class. */ public class AllMembersSupplier extends ParameterSupplier { static class MethodParameterValue extends PotentialAssignment { private final FrameworkMethod fMethod; private MethodParameterValue(final FrameworkMethod dataPointMethod) { fMethod = dataPointMethod; } @Override public Object getValue() throws CouldNotGenerateValueException { try { return fMethod.invokeExplosively(null); } catch (IllegalArgumentException e) { throw new RuntimeException("unexpected: argument length is checked"); } catch (IllegalAccessException e) { throw new RuntimeException("unexpected: getMethods returned an inaccessible method"); } catch (Throwable e) { throw new CouldNotGenerateValueException(); // do nothing, just look for more values } } @Override public String getDescription() throws CouldNotGenerateValueException { return fMethod.getName(); } } private final TestClass fClass; /** * Constructs a new supplier for {@code type} */ public AllMembersSupplier(final TestClass type) { fClass = type; } @Override public List<PotentialAssignment> getValueSources(final ParameterSignature sig) { List<PotentialAssignment> list = new ArrayList<PotentialAssignment>(); addFields(sig, list); addSinglePointMethods(sig, list); addMultiPointMethods(sig, list); return list; } private void addMultiPointMethods(final ParameterSignature sig, final List<PotentialAssignment> list) { for (FrameworkMethod dataPointsMethod : fClass.getAnnotatedMethods(DataPoints.class)) { try { addMultiPointArrayValues(sig, dataPointsMethod.getName(), list, dataPointsMethod.invokeExplosively(null)); } catch (Throwable e) { // ignore and move on } } } private void addSinglePointMethods(final ParameterSignature sig, final List<PotentialAssignment> list) { for (FrameworkMethod dataPointMethod : fClass.getAnnotatedMethods(DataPoint.class)) { if (isCorrectlyTyped(sig, dataPointMethod.getType())) { list.add(new MethodParameterValue(dataPointMethod)); } } } private void addFields(final ParameterSignature sig, final List<PotentialAssignment> list) { for (final Field field : fClass.getJavaClass().getFields()) { if (Modifier.isStatic(field.getModifiers())) { Class<?> type = field.getType(); if (sig.canAcceptArrayType(type) && field.getAnnotation(DataPoints.class) != null) { try { addArrayValues(field.getName(), list, getStaticFieldValue(field)); } catch (Throwable e) { // ignore and move on } } else if (sig.canAcceptType(type) && field.getAnnotation(DataPoint.class) != null) { list.add(PotentialAssignment.forValue(field.getName(), getStaticFieldValue(field))); } } } } private void addArrayValues(final String name, final List<PotentialAssignment> list, final Object array) { for (int i = 0; i < Array.getLength(array); i++) { list.add(PotentialAssignment.forValue(name + "[" + i + "]", Array.get(array, i))); } } private void addMultiPointArrayValues(final ParameterSignature sig, final String name, final List<PotentialAssignment> list, final Object array) throws Throwable { for (int i = 0; i < Array.getLength(array); i++) { if (!isCorrectlyTyped(sig, Array.get(array, i).getClass())) { return; } list.add(PotentialAssignment.forValue(name + "[" + i + "]", Array.get(array, i))); } } private boolean isCorrectlyTyped(final ParameterSignature parameterSignature, final Class<?> type) { return parameterSignature.canAcceptType(type); } private Object getStaticFieldValue(final Field field) { try { return field.get(null); } catch (IllegalArgumentException e) { throw new RuntimeException("unexpected: field from getClass doesn't exist on object"); } catch (IllegalAccessException e) { throw new RuntimeException("unexpected: getFields returned an inaccessible field"); } } }