/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand */ package weka.classifiers.misc; import weka.classifiers.Classifier; import weka.classifiers.evaluation.EvaluationUtils; import weka.core.Attribute; import weka.core.CheckOptionHandler; import weka.core.FastVector; import weka.core.Instances; import weka.core.SerializationHelper; import weka.core.TestInstances; import weka.test.Regression; import java.io.File; import junit.framework.Test; import junit.framework.TestCase; import junit.framework.TestSuite; /** * Tests SerializedClassifier. Run from the command line with:<p> * java weka.classifiers.misc.SerializedClassifierTest * * @author FracPete (fracpete at waikato dot ac dot nz) * @version $Revision: 1.1 $ */ public class SerializedClassifierTest extends TestCase { /** the filename for temporary serialized models */ public final static String MODEL_FILENAME = System.getProperty("user.dir") + "/" + "temp.model"; /** the setup classifier */ protected SerializedClassifier m_Classifier; /** the OptionHandler tester */ protected CheckOptionHandler m_OptionTester; /** * initializes the test * * @param name the name of the test */ public SerializedClassifierTest(String name) { super(name); } /** * Called by JUnit before each test method. * * @throws Exception if an error occurs reading the example instances. */ protected void setUp() throws Exception { m_Classifier = null; m_OptionTester = new CheckOptionHandler(); m_OptionTester.setSilent(true); // delete temp file File file = new File(MODEL_FILENAME); if (file.exists()) file.delete(); } /** * Called by JUnit after each test method */ protected void tearDown() { m_Classifier = null; m_OptionTester = null; // delete temp file File file = new File(MODEL_FILENAME); if (file.exists()) file.delete(); } /** * creates a classifier, trains and serializes it * * @param data the data to use (J48 with nominal class, M5P with * numeric class) * @return the results for the data */ protected double[] trainAndSerializeClassifier(Instances data) { Classifier classifier; double[] result; int i; try { // build if (data.classAttribute().isNominal()) classifier = new weka.classifiers.trees.J48(); else classifier = new weka.classifiers.trees.M5P(); classifier.buildClassifier(data); // record predictions result = new double[data.numInstances()]; for (i = 0; i < result.length; i++) result[i] = classifier.classifyInstance(data.instance(i)); // save SerializationHelper.write(MODEL_FILENAME, classifier); } catch (Exception e) { fail("Training base classifier failed: " + e); return null; } return result; } /** * performs the actual test * * @param nomClass whether to use a nominal class with J48 (TRUE) or * a numeric one with M5P (FALSE) */ protected void performTest(boolean nomClass) { TestInstances test; Instances data; double[] originalResults; double[] testResults; int i; // generate data try { test = new TestInstances(); if (nomClass) { test.setClassType(Attribute.NOMINAL); test.setNumNominal(5); test.setNumNominalValues(4); test.setNumNumeric(0); } else { test.setClassType(Attribute.NUMERIC); test.setNumNominal(0); test.setNumNumeric(5); } test.setNumDate(0); test.setNumString(0); test.setNumRelational(0); test.setNumInstances(100); data = test.generate(); } catch (Exception e) { fail("Generating test data failed: " + e); return; } // train and save base classifier try { originalResults = trainAndSerializeClassifier(data); } catch (Exception e) { fail("Training base classifier failed: " + e); return; } // test loading try { m_Classifier = new SerializedClassifier(); m_Classifier.setModelFile(new File(MODEL_FILENAME)); m_Classifier.buildClassifier(data); } catch (Exception e) { fail("Loading/testing of classifier failed: " + e); } // compare results try { // get results from serialized model testResults = new double[data.numInstances()]; for (i = 0; i < testResults.length; i++) testResults[i] = m_Classifier.classifyInstance(data.instance(i)); // compare for (i = 0; i < originalResults.length; i++) { if (originalResults[i] != testResults[i]) throw new Exception("Result #" + (i+1) + " differs!"); } } catch (Exception e) { fail("Comparing results failed: " + e); } } /** * tests a serialized classifier (J48) handling nominal classes */ public void testNominalClass() { performTest(true); } /** * tests a serialized classifier (M5P) handling numeric classes */ public void testNumericClass() { performTest(true); } /** * Returns a string containing all the predictions. * * @param predictions a <code>FastVector</code> containing the predictions * @return a <code>String</code> representing the vector of predictions. */ protected String predictionsToString(FastVector predictions) { StringBuffer sb = new StringBuffer(); sb.append(predictions.size()).append(" predictions\n"); for (int i = 0; i < predictions.size(); i++) { sb.append(predictions.elementAt(i)).append('\n'); } return sb.toString(); } /** * Runs a regression test -- this checks that the output of the tested * object matches that in a reference version. When this test is * run without any pre-existing reference output, the reference version * is created. Uses J48 for this purpose. */ public void testRegression() { Regression reg; Instances train; Instances test; Instances data; TestInstances testInst; int tot; int mid; EvaluationUtils evaluation; FastVector regressionResults; reg = new Regression(this.getClass()); // generate test data try { testInst = new TestInstances(); testInst.setClassType(Attribute.NOMINAL); testInst.setNumNominal(5); testInst.setNumNominalValues(4); testInst.setNumNumeric(0); testInst.setNumDate(0); testInst.setNumString(0); testInst.setNumRelational(0); testInst.setNumInstances(100); data = testInst.generate(); } catch (Exception e) { fail("Failed generating data: " + e); return; } // split data into train/test tot = data.numInstances(); mid = tot / 2; train = null; test = null; try { train = new Instances(data, 0, mid); test = new Instances(data, mid, tot - mid); m_Classifier = new SerializedClassifier(); m_Classifier.setModelFile(new File(MODEL_FILENAME)); } catch (Exception e) { e.printStackTrace(); fail("Problem setting up to use classifier: " + e); } evaluation = new EvaluationUtils(); try { trainAndSerializeClassifier(train); regressionResults = evaluation.getTrainTestPredictions(m_Classifier, train, test); reg.println(predictionsToString(regressionResults)); } catch (Exception e) { fail("Failed obtaining classifier predictions: " + e); } try { String diff = reg.diff(); if (diff == null) { System.err.println("Warning: No reference available, creating."); } else if (!diff.equals("")) { fail("Regression test failed. Difference:\n" + diff); } } catch (java.io.IOException ex) { fail("Problem during regression testing.\n" + ex); } } /** * tests the listing of the options */ public void testListOptions() { if (!m_OptionTester.checkListOptions()) fail("Options cannot be listed via listOptions."); } /** * tests the setting of the options */ public void testSetOptions() { if (!m_OptionTester.checkSetOptions()) fail("setOptions method failed."); } /** * tests whether there are any remaining options */ public void testRemainingOptions() { if (!m_OptionTester.checkRemainingOptions()) fail("There were 'left-over' options."); } /** * tests the whether the user-supplied options stay the same after setting. * getting, and re-setting again. */ public void testCanonicalUserOptions() { if (!m_OptionTester.checkCanonicalUserOptions()) fail("setOptions method failed"); } /** * tests the resetting of the options to the default ones */ public void testResettingOptions() { if (!m_OptionTester.checkSetOptions()) fail("Resetting of options failed"); } public static Test suite() { return new TestSuite(SerializedClassifierTest.class); } public static void main(String[] args){ junit.textui.TestRunner.run(suite()); } }