package hex.tree.drf;
import hex.Model;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsRegression;
import hex.SplitFrame;
import hex.tree.SharedTreeModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.*;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.Triple;
import water.util.VecUtils;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.util.*;
import static org.junit.Assert.assertEquals;
public class DRFTest extends TestUtil {
@BeforeClass public static void stall() { stall_till_cloudsize(1); }
abstract static class PrepData { abstract int prep(Frame fr); }
static String[] s(String...arr) { return arr; }
@Test public void testClassIris1() throws Throwable {
// iris ntree=1
// the DRF should use only subset of rows since it is using oob validation
basicDRFTestOOBE_Classification(
"./smalldata/iris/iris.csv", "iris.hex",
new PrepData() {
@Override
int prep(Frame fr) {
return fr.numCols() - 1;
}
},
1,
20,
1,
20,
ard(ard(15, 0, 0),
ard(0, 18, 0),
ard(0, 1, 17)),
s("Iris-setosa", "Iris-versicolor", "Iris-virginica"));
}
@Test public void testClassIris5() throws Throwable {
// iris ntree=50
basicDRFTestOOBE_Classification(
"./smalldata/iris/iris.csv", "iris5.hex",
new PrepData() {
@Override
int prep(Frame fr) {
return fr.numCols() - 1;
}
},
5,
20,
1,
20,
ard(ard(43, 0, 0),
ard(0, 37, 4),
ard(0, 4, 39)),
s("Iris-setosa", "Iris-versicolor", "Iris-virginica"));
}
@Test public void testClassCars1() throws Throwable {
// cars ntree=1
basicDRFTestOOBE_Classification(
"./smalldata/junit/cars.csv", "cars.hex",
new PrepData() {
@Override
int prep(Frame fr) {
fr.remove("name").remove();
return fr.find("cylinders");
}
},
1,
20,
1,
20,
ard(ard(0, 2, 0, 0, 0),
ard(0, 58, 6, 4, 0),
ard(0, 1, 0, 0, 0),
ard(1, 3, 4, 25, 1),
ard(0, 0, 0, 2, 37)),
s("3", "4", "5", "6", "8"));
}
@Test public void testClassCars5() throws Throwable {
basicDRFTestOOBE_Classification(
"./smalldata/junit/cars.csv", "cars5.hex",
new PrepData() {
@Override
int prep(Frame fr) {
fr.remove("name").remove();
return fr.find("cylinders");
}
},
5,
20,
1,
20,
ard(ard(1, 2, 0, 0, 0),
ard(0, 177, 1, 5, 0),
ard(0, 2, 0, 0, 0),
ard(0, 6, 1, 67, 1),
ard(0, 0, 0, 2, 84)),
s("3", "4", "5", "6", "8"));
}
@Test public void testConstantCols() throws Throwable {
try {
basicDRFTestOOBE_Classification(
"./smalldata/poker/poker100", "poker.hex",
new PrepData() {
@Override
int prep(Frame fr) {
for (int i = 0; i < 7; i++) {
fr.remove(3).remove();
}
return 3;
}
},
1,
20,
1,
20,
null,
null);
Assert.fail();
} catch( H2OModelBuilderIllegalArgumentException iae ) {
/*pass*/
}
}
@Ignore @Test public void testBadData() throws Throwable {
basicDRFTestOOBE_Classification(
"./smalldata/junit/drf_infinities.csv", "infinitys.hex",
new PrepData() { @Override int prep(Frame fr) { return fr.find("DateofBirth"); } },
1,
20,
1,
20,
ard(ard(6, 0),
ard(9, 1)),
s("0", "1"));
}
//@Test
public void testCreditSample1() throws Throwable {
basicDRFTestOOBE_Classification(
"./smalldata/kaggle/creditsample-training.csv.gz", "credit.hex",
new PrepData() {
@Override
int prep(Frame fr) {
fr.remove("MonthlyIncome").remove();
return fr.find("SeriousDlqin2yrs");
}
},
1,
20,
1,
20,
ard(ard(46294, 202),
ard(3187, 107)),
s("0", "1"));
}
@Test public void testCreditProstate1() throws Throwable {
basicDRFTestOOBE_Classification(
"./smalldata/logreg/prostate.csv", "prostate.hex",
new PrepData() {
@Override
int prep(Frame fr) {
fr.remove("ID").remove();
return fr.find("CAPSULE");
}
},
1,
20,
1,
20,
ard(ard(0, 70),
ard(0, 59)),
s("0", "1"));
}
@Test public void testCreditProstateRegression1() throws Throwable {
basicDRFTestOOBE_Regression(
"./smalldata/logreg/prostate.csv", "prostateRegression.hex",
new PrepData() {
@Override
int prep(Frame fr) {
fr.remove("ID").remove();
return fr.find("AGE");
}
},
1,
20,
1,
10,
63.13182273942728
);
}
@Test public void testCreditProstateRegression5() throws Throwable {
basicDRFTestOOBE_Regression(
"./smalldata/logreg/prostate.csv", "prostateRegression5.hex",
new PrepData() {
@Override
int prep(Frame fr) {
fr.remove("ID").remove();
return fr.find("AGE");
}
},
5,
20,
1,
10,
59.713095855920244
);
}
@Test public void testCreditProstateRegression50() throws Throwable {
basicDRFTestOOBE_Regression(
"./smalldata/logreg/prostate.csv", "prostateRegression50.hex",
new PrepData() {
@Override
int prep(Frame fr) {
fr.remove("ID").remove();
return fr.find("AGE");
}
},
50,
20,
1,
10,
47.00716017021814
);
}
@Test public void testCzechboard() throws Throwable {
basicDRFTestOOBE_Classification(
"./smalldata/gbm_test/czechboard_300x300.csv", "czechboard_300x300.hex",
new PrepData() {
@Override
int prep(Frame fr) {
Vec resp = fr.remove("C2");
fr.add("C2", VecUtils.toCategoricalVec(resp));
resp.remove();
return fr.find("C3");
}
},
50,
20,
1,
20,
ard(ard(0, 45000),
ard(0, 45000)),
s("0", "1"));
}
@Test public void test30kUnseenLevels() throws Throwable {
basicDRFTestOOBE_Regression(
"./smalldata/gbm_test/30k_cattest.csv", "cat30k",
new PrepData() {
@Override
int prep(Frame fr) {
return fr.find("C3");
}
},
50, //ntrees
20, //bins
10, //min_rows
5, //max_depth
0.25040633586487);
}
@Test public void testProstate() throws Throwable {
basicDRFTestOOBE_Classification(
"./smalldata/prostate/prostate.csv.zip", "prostate2.zip.hex",
new PrepData() {
@Override
int prep(Frame fr) {
String[] names = fr.names().clone();
Vec[] en = fr.remove(new int[]{1,4,5,8});
fr.add(names[1], VecUtils.toCategoricalVec(en[0])); //CAPSULE
fr.add(names[4], VecUtils.toCategoricalVec(en[1])); //DPROS
fr.add(names[5], VecUtils.toCategoricalVec(en[2])); //DCAPS
fr.add(names[8], VecUtils.toCategoricalVec(en[3])); //GLEASON
for (Vec v : en) v.remove();
fr.remove(0).remove(); //drop ID
return 4; //CAPSULE
}
},
4, //ntrees
2, //bins
1, //min_rows
1, //max_depth
null,
s("0", "1"));
}
@Test public void testAlphabet() throws Throwable {
basicDRFTestOOBE_Classification(
"./smalldata/gbm_test/alphabet_cattest.csv", "alphabetClassification.hex",
new PrepData() {
@Override
int prep(Frame fr) {
return fr.find("y");
}
},
1,
20,
1,
20,
ard(ard(670, 0),
ard(0, 703)),
s("0", "1"));
}
@Test public void testAlphabetRegression() throws Throwable {
basicDRFTestOOBE_Regression(
"./smalldata/gbm_test/alphabet_cattest.csv", "alphabetRegression.hex",
new PrepData() {
@Override
int prep(Frame fr) {
return fr.find("y");
}
},
1,
20,
1,
10,
0.0);
}
@Test public void testAlphabetRegression2() throws Throwable {
basicDRFTestOOBE_Regression(
"./smalldata/gbm_test/alphabet_cattest.csv", "alphabetRegression2.hex",
new PrepData() {
@Override
int prep(Frame fr) {
return fr.find("y");
}
},
1,
26, // enough bins to resolve the alphabet
1,
1, // depth 1 is enough since nbins_cats == nbins == 26 (enough)
0.0);
}
@Test public void testAlphabetRegression3() throws Throwable {
basicDRFTestOOBE_Regression(
"./smalldata/gbm_test/alphabet_cattest.csv", "alphabetRegression3.hex",
new PrepData() {
@Override
int prep(Frame fr) {
return fr.find("y");
}
},
1,
25, // not enough bins to resolve the alphabet
1,
1, // depth 1 is not enough since nbins_cats == nbins < 26
0.24007225096411577);
}
@Ignore //1-vs-5 node discrepancy (parsing into different number of chunks?)
@Test public void testAirlines() throws Throwable {
basicDRFTestOOBE_Classification(
"./smalldata/airlines/allyears2k_headers.zip", "airlines.hex",
new PrepData() {
@Override
int prep(Frame fr) {
for (String s : new String[]{
"DepTime", "ArrTime", "ActualElapsedTime",
"AirTime", "ArrDelay", "DepDelay", "Cancelled",
"CancellationCode", "CarrierDelay", "WeatherDelay",
"NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"
}) {
fr.remove(s).remove();
}
return fr.find("IsDepDelayed");
}
},
7,
20, 1, 20, ard(ard(7958, 11707), //1-node
ard(2709, 19024)),
// a(a(7841, 11822), //5-node
// a(2666, 19053)),
s("NO", "YES"));
}
// Put response as the last vector in the frame and return possible frames to clean up later
// Also fill DRF.
static Vec unifyFrame(DRFModel.DRFParameters drf, Frame fr, PrepData prep, boolean classification) {
int idx = prep.prep(fr);
if( idx < 0 ) { idx = ~idx; }
String rname = fr._names[idx];
drf._response_column = fr.names()[idx];
Vec resp = fr.vecs()[idx];
Vec ret = null;
if (classification) {
ret = fr.remove(idx);
fr.add(rname, VecUtils.toCategoricalVec(resp));
} else {
fr.remove(idx);
fr.add(rname,resp);
}
return ret;
}
public void basicDRFTestOOBE_Classification(String fnametrain, String hexnametrain, PrepData prep, int ntree, int nbins, int min_rows, int max_depth, double[][] expCM, String[] expRespDom) throws Throwable {
basicDRF(fnametrain, hexnametrain, null, prep, ntree, max_depth, nbins, true, min_rows, expCM, -1, expRespDom);
}
public void basicDRFTestOOBE_Regression(String fnametrain, String hexnametrain, PrepData prep, int ntree, int nbins, int min_rows, int max_depth, double expMSE) throws Throwable {
basicDRF(fnametrain, hexnametrain, null, prep, ntree, max_depth, nbins, false, min_rows, null, expMSE, null);
}
public void basicDRF(String fnametrain, String hexnametrain, String fnametest, PrepData prep, int ntree, int max_depth, int nbins, boolean classification, int min_rows, double[][] expCM, double expMSE, String[] expRespDom) throws Throwable {
Scope.enter();
DRFModel.DRFParameters drf = new DRFModel.DRFParameters();
Frame frTest = null, pred = null;
Frame frTrain = null;
Frame test = null, res = null;
DRFModel model = null;
try {
frTrain = parse_test_file(fnametrain);
Vec removeme = unifyFrame(drf, frTrain, prep, classification);
if (removeme != null) Scope.track(removeme);
DKV.put(frTrain._key, frTrain);
// Configure DRF
drf._train = frTrain._key;
drf._response_column = ((Frame)DKV.getGet(drf._train)).lastVecName();
drf._ntrees = ntree;
drf._max_depth = max_depth;
drf._min_rows = min_rows;
drf._stopping_rounds = 0; //no early stopping
// drf._binomial_double_trees = new Random().nextBoolean();
drf._nbins = nbins;
drf._nbins_cats = nbins;
drf._mtries = -1;
drf._sample_rate = 0.66667f; // Simulated sampling with replacement
drf._seed = (1L<<32)|2;
// Invoke DRF and block till the end
DRF job = new DRF(drf);
// Get the model
model = job.trainModel().get();
// StreamingSchema ss = new StreamingSchema(model.getMojo(), "model.zip");
// FileOutputStream fos = new FileOutputStream("model.zip");
// ss.getStreamWriter().writeTo(fos);
Log.info(model._output);
Assert.assertTrue(job.isStopped()); //HEX-1817
hex.ModelMetrics mm;
if (fnametest != null) {
frTest = parse_test_file(fnametest);
pred = model.score(frTest);
mm = hex.ModelMetrics.getFromDKV(model, frTest);
// Check test set CM
} else {
mm = hex.ModelMetrics.getFromDKV(model, frTrain);
}
Assert.assertEquals("Number of trees differs!", ntree, model._output._ntrees);
test = parse_test_file(fnametrain);
res = model.score(test);
// Build a POJO, validate same results
Assert.assertTrue(model.testJavaScoring(test,res,1e-15));
if (classification && expCM != null) {
Assert.assertTrue("Expected: " + Arrays.deepToString(expCM) + ", Got: " + Arrays.deepToString(mm.cm()._cm),
Arrays.deepEquals(mm.cm()._cm, expCM));
String[] cmDom = model._output._domains[model._output._domains.length - 1];
Assert.assertArrayEquals("CM domain differs!", expRespDom, cmDom);
Log.info("\nOOB Training CM:\n" + mm.cm().toASCII());
Log.info("\nTraining CM:\n" + hex.ModelMetrics.getFromDKV(model, test).cm().toASCII());
} else if (!classification) {
Assert.assertTrue("Expected: " + expMSE + ", Got: " + mm.mse(), Math.abs(expMSE-mm.mse()) <= 1e-10*Math.abs(expMSE+mm.mse()));
Log.info("\nOOB Training MSE: " + mm.mse());
Log.info("\nTraining MSE: " + hex.ModelMetrics.getFromDKV(model, test).mse());
}
hex.ModelMetrics.getFromDKV(model, test);
} finally {
if (frTrain!=null) frTrain.remove();
if (frTest!=null) frTest.remove();
if( model != null ) model.delete(); // Remove the model
if( pred != null ) pred.delete();
if( test != null ) test.delete();
if( res != null ) res.delete();
Scope.exit();
}
}
@Ignore
@Test public void testAutoRebalance() {
//First pass to warm up
boolean warmUp = true;
if (warmUp) {
int[] warmUpChunks = {1, 2, 3, 4, 5};
for (int chunk : warmUpChunks) {
Frame tfr = null;
Scope.enter();
try {
// Load data, hack frames
tfr = parse_test_file("/Users/ludirehak/Downloads/train.csv.zip");
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "Sales";
parms._nbins = 1000;
parms._ntrees = 10;
parms._max_depth = 20;
parms._mtries = -1;
parms._min_rows = 10;
parms._seed = 1234;
// parms._rebalance_me = true;
// parms._nchunks = 22;
// Build a first model; all remaining models should be equal
DRF job = new DRF(parms);
DRFModel drf = job.trainModel().get();
drf.delete();
} finally {
if (tfr != null) tfr.remove();
}
Scope.exit();
}
}
int[] max_depths = {2,5,10,15,20};
int[] chunks = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32};
boolean[] rebalanceMes = {true};
int[] ntrees = {10};
int totalLength = chunks.length*max_depths.length*rebalanceMes.length*ntrees.length;
double[] executionTimes = new double[totalLength];
int[] outputchunks = new int[totalLength];
int[] outputdepths = new int[totalLength];
boolean[] outputrebalanceme = new boolean[totalLength];
int[] outputntrees = new int[totalLength];
int c = 0;
for (int max_depth : max_depths) {
for (int ntree: ntrees) {
for (boolean rebalanceMe: rebalanceMes) {
for (int chunk : chunks) {
long startTime = System.currentTimeMillis();
Scope.enter();
// Load data, hack frames
Frame tfr = parse_test_file("/Users/ludirehak/Downloads/train.csv.zip");
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "Sales";
parms._nbins = 1000;
parms._mtries = -1;
parms._min_rows = 10;
parms._seed = 1234;
parms._ntrees = ntree;
parms._max_depth = max_depth;
// parms._rebalance_me = rebalanceMe;
// parms._nchunks = chunk;
// Build a first model
DRF job = new DRF(parms);
DRFModel drf = job.trainModel().get();
assertEquals(drf._output._ntrees, parms._ntrees);
ModelMetricsRegression mm = (ModelMetricsRegression) drf._output._training_metrics;
int actualChunk = job.train().anyVec().nChunks();
drf.delete();
tfr.remove();
Scope.exit();
executionTimes[c] = (System.currentTimeMillis() - startTime) / 1000d;
if (!rebalanceMe) assert actualChunk == 22;
outputchunks[c] = actualChunk;
outputdepths[c] = max_depth;
outputrebalanceme[c] = rebalanceMe;
outputntrees[c] = drf._output._ntrees;
Log.info("Iteration " + (c + 1) + " out of " + executionTimes.length);
Log.info(" DEPTH: " + outputdepths[c] + " NTREES: "+ outputntrees[c] + " CHUNKS: " + outputchunks[c] + " EXECUTION TIME: " + executionTimes[c] + " Rebalanced: " + rebalanceMe + " WarmedUp: " + warmUp);
c++;
}
}
}
}
String fileName = "/Users/ludirehak/Desktop/DRFTestRebalance3.txt";
//R code for plotting: plot(chunks,execution_time,t='n',main='Execution Time of DRF on Rebalanced Data');
// for (i in 1:length(unique(max_depth))) {s = which(max_depth ==unique(max_depth)[i]);
// points(chunks[s],execution_time[s],col=i)};
// legend('topright', legend= c('max_depth',unique(max_depth)),col = 0:length(unique(max_depth)),pch=1);
try {
FileWriter fileWriter = new FileWriter(fileName);
BufferedWriter bufferedWriter = new BufferedWriter(fileWriter);
bufferedWriter.write("max_depth,ntrees,nbins,min_rows,chunks,execution_time,rebalanceMe,warmUp");
bufferedWriter.newLine();
for (int i = 0; i < executionTimes.length; i++) {
bufferedWriter.write(outputdepths[i] +"," + outputntrees[i] + "," + 1000 + ","+ 10 + "," + outputchunks[i] + "," + executionTimes[i] +"," +","+(outputrebalanceme[i]? 1:0)+","+(warmUp?1:0));
bufferedWriter.newLine();
}
bufferedWriter.close();
} catch (Exception e) {
Log.info("Fail");
}
}
// PUBDEV-2476 Check reproducibility for the same # of chunks (i.e., same # of nodes) and same parameters
@Test public void testChunks() {
Frame tfr;
final int N = 4;
double[] mses = new double[N];
int[] chunks = new int[]{1,13,19,39,500};
for (int i=0; i<N; ++i) {
Scope.enter();
// Load data, hack frames
tfr = parse_test_file("smalldata/covtype/covtype.20k.data");
// rebalance to 256 chunks
Key dest = Key.make("df.rebalanced.hex");
RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, chunks[i]);
H2O.submitTask(rb);
rb.join();
tfr.delete();
tfr = DKV.get(dest).get();
Scope.track(tfr.replace(54, tfr.vecs()[54].toCategoricalVec()));
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "C55";
parms._ntrees = 10;
parms._seed = 1234;
parms._auto_rebalance = false;
// Build a first model; all remaining models should be equal
DRF job = new DRF(parms);
DRFModel drf = job.trainModel().get();
assertEquals(drf._output._ntrees, parms._ntrees);
mses[i] = drf._output._scored_train[drf._output._scored_train.length-1]._mse;
drf.delete();
if (tfr != null) tfr.remove();
Scope.exit();
}
for (int i=0; i<mses.length; ++i) {
Log.info("trial: " + i + " -> MSE: " + mses[i]);
}
for(double mse : mses)
assertEquals(mse, mses[0], 1e-10);
}
//
@Test public void testReproducibility() {
Frame tfr=null;
final int N = 5;
double[] mses = new double[N];
Scope.enter();
try {
// Load data, hack frames
tfr = parse_test_file("smalldata/covtype/covtype.20k.data");
// rebalance to 256 chunks
Key dest = Key.make("df.rebalanced.hex");
RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
H2O.submitTask(rb);
rb.join();
tfr.delete();
tfr = DKV.get(dest).get();
// Scope.track(tfr.replace(54, tfr.vecs()[54].toCategoricalVec())._key);
// DKV.put(tfr);
for (int i=0; i<N; ++i) {
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "C55";
parms._nbins = 1000;
parms._ntrees = 1;
parms._max_depth = 8;
parms._mtries = -1;
parms._min_rows = 10;
parms._seed = 1234;
// Build a first model; all remaining models should be equal
DRFModel drf = new DRF(parms).trainModel().get();
assertEquals(drf._output._ntrees, parms._ntrees);
mses[i] = drf._output._scored_train[drf._output._scored_train.length-1]._mse;
drf.delete();
}
} finally{
if (tfr != null) tfr.remove();
}
Scope.exit();
for (int i=0; i<mses.length; ++i) {
Log.info("trial: " + i + " -> MSE: " + mses[i]);
}
for(double mse : mses)
assertEquals(mse, mses[0], 1e-15);
}
// PUBDEV-557 Test dependency on # nodes (for small number of bins, but fixed number of chunks)
@Test public void testReproducibilityAirline() {
Frame tfr=null;
final int N = 1;
double[] mses = new double[N];
Scope.enter();
try {
// Load data, hack frames
tfr = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");
// rebalance to fixed number of chunks
Key dest = Key.make("df.rebalanced.hex");
RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
H2O.submitTask(rb);
rb.join();
tfr.delete();
tfr = DKV.get(dest).get();
// Scope.track(tfr.replace(54, tfr.vecs()[54].toCategoricalVec())._key);
// DKV.put(tfr);
for (String s : new String[]{
"DepTime", "ArrTime", "ActualElapsedTime",
"AirTime", "ArrDelay", "DepDelay", "Cancelled",
"CancellationCode", "CarrierDelay", "WeatherDelay",
"NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"
}) {
tfr.remove(s).remove();
}
DKV.put(tfr);
for (int i=0; i<N; ++i) {
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "IsDepDelayed";
parms._nbins = 10;
parms._nbins_cats = 1024;
parms._ntrees = 7;
parms._max_depth = 10;
parms._binomial_double_trees = false;
parms._mtries = -1;
parms._min_rows = 1;
parms._sample_rate = 0.632f; // Simulated sampling with replacement
parms._balance_classes = true;
parms._seed = (1L<<32)|2;
// Build a first model; all remaining models should be equal
DRFModel drf = new DRF(parms).trainModel().get();
assertEquals(drf._output._ntrees, parms._ntrees);
mses[i] = drf._output._training_metrics.mse();
drf.delete();
}
} finally{
if (tfr != null) tfr.remove();
}
Scope.exit();
for (int i=0; i<mses.length; ++i) {
Log.info("trial: " + i + " -> MSE: " + mses[i]);
}
for (int i=0; i<mses.length; ++i) {
assertEquals(0.20377446328850304, mses[i], 1e-4); //check for the same result on 1 nodes and 5 nodes
}
}
// HEXDEV-319
@Ignore
@Test public void testAirline() {
Frame tfr=null;
Frame test=null;
Scope.enter();
try {
// Load data, hack frames
tfr = parse_test_file(Key.make("air.hex"), "/users/arno/sz_bench_data/train-1m.csv");
test = parse_test_file(Key.make("airt.hex"), "/users/arno/sz_bench_data/test.csv");
// for (int i : new int[]{0,1,2}) {
// tfr.vecs()[i] = tfr.vecs()[i].toCategoricalVec();
// test.vecs()[i] = test.vecs()[i].toCategoricalVec();
// }
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._valid = test._key;
parms._ignored_columns = new String[]{"Origin","Dest"};
// parms._ignored_columns = new String[]{"UniqueCarrier","Origin","Dest"};
// parms._ignored_columns = new String[]{"UniqueCarrier","Origin"};
// parms._ignored_columns = new String[]{"Month","DayofMonth","DayOfWeek","DepTime","UniqueCarrier","Origin","Distance"};
parms._response_column = "dep_delayed_15min";
parms._nbins = 20;
parms._nbins_cats = 1024;
parms._binomial_double_trees = new Random().nextBoolean(); //doesn't matter!
parms._ntrees = 1;
parms._max_depth = 3;
parms._mtries = -1;
parms._sample_rate = 0.632f;
parms._min_rows = 10;
parms._seed = 12;
// Build a first model; all remaining models should be equal
DRFModel drf = new DRF(parms).trainModel().get();
Log.info("Training set AUC: " + drf._output._training_metrics.auc_obj()._auc);
Log.info("Validation set AUC: " + drf._output._validation_metrics.auc_obj()._auc);
// all numerical
assertEquals(drf._output._training_metrics.auc_obj()._auc, 0.6498819479528417, 1e-8);
assertEquals(drf._output._validation_metrics.auc_obj()._auc, 0.6479974533672835, 1e-8);
drf.delete();
} finally{
if (tfr != null) tfr.remove();
if (test != null) test.remove();
}
Scope.exit();
}
static double _AUC = 1.0;
static double _MSE = 0.041294642857142856;
static double _LogLoss = 0.14472835908293025;
@Test
public void testNoRowWeights() {
Frame tfr = null, vfr = null;
DRFModel drf = null;
Scope.enter();
try {
tfr = parse_test_file("smalldata/junit/no_weights.csv");
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "response";
parms._seed = 234;
parms._min_rows = 1;
parms._max_depth = 2;
parms._ntrees = 3;
// Build a first model; all remaining models should be equal
drf = new DRF(parms).trainModel().get();
// OOB
ModelMetricsBinomial mm = (ModelMetricsBinomial)drf._output._training_metrics;
assertEquals(_AUC, mm.auc_obj()._auc, 1e-8);
assertEquals(_MSE, mm.mse(), 1e-8);
assertEquals(_LogLoss, mm.logloss(), 1e-6);
} finally {
if (tfr != null) tfr.remove();
if (vfr != null) vfr.remove();
if (drf != null) drf.remove();
Scope.exit();
}
}
@Test
public void testRowWeightsOne() {
Frame tfr = null, vfr = null;
DRFModel drf = null;
Scope.enter();
try {
tfr = parse_test_file("smalldata/junit/weights_all_ones.csv");
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "response";
parms._weights_column = "weight";
parms._seed = 234;
parms._min_rows = 1;
parms._max_depth = 2;
parms._ntrees = 3;
// Build a first model; all remaining models should be equal
drf = new DRF(parms).trainModel().get();
// OOB
ModelMetricsBinomial mm = (ModelMetricsBinomial)drf._output._training_metrics;
assertEquals(_AUC, mm.auc_obj()._auc, 1e-8);
assertEquals(_MSE, mm.mse(), 1e-8);
assertEquals(_LogLoss, mm.logloss(), 1e-6);
} finally {
if (tfr != null) tfr.remove();
if (vfr != null) vfr.remove();
if (drf != null) drf.delete();
Scope.exit();
}
}
@Test
public void testRowWeightsTwo() {
Frame tfr = null, vfr = null;
DRFModel drf = null;
Scope.enter();
try {
tfr = parse_test_file("smalldata/junit/weights_all_twos.csv");
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "response";
parms._weights_column = "weight";
parms._seed = 234;
parms._min_rows = 2; //in terms of weighted rows
parms._max_depth = 2;
parms._ntrees = 3;
// Build a first model; all remaining models should be equal
drf = new DRF(parms).trainModel().get();
// OOB
ModelMetricsBinomial mm = (ModelMetricsBinomial)drf._output._training_metrics;
assertEquals(_AUC, mm.auc_obj()._auc, 1e-8);
assertEquals(_MSE, mm.mse(), 1e-8);
assertEquals(_LogLoss, mm.logloss(), 1e-6);
} finally {
if (tfr != null) tfr.remove();
if (vfr != null) vfr.remove();
if (drf != null) drf.delete();
Scope.exit();
}
}
@Test
public void testRowWeightsTiny() {
Frame tfr = null, vfr = null;
DRFModel drf = null;
Scope.enter();
try {
tfr = parse_test_file("smalldata/junit/weights_all_tiny.csv");
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "response";
parms._weights_column = "weight";
parms._seed = 234;
parms._min_rows = 0.01242; // in terms of weighted rows
parms._max_depth = 2;
parms._ntrees = 3;
// Build a first model; all remaining models should be equal
drf = new DRF(parms).trainModel().get();
// OOB
ModelMetricsBinomial mm = (ModelMetricsBinomial)drf._output._training_metrics;
assertEquals(_AUC, mm.auc_obj()._auc, 1e-8);
assertEquals(_MSE, mm.mse(), 1e-8);
assertEquals(_LogLoss, mm.logloss(), 1e-6);
} finally {
if (tfr != null) tfr.remove();
if (vfr != null) vfr.remove();
if (drf != null) drf.delete();
Scope.exit();
}
}
@Test
public void testNoRowWeightsShuffled() {
Frame tfr = null, vfr = null;
DRFModel drf = null;
Scope.enter();
try {
tfr = parse_test_file("smalldata/junit/no_weights_shuffled.csv");
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "response";
parms._seed = 234;
parms._min_rows = 1;
parms._max_depth = 2;
parms._ntrees = 3;
// Build a first model; all remaining models should be equal
drf = new DRF(parms).trainModel().get();
// OOB
// Shuffling changes the row sampling -> results differ
ModelMetricsBinomial mm = (ModelMetricsBinomial)drf._output._training_metrics;
assertEquals(1.0, mm.auc_obj()._auc, 1e-8);
assertEquals(0.0290178571428571443, mm.mse(), 1e-8);
assertEquals(0.10824081452821664, mm.logloss(), 1e-6);
} finally {
if (tfr != null) tfr.remove();
if (vfr != null) vfr.remove();
if (drf != null) drf.delete();
Scope.exit();
}
}
@Test
public void testRowWeights() {
Frame tfr = null, vfr = null;
DRFModel drf = null;
Scope.enter();
try {
tfr = parse_test_file("smalldata/junit/weights.csv");
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "response";
parms._weights_column = "weight";
parms._seed = 234;
parms._min_rows = 1;
parms._max_depth = 2;
parms._ntrees = 3;
// Build a first model; all remaining models should be equal
drf = new DRF(parms).trainModel().get();
// OOB
// Reduced number of rows changes the row sampling -> results differ
ModelMetricsBinomial mm = (ModelMetricsBinomial)drf._output._training_metrics;
assertEquals(1.0, mm.auc_obj()._auc, 1e-8);
assertEquals(0.05823863636363636, mm.mse(), 1e-8);
assertEquals(0.21035264541934587, mm.logloss(), 1e-6);
// test set scoring (on the same dataset, but without normalizing the weights)
Frame pred = drf.score(parms.train());
hex.ModelMetricsBinomial mm2 = hex.ModelMetricsBinomial.getFromDKV(drf, parms.train());
// Non-OOB
assertEquals(1, mm2.auc_obj()._auc, 1e-8);
assertEquals(0.0154320987654321, mm2.mse(), 1e-8);
assertEquals(0.08349430638608361, mm2.logloss(), 1e-8);
pred.remove();
} finally {
if (tfr != null) tfr.remove();
if (vfr != null) vfr.remove();
if (drf != null) drf.delete();
Scope.exit();
}
}
@Ignore
@Test
public void testNFold() {
Frame tfr = null, vfr = null;
DRFModel drf = null;
Scope.enter();
try {
tfr = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");
for (String s : new String[]{
"DepTime", "ArrTime", "ActualElapsedTime",
"AirTime", "ArrDelay", "DepDelay", "Cancelled",
"CancellationCode", "CarrierDelay", "WeatherDelay",
"NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"
}) {
tfr.remove(s).remove();
}
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "IsDepDelayed";
parms._seed = 234;
parms._min_rows = 2;
parms._nfolds = 3;
parms._max_depth = 5;
parms._ntrees = 5;
// Build a first model; all remaining models should be equal
drf = new DRF(parms).trainModel().get();
ModelMetricsBinomial mm = (ModelMetricsBinomial)drf._output._cross_validation_metrics;
assertEquals(0.7276154565296726, mm.auc_obj()._auc, 1e-8); // 1 node
assertEquals(0.21211607823987555, mm.mse(), 1e-8);
assertEquals(0.6121968624307211, mm.logloss(), 1e-6);
} finally {
if (tfr != null) tfr.remove();
if (vfr != null) vfr.remove();
if (drf != null) {
drf.deleteCrossValidationModels();
drf.delete();
}
Scope.exit();
}
}
@Test
public void testNFoldBalanceClasses() {
Frame tfr = null, vfr = null;
DRFModel drf = null;
Scope.enter();
try {
tfr = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");
for (String s : new String[]{
"DepTime", "ArrTime", "ActualElapsedTime",
"AirTime", "ArrDelay", "DepDelay", "Cancelled",
"CancellationCode", "CarrierDelay", "WeatherDelay",
"NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"
}) {
tfr.remove(s).remove();
}
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "IsDepDelayed";
parms._seed = 234;
parms._min_rows = 2;
parms._nfolds = 3;
parms._max_depth = 5;
parms._balance_classes = true;
parms._ntrees = 5;
// Build a first model; all remaining models should be equal
drf = new DRF(parms).trainModel().get();
} finally {
if (tfr != null) tfr.remove();
if (vfr != null) vfr.remove();
if (drf != null) {
drf.deleteCrossValidationModels();
drf.delete();
}
Scope.exit();
}
}
@Test
public void testNfoldsOneVsRest() {
Frame tfr = null;
DRFModel drf1 = null;
DRFModel drf2 = null;
Scope.enter();
try {
tfr = parse_test_file("smalldata/junit/weights.csv");
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "response";
parms._seed = 9999;
parms._min_rows = 2;
parms._nfolds = (int) tfr.numRows();
parms._fold_assignment = Model.Parameters.FoldAssignmentScheme.Modulo;
parms._max_depth = 5;
parms._ntrees = 5;
drf1 = new DRF(parms).trainModel().get();
// parms._nfolds = (int) tfr.numRows() + 1; //this is now an error
drf2 = new DRF(parms).trainModel().get();
ModelMetricsBinomial mm1 = (ModelMetricsBinomial)drf1._output._cross_validation_metrics;
ModelMetricsBinomial mm2 = (ModelMetricsBinomial)drf2._output._cross_validation_metrics;
assertEquals(mm1.auc_obj()._auc, mm2.auc_obj()._auc, 1e-12);
assertEquals(mm1.mse(), mm2.mse(), 1e-12);
assertEquals(mm1.logloss(), mm2.logloss(), 1e-12);
//TODO: add check: the correct number of individual models were built. PUBDEV-1690
} finally {
if (tfr != null) tfr.remove();
if (drf1 != null) {
drf1.deleteCrossValidationModels();
drf1.delete();
}
if (drf2 != null) {
drf2.deleteCrossValidationModels();
drf2.delete();
}
Scope.exit();
}
}
@Test
public void testNfoldsInvalidValues() {
Frame tfr = null;
DRFModel drf1 = null;
DRFModel drf2 = null;
DRFModel drf3 = null;
Scope.enter();
try {
tfr = parse_test_file("./smalldata/airlines/allyears2k_headers.zip");
for (String s : new String[]{
"DepTime", "ArrTime", "ActualElapsedTime",
"AirTime", "ArrDelay", "DepDelay", "Cancelled",
"CancellationCode", "CarrierDelay", "WeatherDelay",
"NASDelay", "SecurityDelay", "LateAircraftDelay", "IsArrDelayed"
}) {
tfr.remove(s).remove();
}
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "IsDepDelayed";
parms._seed = 234;
parms._min_rows = 2;
parms._max_depth = 5;
parms._ntrees = 5;
parms._nfolds = 0;
drf1 = new DRF(parms).trainModel().get();
parms._nfolds = 1;
try {
Log.info("Trying nfolds==1.");
drf2 = new DRF(parms).trainModel().get();
Assert.fail("Should toss H2OModelBuilderIllegalArgumentException instead of reaching here");
} catch(H2OModelBuilderIllegalArgumentException e) {}
parms._nfolds = -99;
try {
Log.info("Trying nfolds==-99.");
drf3 = new DRF(parms).trainModel().get();
Assert.fail("Should toss H2OModelBuilderIllegalArgumentException instead of reaching here");
} catch(H2OModelBuilderIllegalArgumentException e) {}
} finally {
if (tfr != null) tfr.remove();
if (drf1 != null) drf1.delete();
if (drf2 != null) drf2.delete();
if (drf3 != null) drf3.delete();
Scope.exit();
}
}
@Test
public void testNfoldsCVAndValidation() {
Frame tfr = null, vfr = null;
DRFModel drf = null;
Scope.enter();
try {
tfr = parse_test_file("smalldata/junit/weights.csv");
vfr = parse_test_file("smalldata/junit/weights.csv");
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._valid = vfr._key;
parms._response_column = "response";
parms._min_rows = 2;
parms._max_depth = 2;
parms._nfolds = 2;
parms._ntrees = 3;
parms._seed = 11233;
try {
Log.info("Trying N-fold cross-validation AND Validation dataset provided.");
drf = new DRF(parms).trainModel().get();
} catch(H2OModelBuilderIllegalArgumentException e) {
Assert.fail("Should not toss H2OModelBuilderIllegalArgumentException.");
}
} finally {
if (tfr != null) tfr.remove();
if (vfr != null) vfr.remove();
if (drf != null) {
drf.deleteCrossValidationModels();
drf.delete();
}
Scope.exit();
}
}
@Test
public void testNfoldsConsecutiveModelsSame() {
Frame tfr = null;
Vec old = null;
DRFModel drf1 = null;
DRFModel drf2 = null;
Scope.enter();
try {
tfr = parse_test_file("smalldata/junit/cars_20mpg.csv");
tfr.remove("name").remove(); // Remove unique id
tfr.remove("economy").remove();
old = tfr.remove("economy_20mpg");
tfr.add("economy_20mpg", VecUtils.toCategoricalVec(old)); // response to last column
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "economy_20mpg";
parms._min_rows = 2;
parms._max_depth = 2;
parms._nfolds = 3;
parms._ntrees = 3;
parms._seed = 77777;
drf1 = new DRF(parms).trainModel().get();
drf2 = new DRF(parms).trainModel().get();
ModelMetricsBinomial mm1 = (ModelMetricsBinomial)drf1._output._cross_validation_metrics;
ModelMetricsBinomial mm2 = (ModelMetricsBinomial)drf2._output._cross_validation_metrics;
assertEquals(mm1.auc_obj()._auc, mm2.auc_obj()._auc, 1e-12);
assertEquals(mm1.mse(), mm2.mse(), 1e-12);
assertEquals(mm1.logloss(), mm2.logloss(), 1e-12);
} finally {
if (tfr != null) tfr.remove();
if (old != null) old.remove();
if (drf1 != null) {
drf1.deleteCrossValidationModels();
drf1.delete();
}
if (drf2 != null) {
drf2.deleteCrossValidationModels();
drf2.delete();
}
Scope.exit();
}
}
@Test
public void testMTrys() {
Frame tfr = null;
Vec old = null;
DRFModel drf1 = null;
for (int i=1; i<=6; ++i) {
Scope.enter();
try {
tfr = parse_test_file("smalldata/junit/cars_20mpg.csv");
tfr.remove("name").remove(); // Remove unique id
tfr.remove("economy").remove();
old = tfr.remove("economy_20mpg");
tfr.add("economy_20mpg", VecUtils.toCategoricalVec(old)); // response to last column
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "economy_20mpg";
parms._min_rows = 2;
parms._ntrees = 5;
parms._max_depth = 5;
parms._nfolds = 3;
parms._mtries = i;
drf1 = new DRF(parms).trainModel().get();
ModelMetricsBinomial mm1 = (ModelMetricsBinomial) drf1._output._cross_validation_metrics;
Assert.assertTrue(mm1._auc != null);
} finally {
if (tfr != null) tfr.remove();
if (old != null) old.remove();
if (drf1 != null) {
drf1.deleteCrossValidationModels();
drf1.delete();
}
Scope.exit();
}
}
}
@Test
public void testStochasticDRFEquivalent() {
Frame tfr = null, vfr = null;
DRFModel drf = null;
Scope.enter();
try {
tfr = parse_test_file("./smalldata/junit/cars.csv");
for (String s : new String[]{
"name",
}) {
tfr.remove(s).remove();
}
DKV.put(tfr);
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = tfr._key;
parms._response_column = "cylinders"; //regression
parms._seed = 234;
parms._min_rows = 2;
parms._max_depth = 5;
parms._ntrees = 5;
parms._mtries = 3;
parms._sample_rate = 0.5f;
// Build a first model; all remaining models should be equal
drf = new DRF(parms).trainModel().get();
ModelMetricsRegression mm = (ModelMetricsRegression)drf._output._training_metrics;
assertEquals(0.12358322821934015, mm.mse(), 1e-4);
} finally {
if (tfr != null) tfr.remove();
if (vfr != null) vfr.remove();
if (drf != null) drf.delete();
Scope.exit();
}
}
@Test
public void testColSamplingPerTree() {
Frame tfr = null;
Key[] ksplits = new Key[0];
try{
tfr=parse_test_file("./smalldata/gbm_test/ecology_model.csv");
SplitFrame sf = new SplitFrame(tfr,new double[] { 0.5, 0.5 }, new Key[] { Key.make("train.hex"), Key.make("test.hex")});
// Invoke the job
sf.exec().get();
ksplits = sf._destination_frames;
DRFModel drf = null;
float[] sample_rates = new float[]{0.2f, 0.4f, 0.6f, 0.8f, 1.0f};
float[] col_sample_rates = new float[]{0.4f, 0.6f, 0.8f, 1.0f};
float[] col_sample_rates_per_tree = new float[]{0.4f, 0.6f, 0.8f, 1.0f};
Map<Double, Triple<Float>> hm = new TreeMap<>();
for (float sample_rate : sample_rates) {
for (float col_sample_rate : col_sample_rates) {
for (float col_sample_rate_per_tree : col_sample_rates_per_tree) {
Scope.enter();
try {
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = ksplits[0];
parms._valid = ksplits[1];
parms._response_column = "Angaus"; //regression
parms._seed = 12345;
parms._min_rows = 1;
parms._max_depth = 15;
parms._ntrees = 2;
parms._mtries = Math.max(1,(int)(col_sample_rate*(tfr.numCols()-1)));
parms._col_sample_rate_per_tree = col_sample_rate_per_tree;
parms._sample_rate = sample_rate;
// Build a first model; all remaining models should be equal
DRF job = new DRF(parms);
drf = job.trainModel().get();
// too slow, but passes (now)
// // Build a POJO, validate same results
// Frame pred = drf.score(tfr);
// Assert.assertTrue(drf.testJavaScoring(tfr,pred,1e-15));
// pred.remove();
ModelMetricsRegression mm = (ModelMetricsRegression)drf._output._validation_metrics;
hm.put(mm.mse(), new Triple<>(sample_rate, col_sample_rate, col_sample_rate_per_tree));
} finally {
if (drf != null) drf.delete();
Scope.exit();
}
}
}
}
Iterator<Map.Entry<Double, Triple<Float>>> it;
Triple<Float> last = null;
// iterator over results (min to max MSE) - best to worst
for (it=hm.entrySet().iterator(); it.hasNext();) {
Map.Entry<Double, Triple<Float>> n = it.next();
Log.info( "MSE: " + n.getKey()
+ ", row sample: " + n.getValue().v1
+ ", col sample: " + n.getValue().v2
+ ", col sample per tree: " + n.getValue().v3);
last=n.getValue();
}
// worst validation MSE should belong to the most overfit case (1.0, 1.0, 1.0)
// Assert.assertTrue(last.v1==sample_rates[sample_rates.length-1]);
// Assert.assertTrue(last.v2==col_sample_rates[col_sample_rates.length-1]);
// Assert.assertTrue(last.v3==col_sample_rates_per_tree[col_sample_rates_per_tree.length-1]);
} finally {
if (tfr != null) tfr.remove();
for (Key k : ksplits)
if (k!=null) k.remove();
}
}
@Test public void minSplitImprovement() {
Frame tfr = null;
Key[] ksplits = null;
DRFModel drf = null;
try {
Scope.enter();
tfr = parse_test_file("smalldata/covtype/covtype.20k.data");
int resp = 54;
// tfr = parse_test_file("bigdata/laptop/mnist/train.csv.gz");
// int resp = 784;
Scope.track(tfr.replace(resp, tfr.vecs()[resp].toCategoricalVec()));
DKV.put(tfr);
SplitFrame sf = new SplitFrame(tfr, new double[]{0.5, 0.5}, new Key[]{Key.make("train.hex"), Key.make("valid.hex")});
// Invoke the job
sf.exec().get();
ksplits = sf._destination_frames;
double[] msi = new double[]{0, 1e-10, 1e-8, 1e-6, 1e-4, 1e-2};
final int N = msi.length;
double[] loglosses = new double[N];
for (int i = 0; i < N; ++i) {
// Load data, hack frames
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = ksplits[0];
parms._valid = ksplits[1];
parms._response_column = tfr.names()[resp];
parms._min_split_improvement = msi[i];
parms._ntrees = 20;
parms._score_tree_interval = parms._ntrees;
parms._max_depth = 15;
parms._seed = 1234;
DRF job = new DRF(parms);
drf = job.trainModel().get();
loglosses[i] = drf._output._scored_valid[drf._output._scored_valid.length - 1]._logloss;
if (drf!=null) drf.delete();
}
for (int i = 0; i < msi.length; ++i) {
Log.info("min_split_improvement: " + msi[i] + " -> validation logloss: " + loglosses[i]);
}
int idx = ArrayUtils.minIndex(loglosses);
Log.info("Optimal min_split_improvement: " + msi[idx]);
Assert.assertTrue(0 != idx);
} finally {
if (drf!=null) drf.delete();
if (tfr!=null) tfr.delete();
if (ksplits[0]!=null) ksplits[0].remove();
if (ksplits[1]!=null) ksplits[1].remove();
Scope.exit();
}
}
@Test public void histoTypes() {
Frame tfr = null;
Key[] ksplits = null;
DRFModel drf = null;
try {
Scope.enter();
tfr = parse_test_file("smalldata/covtype/covtype.20k.data");
int resp = 54;
// tfr = parse_test_file("bigdata/laptop/mnist/train.csv.gz");
// int resp = 784;
Scope.track(tfr.replace(resp, tfr.vecs()[resp].toCategoricalVec()));
DKV.put(tfr);
SplitFrame sf = new SplitFrame(tfr, new double[]{0.5, 0.5}, new Key[]{Key.make("train.hex"), Key.make("valid.hex")});
// Invoke the job
sf.exec().get();
ksplits = sf._destination_frames;
SharedTreeModel.SharedTreeParameters.HistogramType[] histoType = SharedTreeModel.SharedTreeParameters.HistogramType.values();
final int N = histoType.length;
double[] loglosses = new double[N];
for (int i = 0; i < N; ++i) {
// Load data, hack frames
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = ksplits[0];
parms._valid = ksplits[1];
parms._response_column = tfr.names()[resp];
parms._histogram_type = histoType[i];
parms._ntrees = 10;
parms._score_tree_interval = parms._ntrees;
parms._max_depth = 10;
parms._seed = 12345;
parms._nbins = 20;
parms._nbins_top_level = 20;
DRF job = new DRF(parms);
drf = job.trainModel().get();
loglosses[i] = drf._output._scored_valid[drf._output._scored_valid.length - 1]._logloss;
if (drf!=null) drf.delete();
}
for (int i = 0; i < histoType.length; ++i) {
Log.info("histoType: " + histoType[i] + " -> validation logloss: " + loglosses[i]);
}
int idx = ArrayUtils.minIndex(loglosses);
Log.info("Optimal randomization: " + histoType[idx]);
Assert.assertTrue(4 == idx); //Quantiles are best
} finally {
if (drf!=null) drf.delete();
if (tfr!=null) tfr.delete();
if (ksplits[0]!=null) ksplits[0].remove();
if (ksplits[1]!=null) ksplits[1].remove();
Scope.exit();
}
}
@Test public void sampleRatePerClass() {
Frame tfr = null;
Key[] ksplits = null;
DRFModel drf = null;
try {
Scope.enter();
tfr = parse_test_file("smalldata/covtype/covtype.20k.data");
int resp = 54;
// tfr = parse_test_file("bigdata/laptop/mnist/train.csv.gz");
// int resp = 784;
Scope.track(tfr.replace(resp, tfr.vecs()[resp].toCategoricalVec()));
DKV.put(tfr);
SplitFrame sf = new SplitFrame(tfr, new double[]{0.5, 0.5}, new Key[]{Key.make("train.hex"), Key.make("valid.hex")});
// Invoke the job
sf.exec().get();
ksplits = sf._destination_frames;
// Load data, hack frames
DRFModel.DRFParameters parms = new DRFModel.DRFParameters();
parms._train = ksplits[0];
parms._valid = ksplits[1];
parms._response_column = tfr.names()[resp];
parms._min_split_improvement = 1e-5;
parms._ntrees = 20;
parms._score_tree_interval = parms._ntrees;
parms._max_depth = 15;
parms._seed = 1234;
parms._sample_rate_per_class = new double[]{0.1f,0.1f,0.2f,0.4f,1f,0.3f,0.2f};
DRF job = new DRF(parms);
drf = job.trainModel().get();
if (drf!=null) drf.delete();
} finally {
if (drf!=null) drf.delete();
if (tfr!=null) tfr.delete();
if (ksplits[0]!=null) ksplits[0].remove();
if (ksplits[1]!=null) ksplits[1].remove();
Scope.exit();
}
}
}