package net.thucydides.junit.runners; import ch.lambdaj.function.convert.Converter; import com.google.common.base.Splitter; import net.thucydides.core.csv.CSVTestDataSource; import net.thucydides.core.csv.TestDataSource; import net.thucydides.core.guice.Injectors; import net.thucydides.core.model.DataTable; import net.thucydides.core.steps.FilePathParser; import net.thucydides.core.util.EnvironmentVariables; import net.thucydides.junit.annotations.TestData; import net.thucydides.junit.annotations.UseTestDataFrom; import org.apache.commons.lang3.StringUtils; import org.junit.runners.model.FrameworkMethod; import org.junit.runners.model.TestClass; import java.io.IOException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.regex.Pattern; import static ch.lambdaj.Lambda.convert; public class DataDrivenAnnotations { private final EnvironmentVariables environmentVariables; private final Pattern DATASOURCE_PATH_SEPARATORS = Pattern.compile("[;,]"); public static DataDrivenAnnotations forClass(final Class testClass) { return new DataDrivenAnnotations(testClass); } public static DataDrivenAnnotations forClass(final TestClass testClass) { return new DataDrivenAnnotations(testClass); } private final TestClass testClass; DataDrivenAnnotations(final Class testClass) { this(new TestClass(testClass)); } DataDrivenAnnotations(final TestClass testClass) { this(testClass, Injectors.getInjector().getProvider(EnvironmentVariables.class).get()); } DataDrivenAnnotations(final TestClass testClass, EnvironmentVariables environmentVariables) { this.testClass = testClass; this.environmentVariables = environmentVariables; } DataDrivenAnnotations usingEnvironmentVariables(EnvironmentVariables environmentVariables) { return new DataDrivenAnnotations(this.testClass, environmentVariables); } public DataTable getParametersTableFromTestDataSource() throws Throwable { TestDataSource testDataSource = new CSVTestDataSource(findTestDataSource(), findTestDataSeparator()); List<Map<String, String>> testData = testDataSource.getData(); List<String> headers = testDataSource.getHeaders(); return DataTable.withHeaders(headers) .andMappedRows(testData) .build(); } public DataTable getParametersTableFromTestDataAnnotation() { Method testDataMethod; String columnNamesString; List parametersList; try { testDataMethod = getTestDataMethod().getMethod(); columnNamesString = testDataMethod.getAnnotation(TestData.class).columnNames(); parametersList = (List) testDataMethod.invoke(null); } catch (Exception e) { throw new RuntimeException("Could not obtain test data from the test class", e); } return createParametersTableFrom(columnNamesString, convert(parametersList, toListOfObjects())); } private Converter<Object[], List<Object>> toListOfObjects() { return new Converter<Object[], List<Object>>() { public List<Object> convert(Object[] parameters) { return Arrays.asList(parameters); } }; } private DataTable createParametersTableFrom(String columnNamesString, List<List<Object>> parametersList) { int numberOfColumns = parametersList.isEmpty() ? 0 : parametersList.get(0).size(); List<String> columnNames = split(columnNamesString, numberOfColumns); return DataTable.withHeaders(columnNames) .andRows(parametersList) .build(); } private List<String> split(String columnNamesString, int numberOfColumns) { String[] columnNames = new String[numberOfColumns]; if (columnNamesString.equals("")) { for (int i = 0; i < numberOfColumns; i++) { columnNames[i] = "Parameter " + (i + 1); } } else { columnNames = StringUtils.split(columnNamesString, ",", numberOfColumns); } return Arrays.asList(columnNames); } public FrameworkMethod getTestDataMethod() throws Exception { FrameworkMethod method = findTestDataMethod(); if (method == null) { throw new IllegalArgumentException("No public static @FilePathParser method on class " + testClass.getName()); } return method; } private FrameworkMethod findTestDataMethod() { List<FrameworkMethod> methods = testClass.getAnnotatedMethods(TestData.class); for (FrameworkMethod each : methods) { int modifiers = each.getMethod().getModifiers(); if (Modifier.isStatic(modifiers) && Modifier.isPublic(modifiers)) { return each; } } return null; } @SuppressWarnings("MalformedRegex") protected String findTestDataSource() { String paths = findTestDataSourcePaths(); for (String path : Splitter.on(DATASOURCE_PATH_SEPARATORS).split(paths)) { if (CSVTestDataSource.validTestDataPath(path)) { return path; } } throw new IllegalArgumentException("No test data file found for path: " + paths); } protected String findTestDataSourcePaths() { return new FilePathParser(environmentVariables).getInstanciatedPath(findUseTestDataFromAnnotation().value()); } private UseTestDataFrom findUseTestDataFromAnnotation() { return testClass.getJavaClass().getAnnotation(UseTestDataFrom.class); } public boolean hasTestDataDefined() { return (findTestDataMethod() != null); } public boolean hasTestDataSourceDefined() { return (findUseTestDataFromAnnotation() != null) && (findTestDataSource() != null); } public <T> List<T> getDataAsInstancesOf(final Class<T> clazz) throws IOException { TestDataSource testdata = new CSVTestDataSource(findTestDataSource(), findTestDataSeparator()); return testdata.getDataAsInstancesOf(clazz); } public int countDataEntries() throws IOException { TestDataSource testdata = new CSVTestDataSource(findTestDataSource(), findTestDataSeparator()); return testdata.getData().size(); } private char findTestDataSeparator() { return findUseTestDataFromAnnotation().separator(); } }