package hex.gbm;
import junit.framework.Assert;
import hex.gbm.GBM.GBMModel;
import hex.trees.TreeTestWithBalanceAndCrossVal;
import org.junit.BeforeClass;
import water.*;
import water.fvec.Frame;
import water.fvec.Vec;
/** Test for advanced GBM workflows including data rebalancing and cross validation. */
public class GBMTest2 extends TreeTestWithBalanceAndCrossVal {
@BeforeClass public static void stall() { stall_till_cloudsize(1); }
@Override protected void testBalanceWithCrossValidation(String dataset, int response, int[] ignored_cols, int ntrees, int nfolds) {
Frame f = parseFrame(dataset);
GBMModel model = null;
GBM gbm = new GBM();
try {
Vec respVec = f.vec(response);
// Build a model
gbm.source = f;
gbm.response = respVec;
gbm.ignored_cols = ignored_cols;
gbm.classification = true;
gbm.ntrees = ntrees;
gbm.balance_classes = true;
gbm.n_folds = nfolds;
gbm.keep_cross_validation_splits = false;
gbm.invoke();
Assert.assertEquals("Number of cross validation model is wrond!", nfolds, gbm.xval_models.length);
model = UKV.get(gbm.dest());
Assert.assertTrue(model.get_params().state == Job.JobState.DONE); //HEX-1817
} finally {
if (f!=null) f.delete();
if (model!=null) {
if (gbm.xval_models!=null) {
for (Key k : gbm.xval_models) {
Model m = UKV.get(k);
m.delete();
}
}
model.delete();
}
}
}
}