package hex.gbm; import hex.gbm.GBM.GBMModel; import static org.junit.Assert.*; import org.junit.Test; import water.*; import water.fvec.Frame; import water.fvec.Vec; public class GBMCheckpointTest extends TestUtil { // Test for multinomial @Test public void testCheckpointReconstruction4Multinomial() { testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3); } // Binomial model checkpointing @Test public void testCheckpointReconstruction4Binomial() { testCheckPointReconstruction("smalldata/logreg/prostate.csv", 1, true, 5, 3); } // And then test regression @Test public void testCheckpointReconstruction4Regression() { testCheckPointReconstruction("smalldata/logreg/prostate.csv", 8, false, 5, 3); } private void testCheckPointReconstruction(String dataset, int response, boolean classification, int ntreesInPriorModel, int ntreesInANewModel) { Frame f = parseFrame(dataset); GBMModel model = null; GBMModel modelFromCheckpoint = null; GBMModel modelFinal = null; try { Vec respVec = f.vec(response); // Build a model GBMWithHooks gbm = new GBMWithHooks(); gbm.source = f; gbm.response = respVec; gbm.classification = classification; gbm.ntrees = ntreesInPriorModel; gbm.collectPoint = WhereToCollect.AFTER_BUILD; gbm.score_each_iteration = true; gbm.invoke(); model = UKV.get(gbm.dest()); // Build a checkpointed model GBMWithHooks gbmFromCheckpoint = new GBMWithHooks(); gbmFromCheckpoint.source = f; gbmFromCheckpoint.response = respVec; gbmFromCheckpoint.classification = classification; gbmFromCheckpoint.ntrees = ntreesInANewModel; gbmFromCheckpoint.collectPoint = WhereToCollect.AFTER_RECONSTRUCTION; gbmFromCheckpoint.checkpoint = gbm.dest(); gbmFromCheckpoint.score_each_iteration = true; gbmFromCheckpoint.invoke(); modelFromCheckpoint = UKV.get(gbmFromCheckpoint.dest()); // Check if reconstructed frame computation data are same assertArrayEquals("Tree data produced by drf run and reconstructed from a model do not match!", gbm.treesCols, gbmFromCheckpoint.treesCols); // Build a model which contains old+new trees and compare prediction results GBM gbmFinal = new GBM(); gbmFinal.source = f; gbmFinal.response = respVec; gbmFinal.classification = classification; gbmFinal.ntrees = ntreesInANewModel + ntreesInPriorModel; gbmFinal.score_each_iteration = true; gbmFinal.invoke(); modelFinal = UKV.get(gbmFinal.dest()); assertTreeModelEquals(modelFinal, modelFromCheckpoint); } finally { if (f!=null) f.delete(); if (model!=null) model.delete(); if (modelFromCheckpoint!=null) modelFromCheckpoint.delete(); if (modelFinal!=null) modelFinal.delete(); } } private enum WhereToCollect { NONE, AFTER_BUILD, AFTER_RECONSTRUCTION } static class GBMWithHooks extends GBM { WhereToCollect collectPoint; public float[][] treesCols; @Override protected void initWorkFrame(GBMModel initialModel, Frame fr) { super.initWorkFrame(initialModel, fr); if (collectPoint==WhereToCollect.AFTER_RECONSTRUCTION) { //debugPrintTreeColumns(fr); treesCols = collectTreeCols(fr); } } // Collect ntrees temporary results in expensive way @Override protected void cleanUp(Frame fr, Timer t_build) { if (collectPoint==WhereToCollect.AFTER_BUILD) { //debugPrintTreeColumns(fr); treesCols = collectTreeCols(fr); } super.cleanUp(fr, t_build); } private float[][] collectTreeCols(Frame fr) { float[][] r = new float[(int) _nrows][_nclass]; for (int c=0; c<_nclass; c++) { Vec ctree = vec_tree(fr, c); for (int row=0; row<_nrows; row++) { r[row][c] = (float) ctree.at(row); } } return r; } } }