package hex.tree.gbm; import hex.tree.CompressedTree; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Ignore; import org.junit.Test; import water.DKV; import water.Key; import water.TestUtil; import water.exceptions.H2OIllegalArgumentException; import water.fvec.Frame; import water.fvec.Vec; import water.util.VecUtils; import static water.ModelSerializationTest.assertTreeEquals; import static water.ModelSerializationTest.getTrees; public class GBMCheckpointTest extends TestUtil { @BeforeClass public static void stall() { stall_till_cloudsize(1); } @Test public void testCheckpointReconstruction4Multinomial() { testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3); } @Test public void testCheckpointReconstruction4Multinomial2() { testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, true, 5, 3); } @Test public void testCheckpointReconstruction4Binomial() { testCheckPointReconstruction("smalldata/logreg/prostate.csv", 1, true, 5, 3); } @Test public void testCheckpointReconstruction4Binomial2() { testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 7, true, 2, 2); } /** Test throwing the right exception if non-modifiable parameter is specified. */ @Test(expected = H2OIllegalArgumentException.class) @Ignore public void testCheckpointWrongParams() { testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3, 0.2f, 0.67f); } @Test public void testCheckpointReconstruction4Regression() { testCheckPointReconstruction("smalldata/logreg/prostate.csv", 8, false, 5, 3); } @Test public void testCheckpointReconstruction4Regression2() { testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, false, 5, 3); } private void testCheckPointReconstruction(String dataset, int responseIdx, boolean classification, int ntreesInPriorModel, int ntreesInNewModel) { testCheckPointReconstruction(dataset, responseIdx, classification, ntreesInPriorModel, ntreesInNewModel, 0.632f, 0.632f); } private void testCheckPointReconstruction(String dataset, int responseIdx, boolean classification, int ntreesInPriorModel, int ntreesInNewModel, float sampleRateInPriorModel, float sampleRateInNewModel) { Frame f = parse_test_file(dataset); Vec v = f.remove("economy"); if (v!=null) v.remove(); //avoid overfitting for binomial case for cars dataset DKV.put(f); // If classification turn response into categorical if (classification) { Vec respVec = f.vec(responseIdx); f.replace(responseIdx, VecUtils.toCategoricalVec(respVec)).remove(); DKV.put(f._key, f); } GBMModel model = null; GBMModel modelFromCheckpoint = null; GBMModel modelFinal = null; try { GBMModel.GBMParameters gbmParams = new GBMModel.GBMParameters(); gbmParams._train = f._key; gbmParams._response_column = f.name(responseIdx); gbmParams._ntrees = ntreesInPriorModel; gbmParams._seed = 42; gbmParams._max_depth = 5; gbmParams._learn_rate_annealing = 0.9; gbmParams._score_each_iteration = true; model = new GBM(gbmParams, Key.<GBMModel>make("Initial model") ).trainModel().get(); GBMModel.GBMParameters gbmFromCheckpointParams = new GBMModel.GBMParameters(); gbmFromCheckpointParams._train = f._key; gbmFromCheckpointParams._response_column = f.name(responseIdx); gbmFromCheckpointParams._ntrees = ntreesInPriorModel + ntreesInNewModel; gbmFromCheckpointParams._seed = 42; gbmFromCheckpointParams._checkpoint = model._key; gbmFromCheckpointParams._score_each_iteration = true; gbmFromCheckpointParams._max_depth = 5; gbmFromCheckpointParams._learn_rate_annealing = 0.9; modelFromCheckpoint = new GBM(gbmFromCheckpointParams,Key.<GBMModel>make("Model from checkpoint")).trainModel().get(); // Compute a separated model containing the same numnber of trees as a model built from checkpoint GBMModel.GBMParameters gbmFinalParams = new GBMModel.GBMParameters(); gbmFinalParams._train = f._key; gbmFinalParams._response_column = f.name(responseIdx); gbmFinalParams._ntrees = ntreesInPriorModel + ntreesInNewModel; gbmFinalParams._seed = 42; gbmFinalParams._score_each_iteration = true; gbmFinalParams._max_depth = 5; gbmFinalParams._learn_rate_annealing = 0.9; modelFinal = new GBM(gbmFinalParams,Key.<GBMModel>make("Validation model")).trainModel().get(); // System.err.println(modelFromCheckpoint.toJava(false,true)); // System.err.println(modelFinal.toJava(false,true)); CompressedTree[][] treesFromCheckpoint = getTrees(modelFromCheckpoint); CompressedTree[][] treesFromFinalModel = getTrees(modelFinal); assertTreeEquals("The model created from checkpoint and corresponding model created from scratch should have the same trees!", treesFromCheckpoint, treesFromFinalModel, true); // Make sure we are not re-using trees for (int tree = 0; tree < treesFromCheckpoint.length; tree++) { for (int clazz = 0; clazz < treesFromCheckpoint[tree].length; clazz++) { if (treesFromCheckpoint[tree][clazz] !=null) { // We already verify equality of models CompressedTree a = treesFromCheckpoint[tree][clazz]; CompressedTree b = treesFromFinalModel[tree][clazz]; Assert.assertNotEquals(a._key, b._key); } } } } finally { if (f!=null) f.delete(); if (model!=null) model.delete(); if (modelFromCheckpoint!=null) modelFromCheckpoint.delete(); if (modelFinal!=null) modelFinal.delete(); } } @Ignore("PUBDEV-1829") public void testCheckpointReconstruction4BinomialPUBDEV1829() { Frame tr = parse_test_file("smalldata/jira/gbm_checkpoint_train.csv"); Frame val = parse_test_file("smalldata/jira/gbm_checkpoint_valid.csv"); Vec old = null; tr.remove("name").remove(); tr.remove("economy").remove(); val.remove("name").remove(); val.remove("economy").remove(); old = tr.remove("economy_20mpg"); tr.add("economy_20mpg", old); DKV.put(tr); old = val.remove("economy_20mpg"); val.add("economy_20mpg", old); DKV.put(val); GBMModel model = null; GBMModel modelFromCheckpoint = null; GBMModel modelFinal = null; try { GBMModel.GBMParameters gbmParams = new GBMModel.GBMParameters(); gbmParams._train = tr._key; gbmParams._valid = val._key; gbmParams._response_column = "economy_20mpg"; gbmParams._ntrees = 5; gbmParams._max_depth = 5; gbmParams._min_rows = 10; gbmParams._score_each_iteration = true; gbmParams._seed = 42; model = new GBM(gbmParams,Key.<GBMModel>make("Initial model")).trainModel().get(); GBMModel.GBMParameters gbmFromCheckpointParams = new GBMModel.GBMParameters(); gbmFromCheckpointParams._train = tr._key; gbmFromCheckpointParams._valid = val._key; gbmFromCheckpointParams._response_column = "economy_20mpg"; gbmFromCheckpointParams._ntrees = 10; gbmFromCheckpointParams._checkpoint = model._key; gbmFromCheckpointParams._score_each_iteration = true; gbmFromCheckpointParams._max_depth = 5; gbmFromCheckpointParams._min_rows = 10; gbmFromCheckpointParams._seed = 42; modelFromCheckpoint = new GBM(gbmFromCheckpointParams,Key.<GBMModel>make("Model from checkpoint")).trainModel().get(); // Compute a separated model containing the same number of trees as a model built from checkpoint GBMModel.GBMParameters gbmFinalParams = new GBMModel.GBMParameters(); gbmFinalParams._train = tr._key; gbmFinalParams._valid = val._key; gbmFinalParams._response_column = "economy_20mpg"; gbmFinalParams._ntrees = 10; gbmFinalParams._score_each_iteration = true; gbmFinalParams._max_depth = 5; gbmFinalParams._min_rows = 10; gbmFinalParams._seed = 42; modelFinal = new GBM(gbmFinalParams,Key.<GBMModel>make("Validation model")).trainModel().get(); CompressedTree[][] treesFromCheckpoint = getTrees(modelFromCheckpoint); CompressedTree[][] treesFromFinalModel = getTrees(modelFinal); assertTreeEquals("The model created from checkpoint and corresponding model created from scratch should have the same trees!", treesFromCheckpoint, treesFromFinalModel, true); } finally { if (tr!=null) tr.delete(); if (val!=null) val.delete(); if (old != null) old.remove(); if (model!=null) model.delete(); if (modelFromCheckpoint!=null) modelFromCheckpoint.delete(); if (modelFinal!=null) modelFinal.delete(); } } }