package hex.drf;
import junit.framework.Assert;
import hex.drf.DRF.DRFModel;
import hex.trees.TreeTestWithBalanceAndCrossVal;
import org.junit.*;
import water.*;
import water.fvec.Frame;
import water.fvec.Vec;
public class DRFTest2 extends TreeTestWithBalanceAndCrossVal {
//@BeforeClass public static void stall() { stall_till_cloudsize(1); }
// A bigger DRF test, useful for tracking memory issues.
/*@Test*/ public void testAirlines() throws Throwable {
for( int i=0; i<10; i++ ) {
new DRFTest().basicDRF(
//
//"../demo/c5/row10000.csv.gz", "c5.hex", null, null,
"../datasets/UCI/UCI-large/covtype/covtype.data", "covtype.hex", null, null,
new DRFTest.PrepData() { @Override int prep(Frame fr) { return fr.numCols()-1; } },
10/*ntree*/,
ar( ar( 199019, 7697, 15, 0, 180, 45, 546),
ar( 8012, 267788, 514, 7, 586, 329, 181),
ar( 16, 707, 33424, 162, 53, 639, 0),
ar( 1, 5, 353, 2211, 0, 99, 0),
ar( 181, 1456, 134, 0, 7455, 43, 4),
ar( 30, 540, 1171, 96, 33, 15109, 0),
ar( 865, 167, 0, 0, 9, 0, 19075)),
ar("1", "2", "3", "4", "5", "6", "7"),
//"./smalldata/iris/iris_wheader.csv", "iris.hex", null, null,
//new DRFTest.PrepData() { @Override int prep(Frame fr) { return fr.numCols()-1; } },
//10/*ntree*/,
//a( a( 50, 0, 0),
// a( 0, 50, 0),
// a( 0, 0, 50)),
//s("Iris-setosa","Iris-versicolor","Iris-virginica"),
//"./smalldata/logreg/prostate.csv", "prostate.hex", null, null,
//new DRFTest.PrepData() { @Override int prep(Frame fr) {
// UKV.remove(fr.remove("ID")._key); return fr.find("CAPSULE");
// } },
//10/*ntree*/,
//a( a(170, 55),
// a( 60, 92)),
//s("0","1"),
99/*max_depth*/,
20/*nbins*/,
0 /*optflag*/ );
}
}
@Test @Ignore public void dummy_test() {
/* this is just a dummy test to avoid JUnit complains about missing test */
}
@Override
protected void testBalanceWithCrossValidation(String dataset, int response, int[] ignored_cols, int ntrees, int nfolds) {
Frame f = parseFrame(dataset);
DRFModel model = null;
DRF drf = new DRF();
try {
Vec respVec = f.vec(response);
// Build a model
drf.source = f;
drf.response = respVec;
drf.ignored_cols = ignored_cols;
drf.classification = true;
drf.ntrees = ntrees;
drf.seed = 42;
drf.balance_classes = true;
drf.n_folds = nfolds;
drf.keep_cross_validation_splits = false;
drf.invoke();
Assert.assertEquals("Number of cross validation model is wrond!", nfolds, drf.xval_models.length);
model = UKV.get(drf.dest());
Assert.assertTrue(model.get_params().state == Job.JobState.DONE); //HEX-1817
} finally {
if (f!=null) f.delete();
if (model!=null) {
if (drf.xval_models!=null) {
for (Key k : drf.xval_models) {
Model m = UKV.get(k);
m.delete();
}
}
model.delete();
}
}
}
}