/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * CVParameterSelection.java * Copyright (C) 1999 Len Trigg * */ package weka.classifiers.meta; import weka.classifiers.Evaluation; import weka.classifiers.Classifier; import weka.classifiers.rules.ZeroR; import java.io.*; import java.util.*; import weka.core.*; /** * Class for performing parameter selection by cross-validation for any * classifier. For more information, see<p> * * R. Kohavi (1995). <i>Wrappers for Performance * Enhancement and Oblivious Decision Graphs</i>. PhD * Thesis. Department of Computer Science, Stanford University. <p> * * Valid options are:<p> * * -D <br> * Turn on debugging output.<p> * * -W classname <br> * Specify the full class name of classifier to perform cross-validation * selection on.<p> * * -X num <br> * Number of folds used for cross validation (default 10). <p> * * -S seed <br> * Random number seed (default 1).<p> * * -P "N 1 5 10" <br> * Sets an optimisation parameter for the classifier with name -N, * lower bound 1, upper bound 5, and 10 optimisation steps. * The upper bound may be the character 'A' or 'I' to substitute * the number of attributes or instances in the training data, * respectively. * This parameter may be supplied more than once to optimise over * several classifier options simultaneously. <p> * * Options after -- are passed to the designated sub-classifier. <p> * * @author Len Trigg (trigg@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */ public class CVParameterSelection extends Classifier implements OptionHandler, Summarizable { /* * A data structure to hold values associated with a single * cross-validation search parameter */ protected class CVParameter { /** Char used to identify the option of interest */ private char m_ParamChar; /** Lower bound for the CV search */ private double m_Lower; /** Upper bound for the CV search */ private double m_Upper; /** Increment during the search */ private double m_Steps; /** The parameter value with the best performance */ private double m_ParamValue; /** True if the parameter should be added at the end of the argument list */ private boolean m_AddAtEnd; /** True if the parameter should be rounded to an integer */ private boolean m_RoundParam; /** * Constructs a CVParameter. */ public CVParameter(String param) throws Exception { // Tokenize the string into it's parts StreamTokenizer st = new StreamTokenizer(new StringReader(param)); if (st.nextToken() != StreamTokenizer.TT_WORD) { throw new Exception("CVParameter " + param + ": Character parameter identifier expected"); } m_ParamChar = st.sval.charAt(0); if (st.nextToken() != StreamTokenizer.TT_NUMBER) { throw new Exception("CVParameter " + param + ": Numeric lower bound expected"); } m_Lower = st.nval; if (st.nextToken() == StreamTokenizer.TT_NUMBER) { m_Upper = st.nval; if (m_Upper < m_Lower) { throw new Exception("CVParameter " + param + ": Upper bound is less than lower bound"); } } else if (st.ttype == StreamTokenizer.TT_WORD) { if (st.sval.toUpperCase().charAt(0) == 'A') { m_Upper = m_Lower - 1; } else if (st.sval.toUpperCase().charAt(0) == 'I') { m_Upper = m_Lower - 2; } else { throw new Exception("CVParameter " + param + ": Upper bound must be numeric, or 'A' or 'N'"); } } else { throw new Exception("CVParameter " + param + ": Upper bound must be numeric, or 'A' or 'N'"); } if (st.nextToken() != StreamTokenizer.TT_NUMBER) { throw new Exception("CVParameter " + param + ": Numeric number of steps expected"); } m_Steps = st.nval; if (st.nextToken() == StreamTokenizer.TT_WORD) { if (st.sval.toUpperCase().charAt(0) == 'R') { m_RoundParam = true; } } } /** * Returns a CVParameter as a string. */ public String toString() { String result = m_ParamChar + " " + m_Lower + " "; switch ((int)(m_Lower - m_Upper + 0.5)) { case 1: result += "A"; break; case 2: result += "I"; break; default: result += m_Upper; break; } result += " " + m_Steps; if (m_RoundParam) { result += " R"; } return result; } } /** The generated base classifier */ protected Classifier m_Classifier = new weka.classifiers.rules.ZeroR(); /** * The base classifier options (not including those being set * by cross-validation) */ protected String [] m_ClassifierOptions; /** The set of all classifier options as determined by cross-validation */ protected String [] m_BestClassifierOptions; /** The cross-validated performance of the best options */ protected double m_BestPerformance; /** The set of parameters to cross-validate over */ protected FastVector m_CVParams; /** The number of attributes in the data */ protected int m_NumAttributes; /** The number of instances in a training fold */ protected int m_TrainFoldSize; /** The number of folds used in cross-validation */ protected int m_NumFolds = 10; /** Random number seed */ protected int m_Seed = 1; /** Debugging mode, gives extra output if true */ protected boolean m_Debug; /** * Create the options array to pass to the classifier. The parameter * values and positions are taken from m_ClassifierOptions and * m_CVParams. * * @return the options array */ protected String [] createOptions() { String [] options = new String [m_ClassifierOptions.length + 2 * m_CVParams.size()]; int start = 0, end = options.length; // Add the cross-validation parameters and their values for (int i = 0; i < m_CVParams.size(); i++) { CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i); double paramValue = cvParam.m_ParamValue; if (cvParam.m_RoundParam) { paramValue = (double)((int) (paramValue + 0.5)); } if (cvParam.m_AddAtEnd) { options[--end] = "" + Utils.doubleToString(paramValue,4); options[--end] = "-" + cvParam.m_ParamChar; } else { options[start++] = "-" + cvParam.m_ParamChar; options[start++] = "" + Utils.doubleToString(paramValue,4); } } // Add the static parameters System.arraycopy(m_ClassifierOptions, 0, options, start, m_ClassifierOptions.length); return options; } /** * Finds the best parameter combination. (recursive for each parameter * being optimised). * * @param depth the index of the parameter to be optimised at this level * @exception Exception if an error occurs */ protected void findParamsByCrossValidation(int depth, Instances trainData) throws Exception { if (depth < m_CVParams.size()) { CVParameter cvParam = (CVParameter)m_CVParams.elementAt(depth); double upper; switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) { case 1: upper = m_NumAttributes; break; case 2: upper = m_TrainFoldSize; break; default: upper = cvParam.m_Upper; break; } double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1); for(cvParam.m_ParamValue = cvParam.m_Lower; cvParam.m_ParamValue <= upper; cvParam.m_ParamValue += increment) { findParamsByCrossValidation(depth + 1, trainData); } } else { Evaluation evaluation = new Evaluation(trainData); // Set the classifier options String [] options = createOptions(); if (m_Debug) { System.err.print("Setting options for " + m_Classifier.getClass().getName() + ":"); for (int i = 0; i < options.length; i++) { System.err.print(" " + options[i]); } System.err.println(""); } ((OptionHandler)m_Classifier).setOptions(options); for (int j = 0; j < m_NumFolds; j++) { Instances train = trainData.trainCV(m_NumFolds, j); Instances test = trainData.testCV(m_NumFolds, j); m_Classifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(m_Classifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println("Cross-validated error rate: " + Utils.doubleToString(error, 6, 4)); } if ((m_BestPerformance == -99) || (error < m_BestPerformance)) { m_BestPerformance = error; m_BestClassifierOptions = createOptions(); } } } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(5); newVector.addElement(new Option( "\tTurn on debugging output.", "D", 0, "-D")); newVector.addElement(new Option( "\tFull name of classifier to perform parameter selection on.\n" + "\teg: weka.classifiers.bayes.NaiveBayes", "W", 1, "-W <classifier class name>")); newVector.addElement(new Option( "\tNumber of folds used for cross validation (default 10).", "X", 1, "-X <number of folds>")); newVector.addElement(new Option( "\tClassifier parameter options.\n" + "\teg: \"N 1 5 10\" Sets an optimisation parameter for the\n" + "\tclassifier with name -N, with lower bound 1, upper bound\n" + "\t5, and 10 optimisation steps. The upper bound may be the\n" + "\tcharacter 'A' or 'I' to substitute the number of\n" + "\tattributes or instances in the training data,\n" + "\trespectively. This parameter may be supplied more than\n" + "\tonce to optimise over several classifier options\n" + "\tsimultaneously.", "P", 1, "-P <classifier parameter>")); newVector.addElement(new Option( "\tSets the random number seed (default 1).", "S", 1, "-S <random number seed>")); if ((m_Classifier != null) && (m_Classifier instanceof OptionHandler)) { newVector.addElement(new Option("", "", 0, "\nOptions specific to sub-classifier " + m_Classifier.getClass().getName() + ":\n(use -- to signal start of sub-classifier options)")); Enumeration enum = ((OptionHandler)m_Classifier).listOptions(); while (enum.hasMoreElements()) { newVector.addElement(enum.nextElement()); } } return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -D <br> * Turn on debugging output.<p> * * -W classname <br> * Specify the full class name of classifier to perform cross-validation * selection on.<p> * * -X num <br> * Number of folds used for cross validation (default 10). <p> * * -S seed <br> * Random number seed (default 1).<p> * * -P "N 1 5 10" <br> * Sets an optimisation parameter for the classifier with name -N, * lower bound 1, upper bound 5, and 10 optimisation steps. * The upper bound may be the character 'A' or 'I' to substitute * the number of attributes or instances in the training data, * respectively. * This parameter may be supplied more than once to optimise over * several classifier options simultaneously. <p> * * Options after -- are passed to the designated sub-classifier. <p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { setDebug(Utils.getFlag('D', options)); String foldsString = Utils.getOption('X', options); if (foldsString.length() != 0) { setNumFolds(Integer.parseInt(foldsString)); } else { setNumFolds(10); } String randomString = Utils.getOption('S', options); if (randomString.length() != 0) { setSeed(Integer.parseInt(randomString)); } else { setSeed(1); } String cvParam; m_CVParams = new FastVector(); do { cvParam = Utils.getOption('P', options); if (cvParam.length() != 0) { addCVParameter(cvParam); } } while (cvParam.length() != 0); if (m_CVParams.size() == 0) { throw new Exception("A parameter specifier must be given with" + " the -P option."); } String classifierName = Utils.getOption('W', options); if (classifierName.length() == 0) { throw new Exception("A classifier must be specified with" + " the -W option."); } setClassifier(Classifier.forName(classifierName, Utils.partitionOptions(options))); if (!(m_Classifier instanceof OptionHandler)) { throw new Exception("Base classifier must accept options"); } } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] classifierOptions = new String [0]; if ((m_Classifier != null) && (m_Classifier instanceof OptionHandler)) { classifierOptions = ((OptionHandler)m_Classifier).getOptions(); } int current = 0; String [] options = new String [classifierOptions.length + 8]; if (m_CVParams != null) { options = new String [m_CVParams.size() * 2 + options.length]; for (int i = 0; i < m_CVParams.size(); i++) { options[current++] = "-P"; options[current++] = "" + getCVParameter(i); } } if (getDebug()) { options[current++] = "-D"; } options[current++] = "-X"; options[current++] = "" + getNumFolds(); options[current++] = "-S"; options[current++] = "" + getSeed(); if (getClassifier() != null) { options[current++] = "-W"; options[current++] = getClassifier().getClass().getName(); } options[current++] = "--"; System.arraycopy(classifierOptions, 0, options, current, classifierOptions.length); current += classifierOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } Instances trainData = new Instances(instances); trainData.deleteWithMissingClass(); if (trainData.numInstances() == 0) { throw new Exception("No training instances without missing class."); } if (trainData.numInstances() < m_NumFolds) { throw new Exception("Number of training instances smaller than number of folds."); } // Check whether there are any parameters to optimize if (m_CVParams == null) { m_Classifier.buildClassifier(trainData); return; } trainData.randomize(new Random(m_Seed)); if (trainData.classAttribute().isNominal()) { trainData.stratify(m_NumFolds); } m_BestPerformance = -99; m_BestClassifierOptions = null; m_NumAttributes = trainData.numAttributes(); m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances(); // Set up m_ClassifierOptions -- take getOptions() and remove // those being optimised. m_ClassifierOptions = ((OptionHandler)m_Classifier).getOptions(); for (int i = 0; i < m_CVParams.size(); i++) { Utils.getOption(((CVParameter)m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions); } findParamsByCrossValidation(0, trainData); String [] options = (String [])m_BestClassifierOptions.clone(); ((OptionHandler)m_Classifier).setOptions(options); m_Classifier.buildClassifier(trainData); } /** * Predicts the class value for the given test instance. * * @param instance the instance to be classified * @return the predicted class value * @exception Exception if an error occurred during the prediction */ public double classifyInstance(Instance instance) throws Exception { return m_Classifier.classifyInstance(instance); } /** * Sets the seed for random number generation. * * @param seed the random number seed */ public void setSeed(int seed) { m_Seed = seed;; } /** * Gets the random number seed. * * @return the random number seed */ public int getSeed() { return m_Seed; } /** * Adds a scheme parameter to the list of parameters to be set * by cross-validation * * @param cvParam the string representation of a scheme parameter. The * format is: <br> * param_char lower_bound upper_bound increment <br> * eg to search a parameter -P from 1 to 10 by increments of 2: <br> * P 1 10 2 <br> * @exception Exception if the parameter specifier is of the wrong format */ public void addCVParameter(String cvParam) throws Exception { CVParameter newCV = new CVParameter(cvParam); m_CVParams.addElement(newCV); } /** * Gets the scheme paramter with the given index. */ public String getCVParameter(int index) { if (m_CVParams.size() <= index) { return ""; } return ((CVParameter)m_CVParams.elementAt(index)).toString(); } /** * Sets debugging mode * * @param debug true if debug output should be printed */ public void setDebug(boolean debug) { m_Debug = debug; } /** * Gets whether debugging is turned on * * @return true if debugging output is on */ public boolean getDebug() { return m_Debug; } /** * Get the number of folds used for cross-validation. * * @return the number of folds used for cross-validation. */ public int getNumFolds() { return m_NumFolds; } /** * Set the number of folds used for cross-validation. * * @param newNumFolds the number of folds used for cross-validation. */ public void setNumFolds(int newNumFolds) { m_NumFolds = newNumFolds; } /** * Set the classifier for boosting. * * @param newClassifier the Classifier to use. */ public void setClassifier(Classifier newClassifier) { m_Classifier = newClassifier; } /** * Get the classifier used as the classifier * * @return the classifier used as the classifier */ public Classifier getClassifier() { return m_Classifier; } /** * Returns description of the cross-validated classifier. * * @return description of the cross-validated classifier as a string */ public String toString() { if (m_BestClassifierOptions == null) return "CVParameterSelection: No model built yet."; String result = "Cross-validated Parameter selection.\n" + "Classifier: " + m_Classifier.getClass().getName() + "\n"; try { for (int i = 0; i < m_CVParams.size(); i++) { CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i); result += "Cross-validation Parameter: '-" + cvParam.m_ParamChar + "'" + " ranged from " + cvParam.m_Lower + " to "; switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) { case 1: result += m_NumAttributes; break; case 2: result += m_TrainFoldSize; break; default: result += cvParam.m_Upper; break; } result += " with " + cvParam.m_Steps + " steps\n"; } } catch (Exception ex) { result += ex.getMessage(); } result += "Classifier Options: " + Utils.joinOptions(m_BestClassifierOptions) + "\n\n" + m_Classifier.toString(); return result; } public String toSummaryString() { String result = "Selected values: " + Utils.joinOptions(m_BestClassifierOptions); return result + '\n'; } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { try { System.out.println(Evaluation.evaluateModel(new CVParameterSelection(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); } } }