package hex.tree.drf;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import hex.Model;
import hex.grid.Grid;
import hex.grid.GridSearch;
import water.DKV;
import water.Job;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.test.util.GridTestUtils;
import water.util.ArrayUtils;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static water.util.ArrayUtils.interval;
public class DRFGridTest extends TestUtil {
@BeforeClass() public static void setup() { stall_till_cloudsize(1); }
@Test public void testCarsGrid() {
Grid<DRFModel.DRFParameters> grid = null;
Frame fr = null;
Vec old = null;
try {
fr = parse_test_file("smalldata/junit/cars.csv");
fr.remove("name").remove(); // Remove unique id
old = fr.remove("cylinders");
fr.add("cylinders", old.toCategoricalVec()); // response to last column
DKV.put(fr);
// Setup hyperparameter search space
final Double[] legalSampleRateOpts = new Double[]{0.5};
final Double[] illegalSampleRateOpts = new Double[]{2.0};
HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>() {{
put("_ntrees", new Integer[]{2, 4});
put("_max_depth", new Integer[]{10, 20});
put("_mtries", new Integer[]{-1, 4});
put("_sample_rate", ArrayUtils.join(legalSampleRateOpts, illegalSampleRateOpts));
}};
// Name of used hyper parameters
String[] hyperParamNames = hyperParms.keySet().toArray(new String[hyperParms.size()]);
Arrays.sort(hyperParamNames);
int hyperSpaceSize = ArrayUtils.crossProductSize(hyperParms);
// Fire off a grid search
DRFModel.DRFParameters params = new DRFModel.DRFParameters();
params._train = fr._key;
params._response_column = "cylinders";
// Get the Grid for this modeling class and frame
Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
grid = (Grid<DRFModel.DRFParameters>) gs.get();
// Make sure number of produced models match size of specified hyper space
Assert.assertEquals("Size of grid should match to size of hyper space", hyperSpaceSize,
grid.getModelCount() + grid.getFailureCount());
//
// Make sure that names of used parameters match
//
String[] gridHyperNames = grid.getHyperNames();
Arrays.sort(gridHyperNames);
Assert.assertArrayEquals("Hyper parameters names should match!", hyperParamNames,
gridHyperNames);
//
// Make sure that values of used parameters match as well to the specified values
//
Model[] ms = grid.getModels();
Map<String, Set<Object>> usedModelParams = GridTestUtils.initMap(hyperParamNames);
for (Model m : ms) {
DRFModel drf = (DRFModel) m;
System.out.println(
drf._output._scored_train[drf._output._ntrees]._mse + " " + Arrays.deepToString(
ArrayUtils.zip(grid.getHyperNames(), grid.getHyperValues(drf._parms))));
GridTestUtils.extractParams(usedModelParams, drf._parms, hyperParamNames);
}
hyperParms.put("_sample_rate", legalSampleRateOpts);
GridTestUtils.assertParamsEqual("Grid models parameters have to cover specified hyper space",
hyperParms,
usedModelParams);
// Verify model failure
Map<String, Set<Object>> failedHyperParams = GridTestUtils.initMap(hyperParamNames);
for (Model.Parameters failedParams : grid.getFailedParameters()) {
GridTestUtils.extractParams(failedHyperParams, failedParams, hyperParamNames);
}
hyperParms.put("_sample_rate", illegalSampleRateOpts);
GridTestUtils
.assertParamsEqual("Failed model parameters have to correspond to specified hyper space",
hyperParms,
failedHyperParams);
} finally {
if (old != null) {
old.remove();
}
if (fr != null) {
fr.remove();
}
if (grid != null) {
grid.remove();
}
}
}
//@Ignore("PUBDEV-1643")
@Test
public void testDuplicatesCarsGrid() {
Grid grid = null;
Frame fr = null;
Vec old = null;
try {
fr = parse_test_file("smalldata/junit/cars_20mpg.csv");
fr.remove("name").remove(); // Remove unique id
old = fr.remove("economy");
fr.add("economy", old); // response to last column
DKV.put(fr);
// Setup random hyperparameter search space
HashMap<String, Object[]> hyperParms = new HashMap<String, Object[]>() {{
put("_ntrees", new Integer[]{5, 5});
put("_max_depth", new Integer[]{2, 2});
put("_mtries", new Integer[]{-1, -1});
put("_sample_rate", new Double[]{.1, .1});
}};
// Fire off a grid search
DRFModel.DRFParameters params = new DRFModel.DRFParameters();
params._train = fr._key;
params._response_column = "economy";
// Get the Grid for this modeling class and frame
Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
grid = gs.get();
// Check that duplicate model have not been constructed
Model[] models = grid.getModels();
assertTrue("Number of returned models has to be > 0", models.length > 0);
// But all off them should be same
Key<Model> modelKey = models[0]._key;
for (Model m : models) {
assertTrue("Number of constructed models has to be equal to 1", modelKey == m._key);
}
} finally {
if (old != null) {
old.remove();
}
if (fr != null) {
fr.remove();
}
if (grid != null) {
grid.remove();
}
}
}
//@Ignore("PUBDEV-1648")
@Test
public void testRandomCarsGrid() {
Grid grid = null;
DRFModel drfRebuilt = null;
Frame fr = null;
try {
fr = parse_test_file("smalldata/junit/cars.csv");
fr.remove("name").remove();
Vec old = fr.remove("economy (mpg)");
fr.add("economy (mpg)", old); // response to last column
DKV.put(fr);
// Setup random hyperparameter search space
HashMap<String, Object[]> hyperParms = new HashMap<>();
// Construct random grid search space
long seed = System.nanoTime();
Random rng = new Random(seed);
// Limit to 1-3 randomly, 4 times. Average total number of models is
// 2^4, or 16. Max is 81 models.
Integer ntreesDim = rng.nextInt(3) + 1;
Integer maxDepthDim = rng.nextInt(3) + 1;
Integer mtriesDim = rng.nextInt(3) + 1;
Integer sampleRateDim = rng.nextInt(3) + 1;
Integer[] ntreesArr = interval(1, 15);
ArrayList<Integer> ntreesList = new ArrayList<>(Arrays.asList(ntreesArr));
Collections.shuffle(ntreesList);
Integer[] ntreesSpace = new Integer[ntreesDim];
for (int i = 0; i < ntreesDim; i++) {
ntreesSpace[i] = ntreesList.get(i);
}
Integer[] maxDepthArr = interval(1, 10);
ArrayList<Integer> maxDepthList = new ArrayList<>(Arrays.asList(maxDepthArr));
Collections.shuffle(maxDepthList);
Integer[] maxDepthSpace = new Integer[maxDepthDim];
for (int i = 0; i < maxDepthDim; i++) {
maxDepthSpace[i] = maxDepthList.get(i);
}
Integer[] mtriesArr = interval(1, 5);
ArrayList<Integer> mtriesList = new ArrayList<>(Arrays.asList(mtriesArr));
Collections.shuffle(mtriesList);
Integer[] mtriesSpace = new Integer[mtriesDim];
for (int i = 0; i < mtriesDim; i++) {
mtriesSpace[i] = mtriesList.get(i);
}
Double[] sampleRateArr = interval(0.01, 0.99, 0.01);
ArrayList<Double> sampleRateList = new ArrayList<>(Arrays.asList(sampleRateArr));
Collections.shuffle(sampleRateList);
Double[] sampleRateSpace = new Double[sampleRateDim];
for (int i = 0; i < sampleRateDim; i++) {
sampleRateSpace[i] = sampleRateList.get(i);
}
hyperParms.put("_ntrees", ntreesSpace);
hyperParms.put("_max_depth", maxDepthSpace);
hyperParms.put("_mtries", mtriesSpace);
hyperParms.put("_sample_rate", sampleRateSpace);
// Fire off a grid search
DRFModel.DRFParameters params = new DRFModel.DRFParameters();
params._train = fr._key;
params._response_column = "economy (mpg)";
// Get the Grid for this modeling class and frame
Job<Grid> gs = GridSearch.startGridSearch(null, params, hyperParms);
grid = gs.get();
System.out.println("Test seed: " + seed);
System.out.println("ntrees search space: " + Arrays.toString(ntreesSpace));
System.out.println("max_depth search space: " + Arrays.toString(maxDepthSpace));
System.out.println("mtries search space: " + Arrays.toString(mtriesSpace));
System.out.println("sample_rate search space: " + Arrays.toString(sampleRateSpace));
// Check that cardinality of grid
Model[] ms = grid.getModels();
int numModels = ms.length;
System.out.println("Grid consists of " + numModels + " models");
assertEquals("Number of models should match hyper space size", numModels,
ntreesDim * maxDepthDim * sampleRateDim * mtriesDim + grid.getFailureCount());
// Pick a random model from the grid
HashMap<String, Object[]> randomHyperParms = new HashMap<>();
Integer ntreeVal = ntreesSpace[rng.nextInt(ntreesSpace.length)];
randomHyperParms.put("_ntrees", new Integer[]{ntreeVal});
Integer maxDepthVal = maxDepthSpace[rng.nextInt(maxDepthSpace.length)];
randomHyperParms.put("_max_depth", maxDepthSpace);
Integer mtriesVal = mtriesSpace[rng.nextInt(mtriesSpace.length)];
randomHyperParms.put("_max_depth", mtriesSpace);
Double sampleRateVal = sampleRateSpace[rng.nextInt(sampleRateSpace.length)];
randomHyperParms.put("_sample_rate", sampleRateSpace);
//TODO: DRFModel drfFromGrid = (DRFModel) g2.model(randomHyperParms).get();
// Rebuild it with it's parameters
params._ntrees = ntreeVal;
params._max_depth = maxDepthVal;
params._mtries = mtriesVal;
drfRebuilt = new DRF(params).trainModel().get();
// Make sure the MSE metrics match
//double fromGridMSE = drfFromGrid._output._scored_train[drfFromGrid._output._ntrees]._mse;
double rebuiltMSE = drfRebuilt._output._scored_train[drfRebuilt._output._ntrees]._mse;
//System.out.println("The random grid model's MSE: " + fromGridMSE);
System.out.println("The rebuilt model's MSE: " + rebuiltMSE);
//assertEquals(fromGridMSE, rebuiltMSE);
} finally {
if (fr != null) {
fr.remove();
}
if (grid != null) {
grid.remove();
}
if (drfRebuilt != null) {
drfRebuilt.remove();
}
}
}
@Test
public void testCollisionOfDRFParamsChecksum() {
Frame fr = null;
try {
fr = parse_test_file("smalldata/junit/cars.csv");
fr.remove("name").remove();
Vec old = fr.remove("economy (mpg)");
fr.add("economy (mpg)", old); // response to last column
DKV.put(fr);
DRFModel.DRFParameters params1 = new DRFModel.DRFParameters();
params1._train = fr._key;
params1._response_column = "economy (mpg)";
params1._seed = -4522296119273841674L;
params1._mtries = 3;
params1._max_depth = 15;
params1._ntrees = 9;
params1._sample_rate = 0.6499997f;
DRFModel.DRFParameters params2 = new DRFModel.DRFParameters();
params2._train = fr._key;
params2._response_column = "economy (mpg)";
params2._seed = -4522296119273841674L;
params2._mtries = 1;
params2._max_depth = 1;
params2._ntrees = 13;
params2._sample_rate = 0.6499997f;
long csum1 = params1.checksum();
long csum2 = params2.checksum();
Assert.assertNotEquals("Checksums shoudl be different", csum1, csum2);
} finally {
if (fr != null) {
fr.remove();
}
}
}
}