/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.classifier; import org.junit.Test; import rapaio.data.*; import rapaio.datasets.Datasets; import rapaio.io.JavaIO; import rapaio.ml.classifier.bayes.NaiveBayes; import rapaio.ml.classifier.bayes.estimator.KernelPdf; import rapaio.ml.classifier.rule.OneRule; import rapaio.ml.classifier.tree.CTree; import rapaio.ml.eval.Confusion; import java.io.File; import java.io.IOException; import java.net.URISyntaxException; import static org.junit.Assert.assertEquals; /** * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 9/15/15. */ public class ClassifierSerializationTest { @Test public void testOneRuleIris() throws IOException, URISyntaxException, ClassNotFoundException { Var varModel = Nominal.empty(); Var varData = Nominal.empty(); Var varAcc = Numeric.empty(); Frame iris = Datasets.loadIrisDataset(); testModel(new OneRule(), iris, "class", "iris", varModel, varData, varAcc); testModel(new NaiveBayes().withNumEstimator(new KernelPdf()), iris, "class", "iris", varModel, varData, varAcc); testModel(CTree.newC45(), iris, "class", "iris", varModel, varData, varAcc); testModel(CTree.newCART(), iris, "class", "iris", varModel, varData, varAcc); Frame mushrooms = Datasets.loadMushrooms(); testModel(new OneRule(), mushrooms, "classes", "mushrooms", varModel, varData, varAcc); testModel(new NaiveBayes().withNumEstimator(new KernelPdf()), mushrooms, "classes", "mushrooms", varModel, varData, varAcc); testModel(CTree.newC45(), mushrooms, "classes", "mushrooms", varModel, varData, varAcc); testModel(CTree.newCART(), mushrooms, "classes", "mushrooms", varModel, varData, varAcc); SolidFrame.byVars(varData, varModel, varAcc).printLines(); } private <T extends Classifier> void testModel(T model, Frame df, String target, String dataName, Var varModel, Var varData, Var varAcc) throws IOException, ClassNotFoundException { model.train(df, target); model.printSummary(); File tmp = File.createTempFile("model-", "ser"); JavaIO.storeToFile(model, tmp); T shaddow = (T) JavaIO.restoreFromFile(tmp); CFit modelFit = model.fit(df); CFit shaddowFit = shaddow.fit(df); modelFit.printSummary(); assertEquals(modelFit.summary(), shaddowFit.summary()); varData.addLabel(dataName); varModel.addLabel(model.name()); varAcc.addValue(new Confusion(df.var(target), modelFit.firstClasses()).accuracy()); } }