package water; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import hex.Model; import hex.ModelMetrics; import hex.glm.GLM; import hex.glm.GLMModel; import hex.tree.CompressedTree; import hex.tree.SharedTreeModel; import hex.tree.drf.DRF; import hex.tree.drf.DRFModel; import hex.tree.gbm.GBM; import hex.tree.gbm.GBMModel; import water.fvec.Frame; import static org.junit.Assert.assertArrayEquals; public class ModelSerializationTest extends TestUtil { @BeforeClass() public static void setup() { stall_till_cloudsize(1); } private static String[] ESA = new String[] {}; @Test public void testSimpleModel() throws IOException { // Create a model BlahModel.BlahParameters params = new BlahModel.BlahParameters(); BlahModel.BlahOutput output = new BlahModel.BlahOutput(false, false, false); Model model = new BlahModel(Key.make("BLAHModel"), params, output); DKV.put(model._key, model); // Create a serializer, save a model and reload it Model loadedModel = null; try { loadedModel = saveAndLoad(model); // And compare assertModelBinaryEquals(model, loadedModel); } finally { if (loadedModel != null) loadedModel.delete(); } } @Test public void testGBMModelMultinomial() throws IOException { GBMModel model, loadedModel = null; try { model = prepareGBMModel("smalldata/iris/iris.csv", ESA, "C5", true, 5); CompressedTree[][] trees = getTrees(model); loadedModel = saveAndLoad(model); // And compare assertModelBinaryEquals(model, loadedModel); CompressedTree[][] loadedTrees = getTrees(loadedModel); assertTreeEquals("Trees have to be binary same", trees, loadedTrees); } finally { if (loadedModel!=null) loadedModel.delete(); } } @Test public void testGBMModelBinomial() throws IOException { GBMModel model, loadedModel = null; try { model = prepareGBMModel("smalldata/logreg/prostate.csv", ar("ID"), "CAPSULE", true, 5); CompressedTree[][] trees = getTrees(model); loadedModel = saveAndLoad(model); // And compare assertModelBinaryEquals(model, loadedModel); CompressedTree[][] loadedTrees = getTrees(loadedModel); assertTreeEquals("Trees have to be binary same", trees, loadedTrees); } finally { if (loadedModel!=null) loadedModel.delete(); } } @Test public void testDRFModelMultinomial() throws IOException { DRFModel model, loadedModel = null; try { model = prepareDRFModel("smalldata/iris/iris.csv", ESA, "C5", true, 5); CompressedTree[][] trees = getTrees(model); loadedModel = saveAndLoad(model); // And compare assertModelBinaryEquals(model, loadedModel); CompressedTree[][] loadedTrees = getTrees(loadedModel); assertTreeEquals("Trees have to be binary same", trees, loadedTrees); } finally { if (loadedModel!=null) loadedModel.delete(); } } @Test public void testDRFModelBinomial() throws IOException { DRFModel model = null, loadedModel = null; try { model = prepareDRFModel("smalldata/logreg/prostate.csv", ar("ID"), "CAPSULE", true, 5); CompressedTree[][] trees = getTrees(model); loadedModel = saveAndLoad(model); // And compare assertModelBinaryEquals(model, loadedModel); CompressedTree[][] loadedTrees = getTrees(loadedModel); assertTreeEquals("Trees have to be binary same", trees, loadedTrees); } finally { if (model!=null) model.delete(); if (loadedModel!=null) loadedModel.delete(); } } @Test public void testGLMModel() throws IOException { GLMModel model, loadedModel = null; try { model = prepareGLMModel("smalldata/junit/cars.csv", ESA, "power (hp)", GLMModel.GLMParameters.Family.poisson); loadedModel = saveAndLoad(model); assertModelBinaryEquals(model, loadedModel); } finally { if (loadedModel!=null) loadedModel.delete(); } } private GBMModel prepareGBMModel(String dataset, String[] ignoredColumns, String response, boolean classification, int ntrees) { Frame f = parse_test_file(dataset); try { if (classification && !f.vec(response).isCategorical()) { f.replace(f.find(response), f.vec(response).toCategoricalVec()).remove(); DKV.put(f._key, f); } GBMModel.GBMParameters gbmParams = new GBMModel.GBMParameters(); gbmParams._train = f._key; gbmParams._ignored_columns = ignoredColumns; gbmParams._response_column = response; gbmParams._ntrees = ntrees; gbmParams._score_each_iteration = true; return new GBM(gbmParams).trainModel().get(); } finally { if (f!=null) f.delete(); } } private DRFModel prepareDRFModel(String dataset, String[] ignoredColumns, String response, boolean classification, int ntrees) { Frame f = parse_test_file(dataset); try { if (classification && !f.vec(response).isCategorical()) { f.replace(f.find(response), f.vec(response).toCategoricalVec()).remove(); DKV.put(f._key, f); } DRFModel.DRFParameters drfParams = new DRFModel.DRFParameters(); drfParams._train = f._key; drfParams._ignored_columns = ignoredColumns; drfParams._response_column = response; drfParams._ntrees = ntrees; drfParams._score_each_iteration = true; return new DRF(drfParams).trainModel().get(); } finally { if (f!=null) f.delete(); } } private GLMModel prepareGLMModel(String dataset, String[] ignoredColumns, String response, GLMModel.GLMParameters.Family family) { Frame f = parse_test_file(dataset); try { GLMModel.GLMParameters params = new GLMModel.GLMParameters(); params._train = f._key; params._ignored_columns = ignoredColumns; params._response_column = response; params._family = family; return new GLM(params).trainModel().get(); } finally { if (f!=null) f.delete(); } } /** Dummy model to test model serialization */ static class BlahModel extends Model<BlahModel, BlahModel.BlahParameters, BlahModel.BlahOutput> { public BlahModel(Key selfKey, BlahParameters params, BlahOutput output) { super(selfKey, params, output); } @Override public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) { return null; } @Override protected double[] score0(double[] data, double[] preds) { return new double[0]; } static class BlahParameters extends Model.Parameters { public String algoName() { return "Blah"; } public String fullName() { return "Blah"; } public String javaName() { return BlahModel.class.getName(); } @Override public long progressUnits() { return 0; } } static class BlahOutput extends Model.Output { public BlahOutput(boolean hasWeights, boolean hasOffset, boolean hasFold) { super(hasWeights, hasOffset, hasFold); } } } private <M extends Model> M saveAndLoad(M model) throws IOException { return saveAndLoad(model,true); } // Serialize to and from a file private <M extends Model> M saveAndLoad(M model, boolean deleteModel) throws IOException { File file = File.createTempFile(model.getClass().getSimpleName(),null); try { model.writeAll(new AutoBuffer(new FileOutputStream(file),true)).close(); if( deleteModel ) model.delete(); final AutoBuffer ab = new AutoBuffer(new FileInputStream(file)); ab.sourceName = file.getAbsolutePath(); return (M)Keyed.readAll(ab); } finally { file.delete(); } } public static void assertModelBinaryEquals(Model a, Model b) { assertArrayEquals("The serialized models are not binary same!", a.write(new AutoBuffer()).buf(), b.write(new AutoBuffer()).buf()); } public static void assertIcedBinaryEquals(String msg, Iced a, Iced b) { if (a == null) { Assert.assertEquals(msg, null, b); } else { assertArrayEquals(msg, a.write(new AutoBuffer()).buf(), b.write(new AutoBuffer()).buf()); } } public static void assertTreeEquals(String msg, CompressedTree[][] a, CompressedTree[][] b) { assertTreeEquals(msg, a, b, false); } public static void assertTreeEquals(String msg, CompressedTree[][] a, CompressedTree[][] b, boolean ignoreKeyField) { Assert.assertEquals("Number of trees has to match", a.length, b.length); for (int i = 0; i < a.length ; i++) { Assert.assertEquals("Number of trees per tree has to match", a[i].length, b[i].length); for (int j = 0; j < a[i].length; j++) { Key oldAKey = null; Key oldBKey = null; if (ignoreKeyField) { if (a[i][j] != null) { oldAKey = a[i][j]._key; a[i][j]._key = null; } if (b[i][j] != null) { oldBKey = b[i][j]._key; b[i][j]._key = null; } } assertIcedBinaryEquals(msg, a[i][j], b[i][j]); if (ignoreKeyField) { if (a[i][j] != null) { a[i][j]._key = oldAKey; } if (b[i][j] != null) { b[i][j]._key = oldBKey; } } } } } public static CompressedTree[][] getTrees(SharedTreeModel tm) { SharedTreeModel.SharedTreeOutput tmo = (SharedTreeModel.SharedTreeOutput) tm._output; int ntrees = tmo._ntrees; int nclasses = tmo.nclasses(); CompressedTree[][] result = new CompressedTree[ntrees][nclasses]; for (int i = 0; i < ntrees; i++) { for (int j = 0; j < nclasses; j++) { if (tmo._treeKeys[i][j] != null) result[i][j] = tmo.ctree(i, j); } } return result; } }