package hex.tree.drf;
import hex.tree.CompressedTree;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.TestUtil;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import static water.ModelSerializationTest.assertTreeEquals;
import static water.ModelSerializationTest.getTrees;
public class DRFCheckpointTest extends TestUtil {
@BeforeClass public static void stall() { stall_till_cloudsize(1); }
/** Test if reconstructed initial frame match the last iteration
* of DRF model builder.
*
* <p>This test verify multinominal model.</p>
*/
@Test
public void testCheckpointReconstruction4Multinomial() {
testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3);
}
@Test
public void testCheckpointReconstruction4Multinomial2() {
testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, true, 5, 3);
}
/** Test if reconstructed initial frame match the last iteration
* of DRF model builder.
*
* <p>This test verify binominal model.</p>
*/
@Test
public void testCheckpointReconstruction4Binomial() {
testCheckPointReconstruction("smalldata/logreg/prostate.csv", 1, true, 5, 3);
}
@Test
public void testCheckpointReconstruction4Binomial2() {
testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 7, true, 1, 1);
}
/** Test throwing the right exception if non-modifiable parameter is specified.
*/
@Test(expected = H2OIllegalArgumentException.class)
public void testCheckpointWrongParams() {
testCheckPointReconstruction("smalldata/iris/iris.csv", 4, true, 5, 3, 0.2f, 0.67f);
}
/** Test if reconstructed initial frame match the last iteration
* of DRF model builder.
*
* <p>This test verify regression model.</p>
*/
@Test
public void testCheckpointReconstruction4Regression() {
testCheckPointReconstruction("smalldata/logreg/prostate.csv", 8, false, 4, 3);
}
@Test
public void testCheckpointReconstruction4Regression2() {
testCheckPointReconstruction("smalldata/junit/cars_20mpg.csv", 1, false, 4, 3);
}
private void testCheckPointReconstruction(String dataset,
int responseIdx,
boolean classification,
int ntreesInPriorModel, int ntreesInNewModel) {
testCheckPointReconstruction(dataset, responseIdx, classification, ntreesInPriorModel, ntreesInNewModel, 0.632f, 0.632f);
}
private void testCheckPointReconstruction(String dataset,
int responseIdx,
boolean classification,
int ntreesInPriorModel, int ntreesInNewModel,
float sampleRateInPriorModel, float sampleRateInNewModel) {
Frame f = parse_test_file(dataset);
Vec v = f.remove("economy"); if (v!=null) v.remove(); //avoid overfitting for binomial case for cars dataset
DKV.put(f);
// If classification turn response into categorical
if (classification) {
Vec respVec = f.vec(responseIdx);
f.replace(responseIdx, respVec.toCategoricalVec()).remove();
DKV.put(f._key, f);
}
DRFModel model = null;
DRFModel modelFromCheckpoint = null;
DRFModel modelFinal = null;
try {
DRFModel.DRFParameters drfParams = new DRFModel.DRFParameters();
drfParams._train = f._key;
drfParams._response_column = f.name(responseIdx);
drfParams._ntrees = ntreesInPriorModel;
drfParams._seed = 42;
drfParams._max_depth = 10;
drfParams._score_each_iteration = true;
drfParams._sample_rate = sampleRateInPriorModel;
model = new DRF(drfParams,Key.<DRFModel>make("Initial model")).trainModel().get();
DRFModel.DRFParameters drfFromCheckpointParams = new DRFModel.DRFParameters();
drfFromCheckpointParams._train = f._key;
drfFromCheckpointParams._response_column = f.name(responseIdx);
drfFromCheckpointParams._ntrees = ntreesInPriorModel + ntreesInNewModel;
drfFromCheckpointParams._seed = 42;
drfFromCheckpointParams._checkpoint = model._key;
drfFromCheckpointParams._score_each_iteration = true;
drfFromCheckpointParams._max_depth = 10;
drfFromCheckpointParams._sample_rate = sampleRateInNewModel;
modelFromCheckpoint = new DRF(drfFromCheckpointParams,Key.<DRFModel>make("Model from checkpoint")).trainModel().get();
// Compute a separated model containing the same number of trees as a model built from checkpoint
DRFModel.DRFParameters drfFinalParams = new DRFModel.DRFParameters();
drfFinalParams._train = f._key;
drfFinalParams._response_column = f.name(responseIdx);
drfFinalParams._ntrees = ntreesInPriorModel + ntreesInNewModel;
drfFinalParams._seed = 42;
drfFinalParams._score_each_iteration = true;
drfFinalParams._max_depth = 10;
drfFinalParams._sample_rate = sampleRateInPriorModel;
modelFinal = new DRF(drfFinalParams,Key.<DRFModel>make("Validation model")).trainModel().get();
CompressedTree[][] treesFromCheckpoint = getTrees(modelFromCheckpoint);
CompressedTree[][] treesFromFinalModel = getTrees(modelFinal);
assertTreeEquals("The model created from checkpoint and corresponding model created from scratch should have the same trees!",
treesFromCheckpoint, treesFromFinalModel, true);
// Make sure we are not re-using trees
for (int tree = 0; tree < treesFromCheckpoint.length; tree++) {
for (int clazz = 0; clazz < treesFromCheckpoint[tree].length; clazz++) {
if (treesFromCheckpoint[tree][clazz] !=null) { // We already verify equality of models
CompressedTree a = treesFromCheckpoint[tree][clazz];
CompressedTree b = treesFromFinalModel[tree][clazz];
Assert.assertNotEquals(a._key, b._key);
}
}
}
} finally {
if (f!=null) f.delete();
if (model!=null) model.delete();
if (modelFromCheckpoint!=null) modelFromCheckpoint.delete();
if (modelFinal!=null) modelFinal.delete();
}
}
}