package water;
import hex.*;
import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import hex.genmodel.utils.DistributionFamily;
import hex.glm.GLM;
import hex.glm.GLMModel;
import hex.grid.Grid;
import hex.grid.GridSearch;
import hex.grid.HyperSpaceSearchCriteria;
import hex.tree.SharedTreeModel;
import hex.tree.drf.DRF;
import hex.tree.drf.DRFModel;
import hex.tree.gbm.GBM;
import hex.tree.gbm.GBMModel;
import water.api.SchemaServer;
import water.fvec.FVecTest;
import water.fvec.Frame;
import water.parser.ParseDataset;
import hex.schemas.HyperSpaceSearchCriteriaV99.RandomDiscreteValueSearchCriteriaV99;
import hex.schemas.GBMV3.GBMParametersV3;
import hex.schemas.DeepLearningV3.DeepLearningParametersV3;
import hex.schemas.GLMV3.GLMParametersV3;
import hex.schemas.DRFV3.DRFParametersV3;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
public class TestCase {
private int testCaseId;
//Only make these public to update table in TestCaseResult
public String algo;
public String algoParameters;
public boolean grid;
public String gridParameters;
public String searchParameters;
public String modelSelectionCriteria;
public boolean regression;
private int trainingDataSetId;
private int testingDataSetId;
private String testCaseDescription;
private Model.Parameters params;
private DataSet trainingDataSet;
private DataSet testingDataSet;
private HashMap<String, Object[]> hyperParms;
private HyperSpaceSearchCriteria searchCriteria;
private static boolean glmRegistered = false;
private static boolean gbmRegistered = false;
private static boolean drfRegistered = false;
private static boolean dlRegistered = false;
public TestCase(int testCaseId, String algo, String algoParameters, boolean grid, String gridParameters,
String searchParameters, String modelSelectionCriteria, boolean regression, int trainingDataSetId,
int testingDataSetId, String testCaseDescription) throws Exception {
this.testCaseId = testCaseId;
this.algo = algo;
this.algoParameters = algoParameters;
this.grid = grid;
this.gridParameters = gridParameters;
this.searchParameters = searchParameters;
this.modelSelectionCriteria = modelSelectionCriteria;
this.regression = regression;
this.trainingDataSetId = trainingDataSetId;
this.testingDataSetId = testingDataSetId;
this.testCaseDescription = testCaseDescription;
trainingDataSet = new DataSet(this.trainingDataSetId);
testingDataSet = new DataSet(this.testingDataSetId);
}
public int getTestCaseId() {
return testCaseId;
}
public boolean isCrossVal() { return params._nfolds > 0; }
public TestCaseResult execute() throws Exception, AssertionError {
loadTestCaseDataSets();
makeModelParameters();
double startTime = 0, stopTime = 0;
if (!grid) {
Model.Output modelOutput = null;
DRF drfJob;
DRFModel drfModel = null;
GLM glmJob;
GLMModel glmModel = null;
GBM gbmJob;
GBMModel gbmModel = null;
DeepLearning dlJob;
DeepLearningModel dlModel = null;
String bestModelJson = null;
try {
switch (algo) {
case "drf":
drfJob = new DRF((DRFModel.DRFParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DRF model.");
startTime = System.currentTimeMillis();
drfModel = drfJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = drfModel._output;
bestModelJson = drfModel._parms.toJsonString();
break;
case "glm":
glmJob = new GLM((GLMModel.GLMParameters) params, Key.<GLMModel>make("GLMModel"));
AccuracyTestingSuite.summaryLog.println("Training GLM model.");
startTime = System.currentTimeMillis();
glmModel = glmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = glmModel._output;
bestModelJson = glmModel._parms.toJsonString();
break;
case "gbm":
gbmJob = new GBM((GBMModel.GBMParameters) params);
AccuracyTestingSuite.summaryLog.println("Training GBM model.");
startTime = System.currentTimeMillis();
gbmModel = gbmJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = gbmModel._output;
bestModelJson = gbmModel._parms.toJsonString();
break;
case "dl":
dlJob = new DeepLearning((DeepLearningModel.DeepLearningParameters) params);
AccuracyTestingSuite.summaryLog.println("Training DL model.");
startTime = System.currentTimeMillis();
dlModel = dlJob.trainModel().get();
stopTime = System.currentTimeMillis();
modelOutput = dlModel._output;
bestModelJson = dlModel._parms.toJsonString();
break;
}
} catch (Exception e) {
throw new Exception(e);
} finally {
if (drfModel != null) {
drfModel.delete();
}
if (glmModel != null) {
glmModel.delete();
}
if (gbmModel != null) {
gbmModel.delete();
}
if (dlModel != null) {
dlModel.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if(params._nfolds > 0){
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics),
getMetrics(modelOutput._cross_validation_metrics), stopTime - startTime, bestModelJson, this,
trainingDataSet, testingDataSet);
}
else{
return new TestCaseResult(testCaseId, getMetrics(modelOutput._training_metrics),
getMetrics(modelOutput._validation_metrics), stopTime - startTime, bestModelJson, this,
trainingDataSet, testingDataSet);
}
} else {
assert !modelSelectionCriteria.equals("");
makeGridParameters();
makeSearchCriteria();
Grid grid = null;
Model bestModel = null;
String bestModelJson = null;
try {
SchemaServer.registerAllSchemasIfNecessary();
switch (algo) { // TODO: Hack for PUBDEV-2812
case "drf":
if (!drfRegistered) {
new DRF(true);
new DRFParametersV3();
drfRegistered = true;
}
break;
case "glm":
if (!glmRegistered) {
new GLM(true);
new GLMParametersV3();
glmRegistered = true;
}
break;
case "gbm":
if (!gbmRegistered) {
new GBM(true);
new GBMParametersV3();
gbmRegistered = true;
}
break;
case "dl":
if (!dlRegistered) {
new DeepLearning(true);
new DeepLearningParametersV3();
dlRegistered = true;
}
break;
}
startTime = System.currentTimeMillis();
// TODO: ModelParametersBuilderFactory parameter must be instantiated properly
Job<Grid> gs = GridSearch.startGridSearch(null,params,hyperParms,
new GridSearch.SimpleParametersBuilderFactory<>(),searchCriteria);
grid = gs.get();
stopTime = System.currentTimeMillis();
boolean higherIsBetter = higherIsBetter(modelSelectionCriteria);
double bestScore = higherIsBetter ? -Double.MAX_VALUE : Double.MAX_VALUE;
for (Model m : grid.getModels()) {
double validationMetricScore = getMetrics(m._output._validation_metrics).get(modelSelectionCriteria);
AccuracyTestingSuite.summaryLog.println(modelSelectionCriteria + " for model " + m._key.toString() + " is " +
validationMetricScore);
if (higherIsBetter ? validationMetricScore > bestScore : validationMetricScore < bestScore) {
bestScore = validationMetricScore;
bestModel = m;
bestModelJson = bestModel._parms.toJsonString();
}
}
AccuracyTestingSuite.summaryLog.println("Best model: " + bestModel._key.toString());
AccuracyTestingSuite.summaryLog.println("Best model parameters: " + bestModelJson);
} catch (Exception e) {
throw new Exception(e);
} finally {
if (grid != null) {
grid.delete();
}
}
removeTestCaseDataSetFrames();
//Add check if cv is used
if(params._nfolds > 0){
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics),
getMetrics(bestModel._output._cross_validation_metrics), stopTime - startTime, bestModelJson, this,
trainingDataSet, testingDataSet);
}
else{
return new TestCaseResult(testCaseId, getMetrics(bestModel._output._training_metrics),
getMetrics(bestModel._output._validation_metrics), stopTime - startTime, bestModelJson, this,
trainingDataSet,testingDataSet);
}
}
}
private void loadTestCaseDataSets() throws IOException {
trainingDataSet.load(regression);
testingDataSet.load(regression);
}
private void removeTestCaseDataSetFrames() {
trainingDataSet.removeFrame();
testingDataSet.removeFrame();
}
private void makeModelParameters() throws Exception {
switch (algo) {
case "drf":
params = makeDrfModelParameters();
break;
case "glm":
params = makeGlmModelParameters();
break;
case "dl":
params = makeDlModelParameters();
break;
case "gbm":
params = makeGbmModelParameters();
break;
default:
throw new Exception("No algo: " + algo);
}
}
private void makeGridParameters() throws Exception {
switch (algo) {
case "drf":
hyperParms = makeDrfGridParameters();
break;
case "glm":
hyperParms = makeGlmGridParameters();
break;
case "dl":
hyperParms = makeDlGridParameters();
break;
case "gbm":
hyperParms = makeGbmGridParameters();
break;
default:
throw new Exception("Algo not supported: " + algo);
}
}
private void makeSearchCriteria() throws Exception {
AccuracyTestingSuite.summaryLog.println("Making Grid Search Criteria.");
String[] tokens = searchParameters.trim().split(";", -1);
HashMap<String, String> tokenMap = new HashMap<>();
for (int i = 0; i < tokens.length; i++)
tokenMap.put(tokens[i].split("=", -1)[0], tokens[i].split("=", -1)[1]);
if (tokenMap.containsKey("strategy") && tokenMap.get("strategy").equals("RandomDiscrete")) {
RandomDiscreteValueSearchCriteriaV99 sc = new RandomDiscreteValueSearchCriteriaV99();
if (tokenMap.containsKey("seed"))
sc.seed = Integer.parseInt(tokenMap.get("seed"));
if (tokenMap.containsKey("stopping_rounds"))
sc.stopping_rounds = Integer.parseInt(tokenMap.get("stopping_rounds"));
if (tokenMap.containsKey("stopping_tolerance"))
sc.stopping_tolerance = Double.parseDouble(tokenMap.get("stopping_tolerance"));
if (tokenMap.containsKey("max_runtime_secs"))
sc.max_runtime_secs = Double.parseDouble(tokenMap.get("max_runtime_secs"));
if (tokenMap.containsKey("max_models"))
sc.max_models = Integer.parseInt(tokenMap.get("max_models"));
searchCriteria = sc.createAndFillImpl();
} else {
throw new Exception(tokenMap.get("strategy") + " search strategy is not supported for grid search test cases");
}
}
private HashMap<String, Double> getMetrics(ModelMetrics mm) {
HashMap<String, Double> mmMap = new HashMap<String, Double>();
// Supervised metrics
mmMap.put("MSE", mm.mse());
// mmMap.put("R2", ((ModelMetricsSupervised) mm).r2());
// Regression metrics
if (mm instanceof ModelMetricsRegression) {
mmMap.put("MeanResidualDeviance", ((ModelMetricsRegression) mm)._mean_residual_deviance);
}
// Binomial metrics
if (mm instanceof ModelMetricsBinomial) {
mmMap.put("AUC", ((ModelMetricsBinomial) mm).auc());
mmMap.put("Gini", ((ModelMetricsBinomial) mm)._auc._gini);
mmMap.put("Logloss", ((ModelMetricsBinomial) mm).logloss());
mmMap.put("F1", ((ModelMetricsBinomial) mm).cm().f1());
mmMap.put("F2", ((ModelMetricsBinomial) mm).cm().f2());
mmMap.put("F0point5", ((ModelMetricsBinomial) mm).cm().f0point5());
mmMap.put("Accuracy", ((ModelMetricsBinomial) mm).cm().accuracy());
mmMap.put("Error", ((ModelMetricsBinomial) mm).cm().err());
mmMap.put("Precision", ((ModelMetricsBinomial) mm).cm().precision());
mmMap.put("Recall", ((ModelMetricsBinomial) mm).cm().recall());
mmMap.put("MCC", ((ModelMetricsBinomial) mm).cm().mcc());
mmMap.put("MaxPerClassError", ((ModelMetricsBinomial) mm).cm().max_per_class_error());
}
// Multinomial metrics
if (mm instanceof ModelMetricsMultinomial) {
mmMap.put("Logloss", ((ModelMetricsMultinomial) mm).logloss());
mmMap.put("Error", ((ModelMetricsMultinomial) mm).cm().err());
mmMap.put("MaxPerClassError", ((ModelMetricsMultinomial) mm).cm().max_per_class_error());
mmMap.put("Accuracy", ((ModelMetricsMultinomial) mm).cm().accuracy());
}
// GLM-specific metrics
if (mm instanceof ModelMetricsRegressionGLM) {
mmMap.put("ResidualDeviance", ((ModelMetricsRegressionGLM) mm)._resDev);
mmMap.put("ResidualDegreesOfFreedom", (double) ((ModelMetricsRegressionGLM) mm)._residualDegressOfFreedom);
mmMap.put("NullDeviance", ((ModelMetricsRegressionGLM) mm)._nullDev);
mmMap.put("NullDegreesOfFreedom", (double) ((ModelMetricsRegressionGLM) mm)._nullDegressOfFreedom);
mmMap.put("AIC", ((ModelMetricsRegressionGLM) mm)._AIC);
}
if (mm instanceof ModelMetricsBinomialGLM) {
mmMap.put("ResidualDeviance", ((ModelMetricsBinomialGLM) mm)._resDev);
mmMap.put("ResidualDegreesOfFreedom", (double) ((ModelMetricsBinomialGLM) mm)._residualDegressOfFreedom);
mmMap.put("NullDeviance", ((ModelMetricsBinomialGLM) mm)._nullDev);
mmMap.put("NullDegreesOfFreedom", (double) ((ModelMetricsBinomialGLM) mm)._nullDegressOfFreedom);
mmMap.put("AIC", ((ModelMetricsBinomialGLM) mm)._AIC);
}
return mmMap;
}
private GLMModel.GLMParameters makeGlmModelParameters() throws Exception {
AccuracyTestingSuite.summaryLog.println("Making GLM model parameters.");
GLMModel.GLMParameters glmParams = new GLMModel.GLMParameters();
String[] tokens = algoParameters.trim().split(";", -1);
for (int i = 0; i < tokens.length; i++) {
String parameterName = tokens[i].split("=", -1)[0];
String parameterValue = tokens[i].split("=", -1)[1];
switch (parameterName) {
case "_family":
switch (parameterValue) {
case "gaussian":
glmParams._family = GLMModel.GLMParameters.Family.gaussian;
break;
case "binomial":
glmParams._family = GLMModel.GLMParameters.Family.binomial;
break;
case "multinomial":
glmParams._family = GLMModel.GLMParameters.Family.multinomial;
break;
case "poisson":
glmParams._family = GLMModel.GLMParameters.Family.poisson;
break;
case "gamma":
glmParams._family = GLMModel.GLMParameters.Family.gamma;
break;
case "tweedie":
glmParams._family = GLMModel.GLMParameters.Family.tweedie;
break;
default:
throw new Exception(parameterValue + " family is not supported for gbm test cases");
}
break;
case "_solver":
switch (parameterValue) {
case "AUTO":
glmParams._solver = GLMModel.GLMParameters.Solver.AUTO;
break;
case "irlsm":
glmParams._solver = GLMModel.GLMParameters.Solver.IRLSM;
break;
case "lbfgs":
glmParams._solver = GLMModel.GLMParameters.Solver.L_BFGS;
break;
case "coordinate_descent_naive":
glmParams._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT_NAIVE;
break;
case "coordinate_descent":
glmParams._solver = GLMModel.GLMParameters.Solver.COORDINATE_DESCENT;
break;
default:
throw new Exception(parameterValue + " solver is not supported for gbm test cases");
}
break;
case "_nfolds":
glmParams._nfolds = Integer.parseInt(parameterValue);
break;
case "_fold_column":
glmParams._fold_column = tokens[i];
break;
case "_ignore_const_cols":
glmParams._ignore_const_cols = true;
break;
case "_offset_column":
glmParams._offset_column = tokens[i];
break;
case "_weights_column":
glmParams._weights_column = tokens[i];
break;
case "_alpha":
glmParams._alpha = new double[]{Double.parseDouble(parameterValue)};
break;
case "_lambda":
glmParams._lambda = new double[]{Double.parseDouble(parameterValue)};
break;
case "_lambda_search":
glmParams._lambda_search = true;
break;
case "_standardize":
glmParams._standardize = true;
break;
case "_non_negative":
glmParams._non_negative = true;
break;
case "_intercept":
glmParams._intercept = true;
break;
case "_prior":
glmParams._prior = Double.parseDouble(parameterValue);
break;
case "_max_active_predictors":
glmParams._max_active_predictors = Integer.parseInt(parameterValue);
break;
case "_beta_constraints":
double lowerBound = Double.parseDouble(tokens[i].split("|")[0]);
double upperBound = Double.parseDouble(tokens[i].split("|")[1]);
glmParams._beta_constraints = makeBetaConstraints(lowerBound, upperBound);
break;
default:
throw new Exception(parameterName + " parameter is not supported for glm test cases");
}
}
// _train, _valid, _response
glmParams._train = trainingDataSet.getFrame()._key;
glmParams._valid = testingDataSet.getFrame()._key;
glmParams._response_column = trainingDataSet.getFrame()._names[trainingDataSet.getResponseColumn()];
return glmParams;
}
private Key<Frame> makeBetaConstraints(double lowerBound, double upperBound) {
Frame trainingFrame = trainingDataSet.getFrame();
int responseColumn = trainingDataSet.getResponseColumn();
String betaConstraintsString = "names, lower_bounds, upper_bounds\n";
List<String> predictorNames = Arrays.asList(trainingFrame._names);
for (String name : predictorNames) {
// ignore the response column and any constant column in bc.
// we only want predictors
if (!name.equals(trainingFrame._names[responseColumn]) && !trainingFrame.vec(name).isConst()) {
// need coefficient names for each level of a categorical column
if (trainingFrame.vec(name).isCategorical()) {
for (String level : trainingFrame.vec(name).domain()) {
betaConstraintsString += String.format("%s.%s,%s,%s\n", name, level, lowerBound, upperBound);
}
} else { // numeric columns only need one coefficient name
betaConstraintsString += String.format("%s,%s,%s\n", name, lowerBound, upperBound);
}
}
}
Key betaConsKey = Key.make("beta_constraints");
FVecTest.makeByteVec(betaConsKey, betaConstraintsString);
return ParseDataset.parse(Key.make("beta_constraints.hex"), betaConsKey)._key;
}
private HashMap<String, Object[]> makeGlmGridParameters() throws Exception {
AccuracyTestingSuite.summaryLog.println("Making GLM grid parameters.");
String[] tokens = gridParameters.trim().split(";", -1);
HashMap<String, Object[]> glmHyperParms = new HashMap<String, Object[]>();
for (int i = 0; i < tokens.length; i++) {
if (tokens[i].equals("")) return glmHyperParms;
String gridParameterName = tokens[i].split("=", -1)[0];
String[] gridParameterValues = tokens[i].split("=", -1)[1].split("\\|", -1);
switch (gridParameterName) {
case "_alpha":
glmHyperParms.put("_alpha", stringArrayToDoubleAA(gridParameterValues));
break;
case "_lambda":
glmHyperParms.put("_lambda", stringArrayToDoubleAA(gridParameterValues));
break;
default:
throw new Exception(gridParameterName + " grid parameter is not supported for glm test cases");
}
}
return glmHyperParms;
}
private GBMModel.GBMParameters makeGbmModelParameters() throws Exception {
AccuracyTestingSuite.summaryLog.println("Making GBM model parameters.");
GBMModel.GBMParameters gbmParams = new GBMModel.GBMParameters();
String[] tokens = algoParameters.trim().split(";", -1);
for (int i = 0; i < tokens.length; i++) {
String parameterName = tokens[i].split("=", -1)[0];
String parameterValue = tokens[i].split("=", -1)[1];
switch (parameterName) {
case "_distribution":
switch (parameterValue) {
case "AUTO":
gbmParams._distribution = DistributionFamily.AUTO;
break;
case "gaussian":
gbmParams._distribution = DistributionFamily.gaussian;
break;
case "bernoulli":
gbmParams._distribution = DistributionFamily.bernoulli;
break;
case "multinomial":
gbmParams._distribution = DistributionFamily.multinomial;
break;
case "poisson":
gbmParams._distribution = DistributionFamily.poisson;
break;
case "gamma":
gbmParams._distribution = DistributionFamily.gamma;
break;
case "tweedie":
gbmParams._distribution = DistributionFamily.tweedie;
break;
default:
throw new Exception(parameterValue + " distribution is not supported for gbm test cases");
}
break;
case "_histogram_type":
switch (parameterValue) {
case "AUTO":
gbmParams._histogram_type = SharedTreeModel.SharedTreeParameters.HistogramType.AUTO;
break;
case "UniformAdaptive":
gbmParams._histogram_type = SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive;
break;
case "Random":
gbmParams._histogram_type = SharedTreeModel.SharedTreeParameters.HistogramType.Random;
break;
default:
throw new Exception(parameterValue + " histogram_type is not supported for gbm test cases");
}
break;
case "_nfolds":
gbmParams._nfolds = Integer.parseInt(parameterValue);
break;
case "_fold_column":
gbmParams._fold_column = tokens[i];
break;
case "_ignore_const_cols":
gbmParams._ignore_const_cols = true;
break;
case "_offset_column":
gbmParams._offset_column = tokens[i];
break;
case "_weights_column":
gbmParams._weights_column = tokens[i];
break;
case "_ntrees":
gbmParams._ntrees = Integer.parseInt(parameterValue);
break;
case "_max_depth":
gbmParams._max_depth = Integer.parseInt(parameterValue);
break;
case "_min_rows":
gbmParams._min_rows = Double.parseDouble(parameterValue);
break;
case "_nbins":
gbmParams._nbins = Integer.parseInt(parameterValue);
break;
case "_nbins_cats":
gbmParams._nbins_cats = Integer.parseInt(parameterValue);
break;
case "_learn_rate":
gbmParams._learn_rate = Float.parseFloat(parameterValue);
break;
case "_score_each_iteration":
gbmParams._score_each_iteration = true;
break;
case "_balance_classes":
gbmParams._balance_classes = true;
break;
case "_max_confusion_matrix_size":
gbmParams._max_confusion_matrix_size = Integer.parseInt(parameterValue);
break;
case "_build_tree_one_node":
gbmParams._build_tree_one_node = true;
break;
case "_sample_rate":
gbmParams._sample_rate = Float.parseFloat(parameterValue);
break;
case "_col_sample_rate":
gbmParams._col_sample_rate = Float.parseFloat(parameterValue);
break;
case "_col_sample_rate_per_tree":
gbmParams._col_sample_rate_per_tree = Double.parseDouble(parameterValue);
break;
case "_col_sample_rate_change_per_level":
gbmParams._col_sample_rate_change_per_level = Float.parseFloat(parameterValue);
break;
case "_min_split_improvement":
gbmParams._min_split_improvement = Double.parseDouble(parameterValue);
break;
case "_learn_rate_annealing":
gbmParams._learn_rate_annealing = Double.parseDouble(parameterValue);
break;
case "_max_abs_leafnode_pred":
gbmParams._max_abs_leafnode_pred = Double.parseDouble(parameterValue);
break;
case "_score_tree_interval":
gbmParams._score_tree_interval = Integer.parseInt(parameterValue);
break;
default:
throw new Exception(parameterName + " parameter is not supported for gbm test cases");
}
}
// _train, _valid, _response
gbmParams._train = trainingDataSet.getFrame()._key;
gbmParams._valid = testingDataSet.getFrame()._key;
gbmParams._response_column = trainingDataSet.getFrame()._names[trainingDataSet.getResponseColumn()];
return gbmParams;
}
private HashMap<String, Object[]> makeGbmGridParameters() throws Exception {
AccuracyTestingSuite.summaryLog.println("Making GBM grid parameters.");
String[] tokens = gridParameters.trim().split(";", -1);
HashMap<String, Object[]> gbmHyperParms = new HashMap<String, Object[]>();
for (int i = 0; i < tokens.length; i++) {
String gridParameterName = tokens[i].split("=", -1)[0];
String[] gridParameterValues = tokens[i].split("=", -1)[1].split("\\|", -1);
switch (gridParameterName) {
case "_ntrees":
gbmHyperParms.put("_ntrees", stringArrayToIntegerArray(gridParameterValues));
break;
case "_max_depth":
gbmHyperParms.put("_max_depth", stringArrayToIntegerArray(gridParameterValues));
break;
case "_min_rows":
gbmHyperParms.put("_min_rows", stringArrayToDoubleArray(gridParameterValues));
break;
case "_nbins":
gbmHyperParms.put("_nbins", stringArrayToIntegerArray(gridParameterValues));
break;
case "_nbins_cats":
gbmHyperParms.put("_nbins_cats", stringArrayToIntegerArray(gridParameterValues));
break;
case "_learn_rate":
gbmHyperParms.put("_learn_rate", stringArrayToFloatArray(gridParameterValues));
break;
case "_balance_classes":
gbmHyperParms.put("_balance_classes", stringArrayToBooleanArray(gridParameterValues));
break;
case "_r2_stopping":
gbmHyperParms.put("_r2_stopping", stringArrayToDoubleArray(gridParameterValues));
break;
case "_build_tree_one_node":
gbmHyperParms.put("_build_tree_one_node", stringArrayToBooleanArray(gridParameterValues));
break;
case "_sample_rate":
gbmHyperParms.put("_sample_rate", stringArrayToFloatArray(gridParameterValues));
break;
case "_col_sample_rate":
gbmHyperParms.put("_col_sample_rate", stringArrayToFloatArray(gridParameterValues));
break;
case "_col_sample_rate_per_tree":
gbmHyperParms.put("_col_sample_rate_per_tree", stringArrayToDoubleArray(gridParameterValues));
break;
case "_col_sample_rate_change_per_level":
gbmHyperParms.put("_col_sample_rate_change_per_level", stringArrayToFloatArray(gridParameterValues));
break;
case "_min_split_improvement":
gbmHyperParms.put("_min_split_improvement", stringArrayToDoubleArray(gridParameterValues));
break;
case "_learn_rate_annealing":
gbmHyperParms.put("_learn_rate_annealing", stringArrayToDoubleArray(gridParameterValues));
break;
case "_max_abs_leafnode_pred":
gbmHyperParms.put("_max_abs_leafnode_pred", stringArrayToDoubleArray(gridParameterValues));
break;
case "_score_tree_interval":
gbmHyperParms.put("_score_tree_interval", stringArrayToIntegerArray(gridParameterValues));
break;
default:
throw new Exception(gridParameterName + " grid parameter is not supported for gbm test cases");
}
}
return gbmHyperParms;
}
private DeepLearningModel.Parameters makeDlModelParameters() throws Exception {
AccuracyTestingSuite.summaryLog.println("Making DL model parameters.");
DeepLearningModel.DeepLearningParameters dlParams = new DeepLearningModel.DeepLearningParameters();
String[] tokens = algoParameters.trim().split(";", -1);
for (int i = 0; i < tokens.length; i++) {
String parameterName = tokens[i].split("=", -1)[0];
String parameterValue = tokens[i].split("=", -1)[1];
switch (parameterName) {
case "_distribution":
switch (parameterValue) {
case "AUTO":
dlParams._distribution = DistributionFamily.AUTO;
break;
case "gaussian":
dlParams._distribution = DistributionFamily.gaussian;
break;
case "bernoulli":
dlParams._distribution = DistributionFamily.bernoulli;
break;
case "multinomial":
dlParams._distribution = DistributionFamily.multinomial;
break;
case "poisson":
dlParams._distribution = DistributionFamily.poisson;
break;
case "gamma":
dlParams._distribution = DistributionFamily.gamma;
break;
case "tweedie":
dlParams._distribution = DistributionFamily.tweedie;
break;
default:
throw new Exception(parameterValue + " distribution is not supported for gbm test cases");
}
break;
case "_activation":
switch (parameterValue) {
case "tanh":
dlParams._activation = DeepLearningModel.DeepLearningParameters.Activation.Tanh;
break;
case "tanhwithdropout":
dlParams._activation = DeepLearningModel.DeepLearningParameters.Activation.TanhWithDropout;
break;
case "rectifier":
dlParams._activation = DeepLearningModel.DeepLearningParameters.Activation.Rectifier;
break;
case "rectifierwithdropout":
dlParams._activation = DeepLearningModel.DeepLearningParameters.Activation.RectifierWithDropout;
break;
case "maxout":
dlParams._activation = DeepLearningModel.DeepLearningParameters.Activation.Maxout;
break;
case "maxoutwithdropout":
dlParams._activation = DeepLearningModel.DeepLearningParameters.Activation.MaxoutWithDropout;
break;
default:
throw new Exception(parameterValue + " activation is not supported for gbm test cases");
}
break;
case "_loss":
switch (parameterValue) {
case "AUTO":
dlParams._loss = DeepLearningModel.DeepLearningParameters.Loss.Automatic;
;
break;
case "crossentropy":
dlParams._loss = DeepLearningModel.DeepLearningParameters.Loss.CrossEntropy;
break;
case "quadratic":
dlParams._loss = DeepLearningModel.DeepLearningParameters.Loss.Quadratic;
break;
case "huber":
dlParams._loss = DeepLearningModel.DeepLearningParameters.Loss.Huber;
break;
case "modified_huber":
dlParams._loss = DeepLearningModel.DeepLearningParameters.Loss.ModifiedHuber;
break;
case "absolute":
dlParams._loss = DeepLearningModel.DeepLearningParameters.Loss.Absolute;
break;
default:
throw new Exception(parameterValue + " loss is not supported for gbm test cases");
}
break;
case "_nfolds":
dlParams._nfolds = Integer.parseInt(parameterValue);
break;
case "_hidden":
String[] hidden = tokens[i].trim().split(":", -1);
dlParams._hidden = stringArrayTointArray(hidden);
break;
case "_epochs":
dlParams._epochs = Double.parseDouble(parameterValue);
break;
case "_variable_importances":
dlParams._variable_importances = true;
break;
case "_fold_column":
dlParams._fold_column = tokens[i];
break;
case "_weights_column":
dlParams._weights_column = tokens[i];
break;
case "_balance_classes":
dlParams._balance_classes = true;
break;
case "_max_confusion_matrix_size":
dlParams._max_confusion_matrix_size = Integer.parseInt(parameterValue);
break;
case "_use_all_factor_levels":
dlParams._use_all_factor_levels = true;
break;
case "_train_samples_per_iteration":
dlParams._train_samples_per_iteration = Long.parseLong(parameterValue);
break;
case "_adaptive_rate":
dlParams._adaptive_rate = true;
break;
case "_input_dropout_ratio":
dlParams._input_dropout_ratio = Double.parseDouble(parameterValue);
break;
case "_l1":
dlParams._l1 = Double.parseDouble(parameterValue);
break;
case "_l2":
dlParams._l2 = Double.parseDouble(parameterValue);
break;
case "_score_interval":
dlParams._score_interval = Double.parseDouble(parameterValue);
break;
case "_score_training_samples":
dlParams._score_training_samples = Long.parseLong(parameterValue);
break;
case "_score_duty_cycle":
dlParams._score_duty_cycle = Double.parseDouble(parameterValue);
break;
case "_replicate_training_data":
dlParams._replicate_training_data = true;
break;
case "_autoencoder":
dlParams._autoencoder = true;
break;
case "_target_ratio_comm_to_comp":
dlParams._target_ratio_comm_to_comp = Double.parseDouble(parameterValue);
break;
case "_seed":
dlParams._seed = Long.parseLong(parameterValue);
break;
case "_rho":
dlParams._rho = Double.parseDouble(parameterValue);
break;
case "_epsilon":
dlParams._epsilon = Double.parseDouble(parameterValue);
break;
case "_max_w2":
dlParams._max_w2 = Float.parseFloat(parameterValue);
break;
case "_regression_stop":
dlParams._regression_stop = Double.parseDouble(parameterValue);
break;
case "_diagnostics":
dlParams._diagnostics = true;
break;
case "_fast_mode":
dlParams._fast_mode = true;
break;
case "_force_load_balance":
dlParams._force_load_balance = true;
break;
case "_single_node_mode":
dlParams._single_node_mode = true;
break;
case "_shuffle_training_data":
dlParams._shuffle_training_data = true;
break;
case "_quiet_mode":
dlParams._quiet_mode = true;
break;
case "_sparse":
dlParams._sparse = true;
break;
case "_col_major":
dlParams._col_major = true;
break;
case "_average_activation":
dlParams._average_activation = Double.parseDouble(parameterValue);
break;
case "_sparsity_beta":
dlParams._sparsity_beta = Double.parseDouble(parameterValue);
break;
case "_max_categorical_features":
dlParams._max_categorical_features = Integer.parseInt(parameterValue);
break;
case "_reproducible":
dlParams._reproducible = true;
break;
case "_export_weights_and_biases":
dlParams._export_weights_and_biases = true;
break;
default:
throw new Exception(parameterName + " parameter is not supported for dl test cases");
}
}
// _train, _valid, _response
dlParams._train = trainingDataSet.getFrame()._key;
dlParams._valid = testingDataSet.getFrame()._key;
dlParams._response_column = trainingDataSet.getFrame()._names[trainingDataSet.getResponseColumn()];
return dlParams;
}
private HashMap<String, Object[]> makeDlGridParameters() throws Exception {
AccuracyTestingSuite.summaryLog.println("Making DL grid parameters.");
String[] tokens = gridParameters.trim().split(";", -1);
HashMap<String, Object[]> dlHyperParms = new HashMap<String, Object[]>();
for (int i = 0; i < tokens.length; i++) {
if (tokens[i].equals("")) return dlHyperParms;
String gridParameterName = tokens[i].split("=", -1)[0];
String[] gridParameterValues = tokens[i].split("=", -1)[1].split("\\|", -1);
switch (gridParameterName) {
case "_hidden":
dlHyperParms.put("_hidden", hiddenStringArrayTointAA(gridParameterValues));
break;
case "_epochs":
dlHyperParms.put("_epochs", stringArrayToIntegerArray(gridParameterValues));
break;
default:
throw new Exception(gridParameterName + " grid parameter is not supported for dl test cases");
}
}
return dlHyperParms;
}
private DRFModel.DRFParameters makeDrfModelParameters() throws Exception {
AccuracyTestingSuite.summaryLog.println("Making DRF model parameters.");
DRFModel.DRFParameters drfParams = new DRFModel.DRFParameters();
String[] tokens = algoParameters.trim().split(";", -1);
for (int i = 0; i < tokens.length; i++) {
String parameterName = tokens[i].split("=", -1)[0];
String parameterValue = tokens[i].split("=", -1)[1];
switch (parameterName) {
case "_distribution":
switch (parameterValue) {
case "AUTO":
drfParams._distribution = DistributionFamily.AUTO;
break;
case "gaussian":
drfParams._distribution = DistributionFamily.gaussian;
break;
case "bernoulli":
drfParams._distribution = DistributionFamily.bernoulli;
break;
case "multinomial":
drfParams._distribution = DistributionFamily.multinomial;
break;
case "poisson":
drfParams._distribution = DistributionFamily.poisson;
break;
case "gamma":
drfParams._distribution = DistributionFamily.gamma;
break;
case "tweedie":
drfParams._distribution = DistributionFamily.tweedie;
break;
default:
throw new Exception(parameterValue + " distribution is not supported for gbm test cases");
}
break;
case "_histogram_type":
switch (parameterValue) {
case "AUTO":
drfParams._histogram_type = SharedTreeModel.SharedTreeParameters.HistogramType.AUTO;
break;
case "UniformAdaptive":
drfParams._histogram_type = SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive;
break;
case "Random":
drfParams._histogram_type = SharedTreeModel.SharedTreeParameters.HistogramType.Random;
break;
default:
throw new Exception(parameterValue + " histogram_type is not supported for gbm test cases");
}
break;
case "_nfolds":
drfParams._nfolds = Integer.parseInt(parameterValue);
break;
case "_fold_column":
drfParams._fold_column = tokens[i];
break;
case "_ignore_const_cols":
drfParams._ignore_const_cols = true;
break;
case "_offset_column":
drfParams._offset_column = tokens[i];
break;
case "_weights_column":
drfParams._weights_column = tokens[i];
break;
case "_ntrees":
drfParams._ntrees = Integer.parseInt(parameterValue);
break;
case "_max_depth":
drfParams._max_depth = Integer.parseInt(parameterValue);
break;
case "_min_rows":
drfParams._min_rows = Double.parseDouble(parameterValue);
break;
case "_nbins":
drfParams._nbins = Integer.parseInt(parameterValue);
break;
case "_nbins_cats":
drfParams._nbins_cats = Integer.parseInt(parameterValue);
break;
case "_score_each_iteration":
drfParams._score_each_iteration = true;
break;
case "_balance_classes":
drfParams._balance_classes = true;
break;
case "_max_confusion_matrix_size":
drfParams._max_confusion_matrix_size = Integer.parseInt(parameterValue);
break;
case "_build_tree_one_node":
drfParams._build_tree_one_node = true;
break;
case "_binomial_double_trees":
drfParams._binomial_double_trees = true;
break;
case "_nbins_top_level":
drfParams._nbins_top_level = Integer.parseInt(parameterValue);
break;
default:
throw new Exception(parameterName + " parameter is not supported for gbm test cases");
}
}
// _train, _valid, _response
drfParams._train = trainingDataSet.getFrame()._key;
drfParams._valid = testingDataSet.getFrame()._key;
drfParams._response_column = trainingDataSet.getFrame()._names[trainingDataSet.getResponseColumn()];
return drfParams;
}
private HashMap<String, Object[]> makeDrfGridParameters() throws Exception {
AccuracyTestingSuite.summaryLog.println("Making DRF grid parameters.");
String[] tokens = gridParameters.trim().split(";", -1);
HashMap<String, Object[]> drfHyperParms = new HashMap<String, Object[]>();
for (int i = 0; i < tokens.length; i++) {
String gridParameterName = tokens[i].split("=", -1)[0];
String[] gridParameterValues = tokens[i].split("=", -1)[1].split("\\|", -1);
switch (gridParameterName) {
case "_ntrees":
drfHyperParms.put("_ntrees", stringArrayToIntegerArray(gridParameterValues));
break;
case "_max_depth":
drfHyperParms.put("_max_depth", stringArrayToIntegerArray(gridParameterValues));
break;
case "_min_rows":
drfHyperParms.put("_min_rows", stringArrayToDoubleArray(gridParameterValues));
break;
case "_nbins":
drfHyperParms.put("_nbins", stringArrayToIntegerArray(gridParameterValues));
break;
case "_nbins_cats":
drfHyperParms.put("_nbins_cats", stringArrayToIntegerArray(gridParameterValues));
break;
case "_balance_classes":
drfHyperParms.put("_balance_classes", stringArrayToBooleanArray(gridParameterValues));
break;
case "_r2_stopping":
drfHyperParms.put("_r2_stopping", stringArrayToDoubleArray(gridParameterValues));
break;
case "_build_tree_one_node":
drfHyperParms.put("_build_tree_one_node", stringArrayToBooleanArray(gridParameterValues));
break;
case "_mtries":
drfHyperParms.put("_mtries", stringArrayToIntegerArray(gridParameterValues));
break;
case "_sample_rate":
drfHyperParms.put("_sample_rate", stringArrayToFloatArray(gridParameterValues));
break;
case "_binomial_double_trees":
drfHyperParms.put("_binomial_double_trees", stringArrayToBooleanArray(gridParameterValues));
break;
case "_col_sample_rate_per_tree":
drfHyperParms.put("_col_sample_rate_per_tree", stringArrayToFloatArray(gridParameterValues));
break;
case "_min_split_improvement":
drfHyperParms.put("_min_split_improvement", stringArrayToDoubleArray(gridParameterValues));
break;
default:
throw new Exception(gridParameterName + " grid parameter is not supported for drf test cases");
}
}
return drfHyperParms;
}
static Integer[] stringArrayToIntegerArray(String[] sa) {
Integer[] ia = new Integer[sa.length];
for (int v = 0; v < sa.length; v++) ia[v] = Integer.parseInt(sa[v]);
return ia;
}
static int[] stringArrayTointArray(String[] sa) {
int[] ia = new int[sa.length];
for (int v = 0; v < sa.length; v++) ia[v] = Integer.parseInt(sa[v]);
return ia;
}
static int[][] hiddenStringArrayTointAA(String[] sa) {
int[][] iaa = new int[sa.length][];
for (int h=0; h<sa.length; h++) iaa[h] = stringArrayTointArray(sa[h].trim().split(":", -1));
return iaa;
}
static Double[] stringArrayToDoubleArray(String[] sa) {
Double[] da = new Double[sa.length];
for (int v = 0; v < sa.length; v++) da[v] = Double.parseDouble(sa[v]);
return da;
}
static double[][] stringArrayToDoubleAA(String[] sa) {
double[][] daa = new double[sa.length][1];
for (int v = 0; v < sa.length; v++) daa[v] = new double[]{Double.parseDouble(sa[v])};
return daa;
}
static Float[] stringArrayToFloatArray(String[] sa) {
Float[] fa = new Float[sa.length];
for (int v = 0; v < sa.length; v++) fa[v] = Float.parseFloat(sa[v]);
return fa;
}
static Boolean[] stringArrayToBooleanArray(String[] sa) {
Boolean[] ba = new Boolean[sa.length];
for (int v = 0; v < sa.length; v++) ba[v] = Boolean.parseBoolean(sa[v]);
return ba;
}
static boolean higherIsBetter(String metric) {
return metric.equals("R2") || metric.equals("AUC") || metric.equals("Precision") || metric.equals("Recall") ||
metric.equals("F1") || metric.equals("F2") || metric.equals("F0point5") || metric.equals("Accuracy") ||
metric.equals("Gini") || metric.equals("MCC");
}
}