/* * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ /* * MultiScheme.java * Copyright (C) 1999-2012 University of Waikato, Hamilton, New Zealand * */ package weka.classifiers.meta; import java.util.Enumeration; import java.util.Random; import java.util.Vector; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.RandomizableMultipleClassifiersCombiner; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.OptionHandler; import weka.core.RevisionUtils; import weka.core.Utils; /** <!-- globalinfo-start --> * Class for selecting a classifier from among several using cross validation on the training data or the performance on the training data. Performance is measured based on percent correct (classification) or mean-squared error (regression). * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -X <number of folds> * Use cross validation for model selection using the * given number of folds. (default 0, is to * use training error)</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -B <classifier specification> * Full class name of classifier to include, followed * by scheme options. May be specified multiple times. * (default: "weka.classifiers.rules.ZeroR")</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <!-- options-end --> * * @author Len Trigg (trigg@cs.waikato.ac.nz) * @version $Revision: 8034 $ */ public class MultiScheme extends RandomizableMultipleClassifiersCombiner { /** for serialization */ static final long serialVersionUID = 5710744346128957520L; /** The classifier that had the best performance on training data. */ protected Classifier m_Classifier; /** The index into the vector for the selected scheme */ protected int m_ClassifierIndex; /** * Number of folds to use for cross validation (0 means use training * error for selection) */ protected int m_NumXValFolds; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for selecting a classifier from among several using cross " + "validation on the training data or the performance on the " + "training data. Performance is measured based on percent correct " + "(classification) or mean-squared error (regression)."; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(1); newVector.addElement(new Option( "\tUse cross validation for model selection using the\n" + "\tgiven number of folds. (default 0, is to\n" + "\tuse training error)", "X", 1, "-X <number of folds>")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -X <number of folds> * Use cross validation for model selection using the * given number of folds. (default 0, is to * use training error)</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -B <classifier specification> * Full class name of classifier to include, followed * by scheme options. May be specified multiple times. * (default: "weka.classifiers.rules.ZeroR")</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String numFoldsString = Utils.getOption('X', options); if (numFoldsString.length() != 0) { setNumFolds(Integer.parseInt(numFoldsString)); } else { setNumFolds(0); } super.setOptions(options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] superOptions = super.getOptions(); String [] options = new String [superOptions.length + 2]; int current = 0; options[current++] = "-X"; options[current++] = "" + getNumFolds(); System.arraycopy(superOptions, 0, options, current, superOptions.length); return options; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String classifiersTipText() { return "The classifiers to be chosen from."; } /** * Sets the list of possible classifers to choose from. * * @param classifiers an array of classifiers with all options set. */ public void setClassifiers(Classifier [] classifiers) { m_Classifiers = classifiers; } /** * Gets the list of possible classifers to choose from. * * @return the array of Classifiers */ public Classifier [] getClassifiers() { return m_Classifiers; } /** * Gets a single classifier from the set of available classifiers. * * @param index the index of the classifier wanted * @return the Classifier */ public Classifier getClassifier(int index) { return m_Classifiers[index]; } /** * Gets the classifier specification string, which contains the class name of * the classifier and any options to the classifier * * @param index the index of the classifier string to retrieve, starting from * 0. * @return the classifier string, or the empty string if no classifier * has been assigned (or the index given is out of range). */ protected String getClassifierSpec(int index) { if (m_Classifiers.length < index) { return ""; } Classifier c = getClassifier(index); if (c instanceof OptionHandler) { return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)c).getOptions()); } return c.getClass().getName(); } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String seedTipText() { return "The seed used for randomizing the data " + "for cross-validation."; } /** * Sets the seed for random number generation. * * @param seed the random number seed */ public void setSeed(int seed) { m_Seed = seed;; } /** * Gets the random number seed. * * @return the random number seed */ public int getSeed() { return m_Seed; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numFoldsTipText() { return "The number of folds used for cross-validation (if 0, " + "performance on training data will be used)."; } /** * Gets the number of folds for cross-validation. A number less * than 2 specifies using training error rather than cross-validation. * * @return the number of folds for cross-validation */ public int getNumFolds() { return m_NumXValFolds; } /** * Sets the number of folds for cross-validation. A number less * than 2 specifies using training error rather than cross-validation. * * @param numFolds the number of folds for cross-validation */ public void setNumFolds(int numFolds) { m_NumXValFolds = numFolds; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String debugTipText() { return "Whether debug information is output to console."; } /** * Set debugging mode * * @param debug true if debug output should be printed */ public void setDebug(boolean debug) { m_Debug = debug; } /** * Get whether debugging is turned on * * @return true if debugging output is on */ public boolean getDebug() { return m_Debug; } /** * Get the index of the classifier that was determined as best during * cross-validation. * * @return the index in the classifier array */ public int getBestClassifierIndex() { return m_ClassifierIndex; } /** * Buildclassifier selects a classifier from the set of classifiers * by minimising error on the training data. * * @param data the training data to be used for generating the * boosted classifier. * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { if (m_Classifiers.length == 0) { throw new Exception("No base classifiers have been set!"); } // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class Instances newData = new Instances(data); newData.deleteWithMissingClass(); Random random = new Random(m_Seed); newData.randomize(random); if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1)) { newData.stratify(m_NumXValFolds); } Instances train = newData; // train on all data by default Instances test = newData; // test on training data by default Classifier bestClassifier = null; int bestIndex = -1; double bestPerformance = Double.NaN; int numClassifiers = m_Classifiers.length; for (int i = 0; i < numClassifiers; i++) { Classifier currentClassifier = getClassifier(i); Evaluation evaluation; if (m_NumXValFolds > 1) { evaluation = new Evaluation(newData); for (int j = 0; j < m_NumXValFolds; j++) { // We want to randomize the data the same way for every // learning scheme. train = newData.trainCV(m_NumXValFolds, j, new Random (1)); test = newData.testCV(m_NumXValFolds, j); currentClassifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(currentClassifier, test); } } else { currentClassifier.buildClassifier(train); evaluation = new Evaluation(train); evaluation.evaluateModel(currentClassifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println("Error rate: " + Utils.doubleToString(error, 6, 4) + " for classifier " + currentClassifier.getClass().getName()); } if ((i == 0) || (error < bestPerformance)) { bestClassifier = currentClassifier; bestPerformance = error; bestIndex = i; } } m_ClassifierIndex = bestIndex; if (m_NumXValFolds > 1) { bestClassifier.buildClassifier(newData); } m_Classifier = bestClassifier; } /** * Returns class probabilities. * * @param instance the instance to be classified * @return the distribution for the instance * @throws Exception if instance could not be classified * successfully */ public double[] distributionForInstance(Instance instance) throws Exception { return m_Classifier.distributionForInstance(instance); } /** * Output a representation of this classifier * @return a string representation of the classifier */ public String toString() { if (m_Classifier == null) { return "MultiScheme: No model built yet."; } String result = "MultiScheme selection using"; if (m_NumXValFolds > 1) { result += " cross validation error"; } else { result += " error on training data"; } result += " from the following:\n"; for (int i = 0; i < m_Classifiers.length; i++) { result += '\t' + getClassifierSpec(i) + '\n'; } result += "Selected scheme: " + getClassifierSpec(m_ClassifierIndex) + "\n\n" + m_Classifier.toString(); return result; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 8034 $"); } /** * Main method for testing this class. * * @param argv should contain the following arguments: * -t training file [-T test file] [-c class index] */ public static void main(String [] argv) { runClassifier(new MultiScheme(), argv); } }