package water.serial; import hex.*; import hex.FrameTask.DataInfo; import hex.drf.DRF; import hex.drf.DRF.DRFModel; import hex.gbm.GBM; import hex.gbm.GBM.GBMModel; import hex.glm.*; import hex.glm.GLM2.Source; import hex.glm.GLMParams.Family; import java.io.File; import java.io.IOException; import org.junit.Test; import water.*; import water.fvec.Frame; import water.fvec.Vec; public class ModelSerializationTest extends TestUtil { private static int[] EIA = new int[] {}; @Test public void testSimpleModel() throws IOException { // Create a model Model model = new BlahModel(null, null, ar("ColumnBlah", "response"), new String[2][]); // Create a serializer, save a model and reload it Model loadedModel = saveAndLoad(model); // And compare assertModelEquals(model, loadedModel); } @Test public void testGBMModelMultinomial() throws IOException { GBMModel model = null, loadedModel = null; try { model = prepareGBMModel("smalldata/iris/iris.csv", EIA, 4, true, 5); loadedModel = saveAndLoad(model); // And compare assertTreeModelEquals(model, loadedModel); assertModelBinaryEquals(model, loadedModel); } finally { if (model!=null) model.delete(); if (loadedModel!=null) loadedModel.delete(); } } @Test public void testGBMModelBinomial() throws IOException { GBMModel model = null, loadedModel = null; try { model = prepareGBMModel("smalldata/logreg/prostate.csv", ari(0), 1, true, 5); loadedModel = saveAndLoad(model); // And compare assertTreeModelEquals(model, loadedModel); assertModelBinaryEquals(model, loadedModel); } finally { if (model!=null) model.delete(); if (loadedModel!=null) loadedModel.delete(); } } @Test public void testDRFModelMultinomial() throws IOException { DRFModel model = null, loadedModel = null; try { model = prepareDRFModel("smalldata/iris/iris.csv", EIA, 4, true, 5); loadedModel = saveAndLoad(model); // And compare assertTreeModelEquals(model, loadedModel); assertModelBinaryEquals(model, loadedModel); } finally { if (model!=null) model.delete(); if (loadedModel!=null) loadedModel.delete(); } } @Test public void testDRFModelBinomial() throws IOException { DRFModel model = null, loadedModel = null; try { model = prepareDRFModel("smalldata/logreg/prostate.csv", ari(0), 1, true, 5); loadedModel = saveAndLoad(model); // And compare assertTreeModelEquals(model, loadedModel); assertModelBinaryEquals(model, loadedModel); } finally { if (model!=null) model.delete(); if (loadedModel!=null) loadedModel.delete(); } } @Test public void testGLMModel() throws IOException { GLMModel model = null, loadedModel = null; try { model = prepareGLMModel("smalldata/cars.csv", EIA, 4, Family.poisson); loadedModel = saveAndLoad(model); assertModelBinaryEquals(model, loadedModel); GLMTest2.testHTML(loadedModel); } finally { if (model!=null) model.delete(); if (loadedModel!=null) loadedModel.delete(); } } private GBMModel prepareGBMModel(String dataset, int[] ignores, int response, boolean classification, int ntrees) { Frame f = parseFrame(dataset); try { GBM gbm = new GBM(); Vec respVec = f.vec(response); gbm.source = f; gbm.response = respVec; gbm.classification = classification; gbm.ntrees = ntrees; gbm.score_each_iteration = true; gbm.invoke(); return UKV.get(gbm.dest()); } finally { if (f!=null) f.delete(); } } private DRFModel prepareDRFModel(String dataset, int[] ignores, int response, boolean classification, int ntrees) { Frame f = parseFrame(dataset); try { DRF drf = new DRF(); Vec respVec = f.vec(response); drf.source = f; drf.response = respVec; drf.classification = classification; drf.ntrees = ntrees; drf.score_each_iteration = true; drf.invoke(); return UKV.get(drf.dest()); } finally { if (f!=null) f.delete(); } } private GLMModel prepareGLMModel(String dataset, int[] ignores, int response, Family family) { Frame f = parseFrame(dataset); Key modelKey = Key.make("GLM_model_for_"+dataset); try { new GLM2("GLM test on "+dataset,Key.make(),modelKey,new Source(f,f.vec(response),true),family).doInit().fork().get(); return DKV.get(modelKey).get(); } finally { if (f!=null) f.delete(); } } static class BlahModel extends Model { //static final int DEBUG_WEAVER = 1; final Key[] keys; final VarImp varimp; public BlahModel(Key selfKey, Key dataKey, String[] names, String[][] domains) { super(selfKey, dataKey, names, domains, null, null); keys = new Key[3]; varimp = new VarImp.VarImpRI(arf(1f, 1f, 1f)); } @Override protected float[] score0(double[] data, float[] preds) { throw new RuntimeException("TODO Auto-generated method stub"); } } private <M extends Model> M saveAndLoad(M model) throws IOException { return saveAndLoad(model,true); } private <M extends Model> M saveAndLoad(M model, boolean deleteModel) throws IOException { // Serialize to a file File file = File.createTempFile("H2O_ModelSerializationTest", ".model"); file.deleteOnExit(); new Model2FileBinarySerializer().save(model, file); // Delete model if (deleteModel) model.delete(); // Deserialize M m = (M) new Model2FileBinarySerializer().load(file); // And return return m; } }