package hex.drf; import hex.drf.DRF.DRFModel; import org.junit.*; import water.*; import water.fvec.Frame; import water.fvec.Vec; public class DRFCheckpointTest extends TestUtil { /** Test if reconstructed initial frame match the last iteration * of DRF model builder. * * <p>This test verify multinominal model.</p> */ @Test public void testCheckpointReconstruction4Multinomial() { testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3); } /** Test if reconstructed initial frame match the last iteration * of DRF model builder. * * <p>This test verify binominal model.</p> */ @Test public void testCheckpointReconstruction4Binomial() { testCheckPointReconstruction("smalldata/logreg/prostate.csv", 1, true, 5, 3); } /** Test if reconstructed initial frame match the last iteration * of DRF model builder. * * <p>This test verify regression model.</p> */ @Test public void testCheckpointReconstruction4Regression() { testCheckPointReconstruction("smalldata/logreg/prostate.csv", 8, false, 5, 3); } private void testCheckPointReconstruction(String dataset, int response, boolean classification, int ntreesInPriorModel, int ntreesInANewModel) { Frame f = parseFrame(dataset); DRFModel model = null; DRFModel modelFromCheckpoint = null; DRFModel modelFinal = null; try { Vec respVec = f.vec(response); // Build a model DRFWithHooks drf = new DRFWithHooks(); drf.source = f; drf.response = respVec; drf.classification = classification; drf.ntrees = ntreesInPriorModel; drf.collectPoint = WhereToCollect.AFTER_BUILD; drf.seed = 42; drf.invoke(); model = UKV.get(drf.dest()); DRFWithHooks drfFromCheckpoint = new DRFWithHooks(); drfFromCheckpoint.source = f; drfFromCheckpoint.response = respVec; drfFromCheckpoint.classification = classification; drfFromCheckpoint.ntrees = ntreesInANewModel; drfFromCheckpoint.collectPoint = WhereToCollect.AFTER_RECONSTRUCTION; drfFromCheckpoint.checkpoint = drf.dest(); drfFromCheckpoint.seed = 42; drfFromCheckpoint.invoke(); modelFromCheckpoint = UKV.get(drf.dest()); //System.err.println(Arrays.toString(modelFromCheckpoint.errs)); Assert.assertArrayEquals("Tree data produced by drf run and reconstructed from a model do not match!", drf.treesCols, drfFromCheckpoint.treesCols); DRF drfFinal = new DRF(); drfFinal.source = f; drfFinal.response = respVec; drfFinal.classification = classification; drfFinal.ntrees = ntreesInANewModel + ntreesInPriorModel; drfFinal.score_each_iteration = true; drfFinal.seed = 42; drfFinal.invoke(); modelFinal = UKV.get(drfFinal.dest()); //System.err.println(Arrays.toString(modelFinal.errs)); // Compare resulting model with the model produced from checkpoint assertTreeModelEquals(modelFinal, modelFromCheckpoint); } finally { if (f!=null) f.delete(); if (model!=null) model.delete(); if (modelFromCheckpoint!=null) modelFromCheckpoint.delete(); if (modelFinal!=null) modelFinal.delete(); } } private enum WhereToCollect { NONE, AFTER_BUILD, AFTER_RECONSTRUCTION } // Helper class with a hook to collect tree cols static class DRFWithHooks extends DRF { WhereToCollect collectPoint; public float[][] treesCols; @Override protected void initWorkFrame(DRFModel initialModel, Frame fr) { super.initWorkFrame(initialModel, fr); if (collectPoint==WhereToCollect.AFTER_RECONSTRUCTION) treesCols = collectTreeCols(fr); } // Collect ntrees temporary results in expensive way @Override protected void cleanUp(Frame fr, Timer t_build) { if (collectPoint==WhereToCollect.AFTER_BUILD) treesCols = collectTreeCols(fr); super.cleanUp(fr, t_build); } private float[][] collectTreeCols(Frame fr) { float[][] r = new float[(int) _nrows][_nclass]; for (int c=0; c<_nclass; c++) { Vec ctree = vec_tree(fr, c); for (int row=0; row<_nrows; row++) { r[row][c] = ctree.at8(row); } } return r; } } }