package hex.gbm;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import hex.gbm.GBM.Family;
import hex.gbm.GBM.GBMModel;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.*;
import water.api.ConfusionMatrix;
import water.api.GBMModelView;
import water.fvec.*;
import water.util.Log;
import java.io.File;
public class GBMTest extends TestUtil {
private void testHTML(GBMModel m) {
StringBuilder sb = new StringBuilder();
GBMModelView gbmv = new GBMModelView();
gbmv.gbm_model = m;
gbmv.toHTML(sb);
assert(sb.length() > 0);
}
@BeforeClass public static void stall() { stall_till_cloudsize(1); }
private abstract class PrepData { abstract int prep(Frame fr); }
static final String ignored_aircols[] = new String[] { "DepTime", "ArrTime", "AirTime", "ArrDelay", "DepDelay", "TaxiIn", "TaxiOut", "Cancelled", "CancellationCode", "Diverted", "CarrierDelay", "WeatherDelay", "NASDelay", "SecurityDelay", "LateAircraftDelay", "IsDepDelayed"};
@Test public void testGBMRegression() {
File file = TestUtil.find_test_file("./smalldata/gbm_test/Mfgdata_gaussian_GBM_testing.csv");
Key fkey = NFSFileVec.make(file);
Key dest = Key.make("mfg.hex");
GBM gbm = new GBM(); // The Builder
GBM.GBMModel gbmmodel = null; // The Model
try {
Frame fr = gbm.source = ParseDataset2.parse(dest,new Key[]{fkey});
UKV.remove(fkey);
gbm.classification = false; // Regression
gbm.family = GBM.Family.AUTO;
gbm.response = fr.vecs()[1]; // Row in col 0, dependent in col 1, predictor in col 2
gbm.ntrees = 1;
gbm.max_depth = 1;
gbm.min_rows = 1;
gbm.nbins = 20;
gbm.cols = new int[]{2}; // Just column 2
gbm.validation = null;
gbm.learn_rate = 1.0f;
gbm.score_each_iteration=true;
gbm.invoke();
gbmmodel = UKV.get(gbm.dest());
Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE); //HEX-1817
Frame preds = gbm.score(gbm.source);
double sq_err = new CompErr().doAll(gbm.response,preds.vecs()[0])._sum;
double mse = sq_err/preds.numRows();
assertEquals(79152.1233,mse,0.1);
preds.delete();
} finally {
gbm.source.delete(); // Remove original hex frame key
if (gbm.validation != null) gbm.validation.delete(); // Remove validation dataset if specified
if( gbmmodel != null ) gbmmodel.delete(); // Remove the model
gbm.remove(); // Remove GBM Job
}
}
private static class CompErr extends MRTask2<CompErr> {
double _sum;
@Override public void map( Chunk resp, Chunk pred ) {
double sum = 0;
for( int i=0; i<resp._len; i++ ) {
double err = resp.at(i)-pred.at(i);
sum += err*err;
}
_sum = sum;
}
@Override public void reduce( CompErr ce ) { _sum += ce._sum; }
}
@Test public void testBasicGBM() {
// Regression tests
basicGBM("./smalldata/cars.csv","cars.hex",
new PrepData() { int prep(Frame fr ) { UKV.remove(fr.remove("name")._key); return ~fr.find("economy (mpg)"); }});
// Classification tests
basicGBM("./smalldata/test/test_tree.csv","tree.hex",
new PrepData() { int prep(Frame fr) { return 1; }
});
basicGBM("./smalldata/test/test_tree_minmax.csv","tree_minmax.hex",
new PrepData() { int prep(Frame fr) { return fr.find("response"); }
});
basicGBM("./smalldata/logreg/prostate.csv","prostate.hex",
new PrepData() {
int prep(Frame fr) {
assertEquals(380,fr.numRows());
// Remove patient ID vector
UKV.remove(fr.remove("ID")._key);
// Prostate: predict on CAPSULE
return fr.find("CAPSULE");
}
});
basicGBM("./smalldata/cars.csv","cars.hex",
new PrepData() { int prep(Frame fr) { UKV.remove(fr.remove("name")._key); return fr.find("cylinders"); }
});
basicGBM("./smalldata/airlines/allyears2k_headers.zip","air.hex",
new PrepData() { int prep(Frame fr) {
for( String s : ignored_aircols ) UKV.remove(fr.remove(s)._key);
return fr.find("IsArrDelayed"); }
});
//basicGBM("../datasets/UCI/UCI-large/covtype/covtype.data","covtype.hex",
// new PrepData() {
// int prep(Frame fr) {
// assertEquals(581012,fr.numRows());
// for( int ign : IGNS )
// UKV.remove(fr.remove("C"+Integer.toString(ign))._key);
// // Covtype: predict on last column
// return fr.numCols()-1;
// }
// });
}
@Test public void testBasicGBMFamily() {
Scope.enter();
// Classification with Bernoulli family
basicGBM("./smalldata/logreg/prostate.csv","prostate.hex",
new PrepData() {
int prep(Frame fr) {
assertEquals(380,fr.numRows());
// Remove patient ID vector
UKV.remove(fr.remove("ID")._key);
// Change CAPSULE and RACE to categoricals
Scope.track(fr.factor(fr.find("CAPSULE"))._key);
Scope.track(fr.factor(fr.find("RACE" ))._key);
// Prostate: predict on CAPSULE
return fr.find("CAPSULE");
}
}, Family.bernoulli);
Scope.exit();
}
// ==========================================================================
public GBMModel basicGBM(String fname, String hexname, PrepData prep) {
return basicGBM(fname, hexname, prep, false, Family.AUTO);
}
public GBMModel basicGBM(String fname, String hexname, PrepData prep, boolean validation) {
return basicGBM(fname, hexname, prep, validation, Family.AUTO);
}
public GBMModel basicGBM(String fname, String hexname, PrepData prep, Family family) {
return basicGBM(fname, hexname, prep, false, family);
}
public GBMModel basicGBM(String fname, String hexname, PrepData prep, boolean validation, Family family) {
File file = TestUtil.find_test_file(fname);
if( file == null ) return null; // Silently abort test if the file is missing
Key fkey = NFSFileVec.make(file);
Key dest = Key.make(hexname);
GBM gbm = new GBM(); // The Builder
GBM.GBMModel gbmmodel = null; // The Model
try {
Frame fr = gbm.source = ParseDataset2.parse(dest,new Key[]{fkey});
UKV.remove(fkey);
int idx = prep.prep(fr);
if( idx < 0 ) { gbm.classification = false; idx = ~idx; }
gbm.response = fr.vecs()[idx];
gbm.family = family;
assert gbm.family != Family.bernoulli || gbm.classification;
gbm.ntrees = 4;
gbm.max_depth = 4;
gbm.min_rows = 1;
gbm.nbins = 50;
gbm.cols = new int[fr.numCols()];
for( int i=0; i<gbm.cols.length; i++ ) gbm.cols[i]=i;
gbm.validation = validation ? new Frame(gbm.source) : null;
gbm.learn_rate = .2f;
gbm.score_each_iteration=true;
gbm.invoke();
gbmmodel = UKV.get(gbm.dest());
testHTML(gbmmodel);
Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE); //HEX-1817
//System.out.println(gbmmodel.toJava());
Frame preds = gbm.score(gbm.source);
preds.delete();
return gbmmodel;
} finally {
gbm.source.delete(); // Remove original hex frame key
if (gbm.validation != null) gbm.validation.delete(); // Remove validation dataset if specified
if( gbmmodel != null ) gbmmodel.delete(); // Remove the model
gbm.remove(); // Remove GBM Job
}
}
// Test-on-Train. Slow test, needed to build a good model.
@Test public void testGBMTrainTest() {
File file1 = TestUtil.find_test_file("smalldata/gbm_test/ecology_model.csv");
if( file1 == null ) return; // Silently ignore if file not found
Key fkey1 = NFSFileVec.make(file1);
Key dest1 = Key.make("train.hex");
File file2 = TestUtil.find_test_file("smalldata/gbm_test/ecology_eval.csv");
Key fkey2 = NFSFileVec.make(file2);
Key dest2 = Key.make("test.hex");
GBM gbm = new GBM(); // The Builder
GBM.GBMModel gbmmodel = null; // The Model
Frame ftest = null, fpreds = null;
try {
Frame fr = ParseDataset2.parse(dest1,new Key[]{fkey1});
UKV.remove(fr.remove("Site")._key); // Remove unique ID; too predictive
gbm.response = fr.vecs()[fr.find("Angaus")]; // Train on the outcome
gbm.source = fr;
gbm.ntrees = 5;
gbm.max_depth = 10;
gbm.learn_rate = 0.2f;
gbm.min_rows = 10;
gbm.nbins = 100;
gbm.invoke();
gbmmodel = UKV.get(gbm.dest());
testHTML(gbmmodel);
Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE); //HEX-1817
// Test on the train data
ftest = ParseDataset2.parse(dest2,new Key[]{fkey2});
fpreds = gbm.score(ftest);
// Build a confusion matrix
ConfusionMatrix CM = new ConfusionMatrix();
CM.actual = ftest;
CM.vactual = ftest.vecs()[ftest.find("Angaus")];
CM.predict = fpreds;
CM.vpredict = fpreds.vecs()[fpreds.find("predict")];
CM.invoke(); // Start it, do it
StringBuilder sb = new StringBuilder();
CM.toASCII(sb);
System.out.println(sb);
} finally {
gbm.source.delete(); // Remove the original hex frame key
if( ftest != null ) ftest .delete();
if( fpreds != null ) fpreds.delete();
if( gbmmodel != null ) gbmmodel.delete(); // Remove the model
UKV.remove(gbm.response._key);
gbm.remove(); // Remove GBM Job
}
}
// Adapt a trained model to a test dataset with different enums
@Test public void testModelAdapt() {
File file1 = TestUtil.find_test_file("./smalldata/kaggle/KDDTrain.arff.gz");
Key fkey1 = NFSFileVec.make(file1);
Key dest1 = Key.make("KDDTrain.hex");
File file2 = TestUtil.find_test_file("./smalldata/kaggle/KDDTest.arff.gz");
Key fkey2 = NFSFileVec.make(file2);
Key dest2 = Key.make("KDDTest.hex");
GBM gbm = new GBM();
GBM.GBMModel gbmmodel = null; // The Model
try {
gbm.source = ParseDataset2.parse(dest1,new Key[]{fkey1});
gbm.response = gbm.source.vecs()[41]; // Response is col 41
gbm.ntrees = 2;
gbm.max_depth = 8;
gbm.learn_rate = 0.2f;
gbm.min_rows = 10;
gbm.nbins = 50;
gbm.invoke();
gbmmodel = UKV.get(gbm.dest());
testHTML(gbmmodel);
Assert.assertTrue(gbmmodel.get_params().state == Job.JobState.DONE); //HEX-1817
// The test data set has a few more enums than the train
Frame ftest = ParseDataset2.parse(dest2,new Key[]{fkey2});
Frame preds = gbm.score(ftest);
ftest.delete();
preds.delete();
} finally {
if( gbmmodel != null ) gbmmodel.delete(); // Remove the model
gbm.source.delete(); // Remove original hex frame key
UKV.remove(gbm.response._key);
gbm.remove(); // Remove GBM Job
}
}
// A test of locking the input dataset during model building.
@Test public void testModelLock() {
GBM gbm = new GBM();
try {
Frame fr = gbm.source = parseFrame(Key.make("air.hex"),"./smalldata/airlines/allyears2k_headers.zip");
for( String s : ignored_aircols ) UKV.remove(fr.remove(s)._key);
int idx = fr.find("IsArrDelayed");
gbm.response = fr.vecs()[idx];
gbm.ntrees = 10;
gbm.max_depth = 5;
gbm.min_rows = 1;
gbm.nbins = 20;
gbm.cols = new int[fr.numCols()];
for( int i=0; i<gbm.cols.length; i++ ) gbm.cols[i]=i;
gbm.learn_rate = .2f;
gbm.fork();
try { Thread.sleep(100); } catch( Exception ignore ) { }
try {
fr.delete(); // Attempted delete while model-build is active
throw H2O.fail(); // Should toss IAE instead of reaching here
} catch( IllegalArgumentException ignore ) {
} catch( DException.DistributedException de ) {
assertTrue( de.getMessage().contains("java.lang.IllegalArgumentException") );
}
GBM.GBMModel model = gbm.get();
Assert.assertTrue(model.get_params().state == Job.JobState.DONE); //HEX-1817
if( model != null ) model.delete();
} finally {
if( gbm.source != null ) gbm.source.delete(gbm.self(),0.0f); // Remove original hex frame key
gbm.remove(); // Remove GBM Job
}
}
// MSE generated by GBM with/without validation dataset should be same
@Test public void testModelMSEEqualityOnProstate() {
final PrepData prostatePrep =
new PrepData() {
@Override int prep(Frame fr) {
assertEquals(380,fr.numRows());
// Remove patient ID vector
UKV.remove(fr.remove("ID")._key);
// Prostate: predict on CAPSULE
return fr.find("CAPSULE");
}
};
double[] mseWithoutVal = basicGBM("./smalldata/logreg/prostate.csv","prostate.hex", prostatePrep, false).errs;
double[] mseWithVal = basicGBM("./smalldata/logreg/prostate.csv","prostate.hex", prostatePrep, true ).errs;
Assert.assertArrayEquals("GBM has to report same list of MSEs for run without/with validation dataset (which is equal to training data)", mseWithoutVal, mseWithVal, 0.0001);
}
@Ignore
@Test public void testModelMSEEqualityOnTitanic() {
final PrepData titanicPrep =
new PrepData() {
@Override int prep(Frame fr) {
assertEquals(1309,fr.numRows());
// Airlines: predict on CAPSULE
return fr.find("survived");
}
};
double[] mseWithoutVal = basicGBM("./smalldata/titanicalt.csv","titanic.hex", titanicPrep, false).errs;
double[] mseWithVal = basicGBM("./smalldata/titanicalt.csv","titanic.hex", titanicPrep, true ).errs;
Assert.assertArrayEquals("GBM has to report same list of MSEs for run without/with validation dataset (which is equal to training data)", mseWithoutVal, mseWithVal, 0.0001);
}
@Test public void testReproducibility() {
Frame tfr=null;
final int N = 5;
double[] mses = new double[N];
Scope.enter();
try {
// Load data, hack frames
tfr = parseFrame(Key.make("air.hex"), "./smalldata/covtype/covtype.20k.data");
// rebalance to 256 chunks
Key dest = Key.make("df.rebalanced.hex");
RebalanceDataSet rb = new RebalanceDataSet(tfr, dest, 256);
H2O.submitTask(rb);
rb.join();
tfr.delete();
tfr = DKV.get(dest).get();
for (int i=0; i<N; ++i) {
GBM parms = new GBM();
parms.source = tfr;
parms.response = tfr.lastVec();
parms.nbins = 1000;
parms.ntrees = 1;
parms.max_depth = 8;
parms.learn_rate = 0.1;
parms.min_rows = 10;
parms.family = Family.AUTO;
parms.classification = false;
// Build a first model; all remaining models should be equal
GBMModel gbm = parms.fork().get();
mses[i] = gbm.mse();
gbm.delete();
}
} finally{
if (tfr != null) tfr.delete();
}
Scope.exit();
for (int i=0; i<mses.length; ++i) {
Log.info("trial: " + i + " -> mse: " + mses[i]);
}
for (int i=0; i<mses.length; ++i) {
assertEquals(mses[i], mses[0], 1e-15);
}
}
}