package hex.drf;
import hex.drf.DRF.DRFModel;
import hex.gbm.GBM;
import org.junit.*;
import static org.junit.Assert.assertEquals;
import water.*;
import water.api.AUC;
import water.api.DRFModelView;
import water.fvec.Frame;
import water.fvec.RebalanceDataSet;
import water.fvec.Vec;
import water.util.Log;
public class DRFTest extends TestUtil {
private final void testHTML(DRF.DRFModel m) {
StringBuilder sb = new StringBuilder();
DRFModelView drfv = new DRFModelView();
drfv.drf_model = m;
drfv.toHTML(sb);
assert(sb.length() > 0);
}
@BeforeClass public static void stall() { stall_till_cloudsize(1); }
abstract static class PrepData { abstract int prep(Frame fr); }
static final String[] s(String...arr) { return arr; }
static final long[] a(long ...arr) { return arr; }
static final long[][] a(long[] ...arr) { return arr; }
@Test public void testClassIris1() throws Throwable {
// iris ntree=1
// the DRF should use only subset of rows since it is using oob validation
basicDRFTestOOBE(
"./smalldata/iris/iris.csv","iris.hex",
new PrepData() { @Override int prep(Frame fr) { return fr.numCols()-1; } },
1,
a( a(25, 0, 0),
a(0, 17, 1),
a(1, 2, 15)),
s("Iris-setosa","Iris-versicolor","Iris-virginica") );
}
@Test public void testClassIris5() throws Throwable {
// iris ntree=50
basicDRFTestOOBE(
"./smalldata/iris/iris.csv","iris.hex",
new PrepData() { @Override int prep(Frame fr) { return fr.numCols()-1; } },
5,
a( a(41, 0, 0),
a(0, 39, 3),
a(0, 4, 41)),
s("Iris-setosa","Iris-versicolor","Iris-virginica") );
}
@Test public void testClassCars1() throws Throwable {
// cars ntree=1
basicDRFTestOOBE(
"./smalldata/cars.csv","cars.hex",
new PrepData() { @Override int prep(Frame fr) { UKV.remove(fr.remove("name")._key); return fr.find("cylinders"); } },
1,
a( a(0, 0, 0, 0, 0),
a(0, 62, 0, 7, 0),
a(0, 1, 0, 0, 0),
a(0, 0, 0,31, 0),
a(0, 0, 0, 0,40)),
s("3", "4", "5", "6", "8"));
}
@Test public void testClassCars5() throws Throwable {
basicDRFTestOOBE(
"./smalldata/cars.csv","cars.hex",
new PrepData() { @Override int prep(Frame fr) { UKV.remove(fr.remove("name")._key); return fr.find("cylinders"); } },
5,
a( a(3, 0, 0, 0, 0),
a(0, 173, 2, 9, 0),
a(0, 1, 1, 0, 0),
a(0, 2, 2, 68, 2),
a(0, 0, 0, 2, 88)),
s("3", "4", "5", "6", "8"));
}
@Test public void testConstantCols() throws Throwable {
try {
basicDRFTestOOBE(
"./smalldata/poker/poker100","poker.hex",
new PrepData() { @Override int prep(Frame fr) {
for (int i=0; i<7;i++) UKV.remove(fr.remove(3)._key);
return 3;
} },
1,
null,
null);
Assert.fail();
} catch( IllegalArgumentException iae ) { /*pass*/ }
}
@Ignore @Test public void testBadData() throws Throwable {
basicDRFTestOOBE(
"./smalldata/test/drf_infinitys.csv","infinitys.hex",
new PrepData() { @Override int prep(Frame fr) { return fr.find("DateofBirth"); } },
1,
a( a(6, 0),
a(9, 1)),
s("0", "1"));
}
//@Test
public void testCreditSample1() throws Throwable {
basicDRFTestOOBE(
"./smalldata/kaggle/creditsample-training.csv.gz","credit.hex",
new PrepData() { @Override int prep(Frame fr) {
UKV.remove(fr.remove("MonthlyIncome")._key); return fr.find("SeriousDlqin2yrs");
} },
1,
a( a(46294, 202),
a( 3187, 107)),
s("0", "1"));
}
@Test
public void testCreditProstate1() throws Throwable {
basicDRFTestOOBE(
"./smalldata/logreg/prostate.csv","prostate.hex",
new PrepData() { @Override int prep(Frame fr) {
UKV.remove(fr.remove("ID")._key); return fr.find("CAPSULE");
} },
1,
a( a(62, 19),
a(31, 22)),
s("0", "1"));
}
@Test public void testAirlines() throws Throwable {
basicDRFTestOOBE(
"./smalldata/airlines/allyears2k_headers.zip","airlines.hex",
new PrepData() {
@Override int prep(Frame fr) {
UKV.remove(fr.remove("DepTime")._key);
UKV.remove(fr.remove("ArrTime")._key);
UKV.remove(fr.remove("ActualElapsedTime")._key);
UKV.remove(fr.remove("AirTime")._key);
UKV.remove(fr.remove("ArrDelay")._key);
UKV.remove(fr.remove("DepDelay")._key);
UKV.remove(fr.remove("Cancelled")._key);
UKV.remove(fr.remove("CancellationCode")._key);
UKV.remove(fr.remove("CarrierDelay")._key);
UKV.remove(fr.remove("WeatherDelay")._key);
UKV.remove(fr.remove("NASDelay")._key);
UKV.remove(fr.remove("SecurityDelay")._key);
UKV.remove(fr.remove("LateAircraftDelay")._key);
UKV.remove(fr.remove("IsArrDelayed")._key);
return fr.find("IsDepDelayed"); }
},
50,
a( a(13941, 6946),
a( 5885,17206)),
s("NO", "YES"));
}
// Put response as the last vector in the frame and return it.
// Also fill DRF.
static Vec unifyFrame(DRF drf, Frame fr, PrepData prep) {
int idx = prep.prep(fr);
if( idx < 0 ) { drf.classification = false; idx = ~idx; }
String rname = fr._names[idx];
drf.response = fr.vecs()[idx];
fr.remove(idx); // Move response to the end
fr.add(rname,drf.response);
return drf.response;
}
public void basicDRFTestOOBE(String fnametrain, String hexnametrain, PrepData prep, int ntree, long[][] expCM, String[] expRespDom) throws Throwable { basicDRF(fnametrain, hexnametrain, null, null, prep, ntree, expCM, expRespDom, 10/*max_depth*/, 20/*nbins*/, 0/*optflag*/); }
public void basicDRF(String fnametrain, String hexnametrain, String fnametest, String hexnametest, PrepData prep, int ntree, long[][] expCM, String[] expRespDom, int max_depth, int nbins, int optflags) throws Throwable {
DRF drf = new DRF();
Key destTrain = Key.make(hexnametrain);
Key destTest = hexnametest!=null?Key.make(hexnametest):null;
Frame frTest = null, pred = null;
DRFModel model = null;
try {
Frame frTrain = drf.source = parseFrame(destTrain, fnametrain);
unifyFrame(drf, frTrain, prep);
// Configure DRF
drf.classification = true;
drf.ntrees = ntree;
drf.max_depth = max_depth;
drf.min_rows = 1; // = nodesize
drf.nbins = nbins;
drf.mtries = -1;
drf.sample_rate = 0.66667f; // Simulated sampling with replacement
drf.seed = (1L<<32)|2;
drf.destination_key = Key.make("DRF_model_4_" + hexnametrain);
// Invoke DRF and block till the end
drf.invoke();
// Get the model
model = UKV.get(drf.dest());
Assert.assertTrue(model.get_params().state == Job.JobState.DONE); //HEX-1817
testHTML(model);
// And compare CMs
assertCM(expCM, model.cms[model.cms.length-1]._arr);
Assert.assertEquals("Number of trees differs!", ntree, model.errs.length-1);
String[] cmDom = model._domains[model._domains.length-1];
Assert.assertArrayEquals("CM domain differs!", expRespDom, cmDom);
frTest = fnametest!=null ? parseFrame(destTest, fnametest) : null;
pred = drf.score(frTest!=null?frTest:drf.source);
} finally {
drf.source.delete();
UKV.remove(drf.response._key);
drf.remove();
if (frTest!=null) frTest.delete();
if( model != null ) model.delete(); // Remove the model
if( pred != null ) pred.delete();
}
}
@Test public void testReproducibility() {
Frame tfr=null;
final int N = 5;
double[] mses = new double[N];
Scope.enter();
try {
// Load data, hack frames
tfr = parseFrame(Key.make("air.hex"), "./smalldata/covtype/covtype.20k.data");
// rebalance to 256 chunks
Key dest = Key.make("df.rebalanced.hex");
RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
H2O.submitTask(rb);
rb.join();
tfr.delete();
tfr = DKV.get(dest).get();
for (int i=0; i<N; ++i) {
DRF parms = new DRF();
parms.source = tfr;
parms.response = tfr.lastVec();
parms.nbins = 1000;
parms.ntrees = 1;
parms.max_depth = 8;
parms.mtries = -1;
parms.min_rows = 10;
parms.classification = false;
parms.seed = 1234;
// Build a first model; all remaining models should be equal
DRFModel drf = parms.fork().get();
mses[i] = drf.mse();
drf.delete();
}
} finally{
if (tfr != null) tfr.delete();
}
Scope.exit();
for (int i=0; i<mses.length; ++i) {
Log.info("trial: " + i + " -> mse: " + mses[i]);
}
for (int i=0; i<mses.length; ++i) {
assertEquals(mses[i], mses[0], 1e-15);
}
}
public static class repro {
@Ignore
@Test public void testAirline() throws InterruptedException {
Frame tfr=null;
Frame test=null;
Scope.enter();
try {
// Load data, hack frames
tfr = parseFrame(Key.make("air.hex"), "/users/arno/sz_bench_data/train-1m.csv");
test = parseFrame(Key.make("airt.hex"), "/users/arno/sz_bench_data/test.csv");
for (int i : new int[]{0,1,2}) {
tfr.vecs()[i] = tfr.vecs()[i].toEnum();
test.vecs()[i] = test.vecs()[i].toEnum();
}
DRF parms = new DRF();
parms.source = tfr;
parms.validation = test;
// parms.ignored_cols_by_name = new int[]{4,5,6};
// parms.ignored_cols_by_name = new int[]{0,1,2,3,4,5,7};
parms.response = tfr.lastVec();
parms.nbins = 20;
parms.ntrees = 100;
parms.max_depth = 20;
parms.mtries = -1;
parms.sample_rate = 0.667f;
parms.min_rows = 10;
parms.classification = true;
parms.seed = 12;
DRFModel drf = parms.fork().get();
Frame pred = drf.score(test);
AUC auc = new AUC();
auc.vactual = test.lastVec();
auc.vpredict = pred.lastVec();
auc.invoke();
Log.info("Test set AUC: " + auc.data().AUC);
drf.delete();
} finally{
if (tfr != null) tfr.delete();
if (test != null) test.delete();
}
Scope.exit();
}
}
}