package water;
import org.testng.annotations.*;
import water.util.FileUtils;
import water.util.Log;
import java.io.*;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.sql.Connection;
import java.sql.DriverManager;
import java.util.*;
import static water.util.FileUtils.*;
public class AccuracyTestingSuite {
private String logDir;
private String resultsDBTableConfig;
private int numH2ONodes;
private String dataSetsCSVPath;
private String testCasesCSVPath;
private String testCasesFilterString;
public static PrintStream summaryLog;
private Connection resultsDBTableConn;
private boolean recordResults;
public static List<String> dataSetsCSVRows;
private ArrayList<TestCase> testCasesList;
@BeforeClass
@Parameters( {"logDir", "resultsDBTableConfig", "numH2ONodes", "dataSetsCSVPath", "testCasesCSVPath",
"testCasesFilterString" } )
private void accuracySuiteSetup(@org.testng.annotations.Optional("h2o-test-accuracy") String logDir,
@org.testng.annotations.Optional("") String resultsDBTableConfig,
@org.testng.annotations.Optional("1") String numH2ONodes,
@org.testng.annotations.Optional("h2o-test-accuracy/src/test/resources/accuracyDataSets.csv")
String dataSetsCSVPath,
@org.testng.annotations.Optional("h2o-test-accuracy/src/test/resources/accuracyTestCases.csv")
String testCasesCSVPath,
@org.testng.annotations.Optional("") String testCasesFilterString)
{
// Logging
this.logDir = logDir;
File resultsDir = null, h2oLogsDir = null;
try {
resultsDir = new File(locateFile(logDir).getCanonicalFile().toString() + "/results");
h2oLogsDir = new File(locateFile(logDir).getCanonicalFile().toString() + "/results/h2ologs");
} catch (IOException e) {
System.out.println("Couldn't create directory.");
e.printStackTrace();
System.exit(-1);
}
resultsDir.mkdir();
for(File f: resultsDir.listFiles()) f.delete();
h2oLogsDir.mkdir();
for(File f: h2oLogsDir.listFiles()) f.delete();
File suiteSummary;
try {
suiteSummary = new File(locateFile(logDir).getCanonicalFile().toString() +
"/results/accuracySuiteSummary.log");
suiteSummary.createNewFile();
summaryLog = new PrintStream(new FileOutputStream(suiteSummary, false));
} catch (IOException e) {
System.out.println("Couldn't create the accuracySuiteSummary.log");
e.printStackTrace();
System.exit(-1);
}
System.out.println("Commenced logging to h2o-test-accuracy/results directory.");
// Results database table
this.resultsDBTableConfig = resultsDBTableConfig;
if (this.resultsDBTableConfig.isEmpty()) {
summaryLog.println("No results database configuration specified, so test case results will not be saved.");
recordResults = false;
} else {
summaryLog.println("Results database configuration specified specified by: " + this.resultsDBTableConfig);
resultsDBTableConn = makeResultsDBTableConn();
recordResults = true;
}
// H2O Cloud
this.numH2ONodes = Integer.parseInt(numH2ONodes);
summaryLog.println("Setting up the H2O Cloud with " + this.numH2ONodes + " nodes.");;
AccuracyTestingUtil.setupH2OCloud(this.numH2ONodes, this.logDir);
// Data sets
this.dataSetsCSVPath = dataSetsCSVPath;
File dataSetsFile = locateFile(this.dataSetsCSVPath);
try {
dataSetsCSVRows = Files.readAllLines(dataSetsFile.toPath(), Charset.defaultCharset());
} catch (IOException e) {
summaryLog.println("Cannot read the lines of the the data sets file: " + dataSetsFile.toPath());
writeStackTrace(e,summaryLog);
System.exit(-1);
}
dataSetsCSVRows.remove(0); // remove the header
// Test Cases
this.testCasesCSVPath = testCasesCSVPath;
this.testCasesFilterString = testCasesFilterString;
testCasesList = makeTestCasesList();
}
@Test
public void accuracyTest() {
TestCase tc = null;
TestCaseResult tcResult;
int id;
boolean suiteFailure = false;
Iterator i = testCasesList.iterator();
while(i.hasNext()) {
tc = (TestCase) i.next();
id = tc.getTestCaseId();
try {
//removeAll();
summaryLog.println("\n-----------------------------");
summaryLog.println("Accuracy Suite Test Case: " + id);
summaryLog.println("-----------------------------\n");
Log.info("-----------------------------");
Log.info("Accuracy Suite Test Case: " + id);
Log.info("-----------------------------");
tcResult = tc.execute();
tcResult.printValidationMetrics(tc.isCrossVal());
if (recordResults) {
summaryLog.println("Recording test case " + id + " result.");
tcResult.saveToAccuracyTable(resultsDBTableConn);
}
} catch (Exception e) {
StringWriter stringWriter = new StringWriter();
e.printStackTrace(new PrintWriter(stringWriter));
String stackTraceString = stringWriter.toString();
Log.err("Test case " + id + " failed on: ");
Log.err(stackTraceString);
summaryLog.println("Test case " + id + " failed on: ");
summaryLog.println(stackTraceString);
suiteFailure = true;
} catch (AssertionError ae) {
Log.err("Test case " + id + " failed on: ");
Log.err(ae.getMessage());
summaryLog.println("Test case " + id + " failed on: ");
summaryLog.println(ae.getMessage());
suiteFailure = true;
}
}
if (suiteFailure) {
System.out.println("The suite failed due to one or more test case failures.");
System.exit(-1);
}
}
private ArrayList<TestCase> makeTestCasesList() {
String[] algorithms = filterForAlgos(testCasesFilterString);
String[] testCases = filterForTestCases(testCasesFilterString);
List<String> testCaseEntries = null;
try {
summaryLog.println("Reading test cases from: " + testCasesCSVPath);
File testCasesFile = getFile(testCasesCSVPath);
testCaseEntries = Files.readAllLines(testCasesFile.toPath(), Charset.defaultCharset());
}
catch (Exception e) {
summaryLog.println("Cannot read the test cases from: " + testCasesCSVPath);
writeStackTrace(e,summaryLog);
System.exit(-1);
}
testCaseEntries.remove(0); // remove header line
ArrayList<TestCase> testCaseArray = new ArrayList<>();
String[] testCaseEntry;
for (String t : testCaseEntries) {
testCaseEntry = t.trim().split(",", -1);
// If algorithms are specified in the testCaseFilterString, load all test cases for these algorithms. Otherwise,
// if specific test cases are specified, then only load those. Else, load all the test cases.
if (null != algorithms) {
if (!Arrays.asList(algorithms).contains(testCaseEntry[1])) { continue; }
}
else if (null != testCases) {
if (!Arrays.asList(testCases).contains(testCaseEntry[0])) { continue; }
}
summaryLog.println("Creating test case: " + t);
try {
testCaseArray.add(
new TestCase(Integer.parseInt(testCaseEntry[0]), testCaseEntry[1], testCaseEntry[2],
testCaseEntry[3].equals("1"), testCaseEntry[4], testCaseEntry[5], testCaseEntry[6],
testCaseEntry[7].equals("1"), Integer.parseInt(testCaseEntry[8]),
Integer.parseInt(testCaseEntry[9]), testCaseEntry[10])
);
} catch (Exception e) {
summaryLog.println("Couldn't create test case: " + t);
writeStackTrace(e, summaryLog);
System.exit(-1);
}
}
return testCaseArray;
}
private String[] filterForAlgos(String selectionString) {
if (selectionString.isEmpty()) return null;
String algoSelectionString = selectionString.trim().split(";", -1)[0];
if (algoSelectionString.isEmpty()) return null;
return algoSelectionString.trim().split(",", -1);
}
private String[] filterForTestCases(String selectionString) {
if (selectionString.isEmpty()) return null;
String testCaseSelectionString = selectionString.trim().split(";", -1)[1];
if (null == testCaseSelectionString || testCaseSelectionString.isEmpty()) return null;
return testCaseSelectionString.trim().split(",", -1);
}
private Connection makeResultsDBTableConn() {
Connection connection = null;
try {
summaryLog.println("Reading the database configuration settings from: " + resultsDBTableConfig);
File configFile = new File(resultsDBTableConfig);
Properties properties = new Properties();
properties.load(new BufferedReader(new FileReader(configFile)));
summaryLog.println("Establishing connection to the database.");
Class.forName("com.mysql.jdbc.Driver");
String url = String.format("jdbc:mysql://%s:%s/%s", properties.getProperty("db.host"),
properties.getProperty("db.port"), properties.getProperty("db.databaseName"));
connection = DriverManager.getConnection(url, properties.getProperty("db.user"),
properties.getProperty("db.password"));
} catch (Exception e) {
summaryLog.println("Couldn't connect to the database.");
writeStackTrace(e, summaryLog);
System.exit(-1);
}
return connection;
}
private static void writeStackTrace(Exception e, PrintStream ps) {
StringWriter stringWriter = new StringWriter();
e.printStackTrace(new PrintWriter(stringWriter));
ps.println(stringWriter.toString());
}
private void removeAll() {
//FIXME: This was just copied over from RemoveAllHandler.
summaryLog.println("Removing all.");
Futures fs = new Futures();
for( Job j : Job.jobs() ) { j.stop(); j.remove(fs); }
fs.blockForPending();
new MRTask(){
@Override public void setupLocal() { H2O.raw_clear(); water.fvec.Vec.ESPC.clear(); }
}.doAllNodes();
H2O.getPM().getIce().cleanUp();
}
}