/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * Copyright 2010 University of Waikato */ package weka.classifiers.misc; import weka.classifiers.AbstractClassifierTest; import weka.classifiers.Classifier; import weka.classifiers.misc.InputMappedClassifier; import weka.core.Attribute; import weka.core.Instances; import weka.core.TestInstances; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Reorder; import weka.filters.unsupervised.attribute.SwapValues; import junit.framework.Test; import junit.framework.TestSuite; /** * Tests InputMappedClassifier. Run from the command line with:<p> * java weka.classifiers.misc.InputMappedClassifierTest * * @author Mark Hall (mhall{[at]}pentaho{[dot]}com) * @version $Revision: 6803 $ */ public class InputMappedClassifierTest extends AbstractClassifierTest { public InputMappedClassifierTest(String name) { super(name); } /** Creates a default InputMappedClassifier */ public Classifier getClassifier() { InputMappedClassifier toUse = new InputMappedClassifier(); toUse.setClassifier(new weka.classifiers.trees.J48()); toUse.setSuppressMappingReport(true); return toUse; } protected Instances reorderAtts(Instances data) throws Exception { Reorder r = new Reorder(); String range = "last"; for (int i = data.numAttributes() - 1; i > 0; i--) { range += "," + i; } r.setAttributeIndices(range); r.setInputFormat(data); data = Filter.useFilter(data, r); return data; } protected Instances swapValues(int attIndex, Instances data) throws Exception { SwapValues s = new SwapValues(); s.setAttributeIndex("" + attIndex); s.setFirstValueIndex("first"); s.setSecondValueIndex("last"); s.setInputFormat(data); data = Filter.useFilter(data, s); return data; } protected Instances generateData(boolean nomClass, int numClasses, int numNominal, int numNumeric) throws Exception { TestInstances generator = new TestInstances(); if (nomClass) { generator.setClassType(Attribute.NOMINAL); generator.setNumClasses(numClasses); } else { generator.setClassType(Attribute.NUMERIC); } generator.setNumNominal(numNominal); generator.setNumNumeric(numNumeric); generator.setNumDate(0); generator.setNumString(0); generator.setNumRelational(0); generator.setNumInstances(100); generator.setClassIndex(TestInstances.CLASS_IS_LAST); Instances data = generator.generate(); return data; } protected void performTest(boolean nomClass, int numClassesTrain, int numTrainAtts, boolean reorderAtts, boolean reorderNomLabels, boolean reorderClassLabels) { Instances train = null; Instances test = null; try { train = generateData(nomClass, numClassesTrain, numTrainAtts, 3); } catch (Exception ex) { fail("Generating training data failed: " + ex); } test = new Instances(train); if (reorderNomLabels) { // do the first attribute try { test = swapValues(1, test); } catch (Exception ex) { fail("Reordering nominal labels failed: " + ex); } } if (reorderClassLabels && nomClass) { try { test = swapValues(7, test); } catch (Exception ex) { fail("Reordering class labels failed: " + ex); } } if (reorderAtts) { try { test = reorderAtts(test); } catch (Exception ex) { fail("Reordering test data failed: " + ex); } } InputMappedClassifier toUse = null; try { toUse = trainClassifier(train, nomClass); } catch (Exception ex) { fail("Training classifier failed: " + ex); } double[] resultsOnTrainingStructure = null; try { resultsOnTrainingStructure = testClassifier(train, toUse); } catch (Exception ex) { fail("Testing classifier on training data failed: " + ex); } double[] resultsOnTestStructure = null; try { resultsOnTestStructure = testClassifier(test, toUse); } catch (Exception ex) { fail("Testing classifier on test data failed: " + ex); } try { for (int i = 0; i < resultsOnTrainingStructure.length; i++) { if (resultsOnTrainingStructure[i] != resultsOnTestStructure[i]) { throw new Exception("Result #" + (i+1) + " differs!"); } } } catch (Exception ex) { fail("Comparing results failed " + ex); } } public void testNominaClass() { performTest(true, 4, 3, false, false, false); } /* public void testNominalClassDifferingNumClassValues() { performTest(true, 4, 6, 3, 3, false, false, false); } */ public void testNominaClassReorderedAtts() { performTest(true, 4, 3, true, false, false); } public void testNominalClassSwapNominalValues() { performTest(true, 4, 3, false, true, false); } public void testNominalClassSwapNominalValuesReorderAtts() { performTest(true, 4, 3, true, true, false); } public void testNominalClassSwapClassValues() { performTest(true, 4, 3, false, false, true); } public void testNominalClassSwapNominalValuesSwapClassValues() { performTest(true, 4, 3, false, true, true); } public void testNominalClassSwapNominalValuesSwapClassValuesReorderAtts() { performTest(true, 4, 3, true, true, true); } public void testNumericClass() { performTest(false, 4, 3, false, false, false); } public void testNumericClassReorderedAtts() { performTest(false, 4, 3, true, false, false); } public void testNumericClassSwapNominalValues() { performTest(false, 4, 3, false, true, false); } public void testNumericClassSwapNominalValuesReorderAtts() { performTest(false, 4, 3, true, true, false); } protected InputMappedClassifier trainClassifier(Instances data, boolean nominalClass) { InputMappedClassifier toUse = new InputMappedClassifier(); if (nominalClass) { toUse.setClassifier(new weka.classifiers.trees.J48()); } else { toUse.setClassifier(new weka.classifiers.functions.LinearRegression()); } toUse.setSuppressMappingReport(true); try { toUse.buildClassifier(data); } catch (Exception ex) { fail("Training InputMappedClassifier failed: " + ex); return null; } return toUse; } protected double[] testClassifier(Instances test, InputMappedClassifier classifier) { double[] result = new double[test.numInstances()]; try { for (int i = 0; i < test.numInstances(); i++) { result[i] = classifier.classifyInstance(test.instance(i)); } } catch (Exception ex) { fail("Testing InputMappedClassifier failed: " + ex); return null; } return result; } public static Test suite() { return new TestSuite(InputMappedClassifierTest.class); } public static void main(String[] args){ junit.textui.TestRunner.run(suite()); } }