/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * SVMAttributeEval.java * Copyright (C) 2002 Eibe Frank * Mod by Kieran Holland * */ package weka.attributeSelection; import java.io.*; import java.util.*; import weka.core.*; import weka.classifiers.functions.SMO; import weka.filters.Filter; import weka.filters.unsupervised.attribute.MakeIndicator; import weka.attributeSelection.*; /** * Class for Evaluating attributes individually by using the SVM * classifier. Attributes are ranked by the square of the weight * assigned by the SVM. Attribute selection for multiclass problems * is handled by ranking attributes for each class seperately * using a one-vs-all method and then "dealing" from the top of * each pile to give a final ranking.<p> * * Valid options are: <p> * * -X <constant rate of elimination> <br> * Specify constant rate at which attributes are eliminated per invocation * of the support vector machine. Default = 1.<p> * * -Y <percent rate of elimination> <br> * Specify the percentage rate at which attributes are eliminated per invocation * of the support vector machine. This setting trumps the constant rate setting. * Default = 0 (percentage rate ignored).<p> * * -Z <threshold for percent elimination> <br> * Specify the threshold below which the percentage elimination method * reverts to the constant elimination method.<p> * * -C <complexity parameter> <br> * Specify the value of C - the complexity parameter to be passed on * to the support vector machine. <p> * * -P <episilon> <br> * Sets the epsilon for round-off error. (default 1.0e-25)<p> * * -T <tolerance> <br> * Sets the tolerance parameter. (default 1.0e-10)<p> * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @author Mark Hall (mhall@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */ public class SVMAttributeEval extends AttributeEvaluator implements OptionHandler { /** The attribute scores */ private double[] m_attScores; /** Constant rate of attribute elimination per iteration */ private int m_numToEliminate = 1; /** Percentage rate of attribute elimination, trumps constant rate (above threshold), ignored if = 0 */ private int m_percentToEliminate = 0; /** Threshold below which percent elimination switches to constant elimination */ private int m_percentThreshold = 0; /** Complexity parameter to pass on to SMO */ private double m_smoCParameter = 1.0; /** Tolerance parameter to pass on to SMO */ private double m_smoTParameter = 1.0e-10; /** Epsilon parameter to pass on to SMO */ private double m_smoPParameter = 1.0e-25; /** Filter parameter to pass on to SMO */ private int m_smoFilterType = 0; /** * Returns a string describing this attribute evaluator * @return a description of the evaluator suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "SVMAttributeEval :\n\nEvaluates the worth of an attribute by " + "using an SVM classifier.\n"; } /** * Constructor */ public SVMAttributeEval() { resetOptions(); } /** * Returns an enumeration describing all the available options * * @return an enumeration of options */ public Enumeration listOptions() { Vector newVector = new Vector(4); newVector.addElement( new Option( "\tSpecify the constant rate of attribute\n" + "\telimination per invocation of\n" + "\tthe support vector machine.\n" + "\tDefault = 1.", "X", 1, "-X <constant rate of elimination>")); newVector.addElement( new Option( "\tSpecify the percentage rate of attributes to\n" + "\telimination per invocation of\n" + "\tthe support vector machine.\n" + "\tTrumps constant rate (above threshold).\n" + "\tDefault = 0.", "Y", 1, "-Y <percent rate of elimination>")); newVector.addElement( new Option( "\tSpecify the threshold below which \n" + "\tpercentage attribute elimination\n" + "\treverts to the constant method.\n", "Z", 1, "-Z <threshold for percent elimination>")); newVector.addElement( new Option( "\tSpecify the value of P (epsilon\n" + "\tparameter) to pass on to the\n" + "\tsupport vector machine.\n" + "\tDefault = 1.0e-25", "P", 1, "-P <epsilon>")); newVector.addElement( new Option( "\tSpecify the value of T (tolerance\n" + "\tparameter) to pass on to the\n" + "\tsupport vector machine.\n" + "\tDefault = 1.0e-10", "T", 1, "-T <tolerance>")); newVector.addElement( new Option( "\tSpecify the value of C (complexity\n" + "\tparameter) to pass on to the\n" + "\tsupport vector machine.\n" + "\tDefault = 1.0", "C", 1, "-C <complexity>")); newVector.addElement(new Option("\tWhether the SVM should " + "0=normalize/1=standardize/2=neither. " + "(default 0=normalize)", "N", 1, "-N")); return newVector.elements(); } /** * Parses a given list of options * * Valid options are: <p> * * -X <constant rate of elimination> <br> * Specify constant rate at which attributes are eliminated per invocation * of the support vector machine. Default = 1.<p> * * -Y <percent rate of elimination> <br> * Specify the percentage rate at which attributes are eliminated per invocation * of the support vector machine. This setting trumps the constant rate setting. * Default = 0 (percentage rate ignored).<p> * * -Z <threshold for percent elimination> <br> * Specify the threshold below which the percentage elimination method * reverts to the constant elimination method.<p> * * -C <complexity parameter> <br> * Specify the value of C - the complexity parameter to be passed on * to the support vector machine. <p> * * -P <episilon> <br> * Sets the epsilon for round-off error. (default 1.0e-25)<p> * * -T <tolerance> <br> * Sets the tolerance parameter. (default 1.0e-10)<p> * * -N <0|1|2> <br> * Whether the SVM should 0=normalize/1=standardize/2=neither. (default 0=normalize)<p> * * @param options the list of options as an array of strings * @exception Exception if an error occurs */ public void setOptions(String[] options) throws Exception { String optionString; optionString = Utils.getOption('X', options); if (optionString.length() != 0) { setAttsToEliminatePerIteration(Integer.parseInt(optionString)); } optionString = Utils.getOption('Y', options); if (optionString.length() != 0) { setPercentToEliminatePerIteration(Integer.parseInt(optionString)); } optionString = Utils.getOption('Z', options); if (optionString.length() != 0) { setPercentThreshold(Integer.parseInt(optionString)); } optionString = Utils.getOption('P', options); if (optionString.length() != 0) { setEpsilonParameter((new Double(optionString)).doubleValue()); } optionString = Utils.getOption('T', options); if (optionString.length() != 0) { setToleranceParameter((new Double(optionString)).doubleValue()); } optionString = Utils.getOption('C', options); if (optionString.length() != 0) { setComplexityParameter((new Double(optionString)).doubleValue()); } optionString = Utils.getOption('N', options); if (optionString.length() != 0) { setFilterType(new SelectedTag(Integer.parseInt(optionString), SMO.TAGS_FILTER)); } else { setFilterType(new SelectedTag(SMO.FILTER_NORMALIZE, SMO.TAGS_FILTER)); } Utils.checkForRemainingOptions(options); } /** * Gets the current settings of SVMAttributeEval * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions() { String[] options = new String[14]; int current = 0; options[current++] = "-X"; options[current++] = "" + getAttsToEliminatePerIteration(); options[current++] = "-Y"; options[current++] = "" + getPercentToEliminatePerIteration(); options[current++] = "-Z"; options[current++] = "" + getPercentThreshold(); options[current++] = "-P"; options[current++] = "" + getEpsilonParameter(); options[current++] = "-T"; options[current++] = "" + getToleranceParameter(); options[current++] = "-C"; options[current++] = "" + getComplexityParameter(); options[current++] = "-N"; options[current++] = "" + m_smoFilterType; while (current < options.length) { options[current++] = ""; } return options; } //________________________________________________________________________ /** * Returns a tip text for this property suitable for display in the * GUI * * @return tip text string describing this property */ public String attsToEliminatePerIterationTipText() { return "Constant rate of attribute elimination."; } /** * Returns a tip text for this property suitable for display in the * GUI * * @return tip text string describing this property */ public String percentToEliminatePerIterationTipText() { return "Percent rate of attribute elimination."; } /** * Returns a tip text for this property suitable for display in the * GUI * * @return tip text string describing this property */ public String percentThresholdTipText() { return "Threshold below which percent elimination reverts to constant elimination."; } /** * Returns a tip text for this property suitable for display in the * GUI * * @return tip text string describing this property */ public String epsilonParameterTipText() { return "P epsilon parameter to pass to the SVM"; } /** * Returns a tip text for this property suitable for display in the * GUI * * @return tip text string describing this property */ public String toleranceParameterTipText() { return "T tolerance parameter to pass to the SVM"; } /** * Returns a tip text for this property suitable for display in the * GUI * * @return tip text string describing this property */ public String complexityParameterTipText() { return "C complexity parameter to pass to the SVM"; } /** * Returns a tip text for this property suitable for display in the * GUI * * @return tip text string describing this property */ public String filterTypeTipText() { return "filtering used by the SVM"; } //________________________________________________________________________ /** * Set the constant rate of attribute elimination per iteration * * @param X the constant rate of attribute elimination per iteration */ public void setAttsToEliminatePerIteration(int cRate) { m_numToEliminate = cRate; } /** * Get the constant rate of attribute elimination per iteration * * @return the constant rate of attribute elimination per iteration */ public int getAttsToEliminatePerIteration() { return m_numToEliminate; } /** * Set the percentage of attributes to eliminate per iteration * * @param Y percent of attributes to eliminate per iteration */ public void setPercentToEliminatePerIteration(int pRate) { m_percentToEliminate = pRate; } /** * Get the percentage rate of attribute elimination per iteration * * @return the percentage rate of attribute elimination per iteration */ public int getPercentToEliminatePerIteration() { return m_percentToEliminate; } /** * Set the threshold below which percentage elimination reverts to * constant elimination. * * @param thresh percent of attributes to eliminate per iteration */ public void setPercentThreshold(int pThresh) { m_percentThreshold = pThresh; } /** * Get the threshold below which percentage elimination reverts to * constant elimination. * * @return the threshold below which percentage elimination stops */ public int getPercentThreshold() { return m_percentThreshold; } /** * Set the value of P for SMO * * @param svmP the value of P */ public void setEpsilonParameter(double svmP) { m_smoPParameter = svmP; } /** * Get the value of P used with SMO * * @return the value of P */ public double getEpsilonParameter() { return m_smoPParameter; } /** * Set the value of T for SMO * * @param svmC the value of T */ public void setToleranceParameter(double svmT) { m_smoTParameter = svmT; } /** * Get the value of T used with SMO * * @return the value of T */ public double getToleranceParameter() { return m_smoTParameter; } /** * Set the value of C for SMO * * @param svmC the value of C */ public void setComplexityParameter(double svmC) { m_smoCParameter = svmC; } /** * Get the value of C used with SMO * * @return the value of C */ public double getComplexityParameter() { return m_smoCParameter; } /** * The filtering mode to pass to SMO * * @param newType the new filtering mode */ public void setFilterType(SelectedTag newType) { if (newType.getTags() == SMO.TAGS_FILTER) { m_smoFilterType = newType.getSelectedTag().getID(); } } /** * Get the filtering mode passed to SMO * * @return the filtering mode */ public SelectedTag getFilterType() { return new SelectedTag(m_smoFilterType, SMO.TAGS_FILTER); } //________________________________________________________________________ /** * Initializes the evaluator. * * @param data set of instances serving as training data * @exception Exception if the evaluator has not been * generated successfully */ public void buildEvaluator(Instances data) throws Exception { if (data.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } if (!data.classAttribute().isNominal()) { throw new Exception("Class must be nominal!"); } for (int i = 0; i < data.numAttributes(); i++) { if (data.attribute(i).isNominal() && (data.attribute(i).numValues() != 2) && !(i==data.classIndex()) ) { throw new Exception("All nominal attributes must be binary!"); } } System.out.println("Class attribute: " + data.attribute(data.classIndex()).name()); // Check settings m_numToEliminate = (m_numToEliminate > 1) ? m_numToEliminate : 1; m_percentToEliminate = (m_percentToEliminate < 100) ? m_percentToEliminate : 100; m_percentToEliminate = (m_percentToEliminate > 0) ? m_percentToEliminate : 0; m_percentThreshold = (m_percentThreshold < data.numAttributes()) ? m_percentThreshold : data.numAttributes() - 1; m_percentThreshold = (m_percentThreshold > 0) ? m_percentThreshold : 0; // Get ranked attributes for each class seperately, one-vs-all int[][] attScoresByClass; int numAttr = data.numAttributes() - 1; if(data.numClasses()>2) { attScoresByClass = new int[data.numClasses()][numAttr]; for (int i = 0; i < data.numClasses(); i++) { attScoresByClass[i] = rankBySVM(i, data); } } else { attScoresByClass = new int[1][numAttr]; attScoresByClass[0] = rankBySVM(0, data); } // Cycle through class-specific ranked lists, poping top one off for each class // and adding it to the overall ranked attribute list if it's not there already ArrayList ordered = new ArrayList(numAttr); for (int i = 0; i < numAttr; i++) { for (int j = 0; j < (data.numClasses()>2 ? data.numClasses() : 1); j++) { Integer rank = new Integer(attScoresByClass[j][i]); if (!ordered.contains(rank)) ordered.add(rank); } } m_attScores = new double[data.numAttributes()]; Iterator listIt = ordered.iterator(); for (double i = (double) numAttr; listIt.hasNext(); i = i - 1.0) { m_attScores[((Integer) listIt.next()).intValue()] = i; } } /** * Get SVM-ranked attribute indexes (best to worst) selected for * the class attribute indexed by classInd (one-vs-all). */ private int[] rankBySVM(int classInd, Instances data) { // Holds a mapping into the original array of attribute indices int[] origIndices = new int[data.numAttributes()]; for (int i = 0; i < origIndices.length; i++) origIndices[i] = i; // Count down of number of attributes remaining int numAttrLeft = data.numAttributes()-1; // Ranked attribute indices for this class, one vs.all (highest->lowest) int[] attRanks = new int[numAttrLeft]; try { MakeIndicator filter = new MakeIndicator(); filter.setAttributeIndex(data.classIndex()); filter.setNumeric(false); filter.setValueIndex(classInd); filter.setInputFormat(data); Instances trainCopy = Filter.useFilter(data, filter); double pctToElim = ((double) m_percentToEliminate) / 100.0; while (numAttrLeft > 0) { int numToElim; if (pctToElim > 0) { numToElim = (int) (trainCopy.numAttributes() * pctToElim); numToElim = (numToElim > 1) ? numToElim : 1; if (numAttrLeft - numToElim <= m_percentThreshold) { pctToElim = 0; numToElim = numAttrLeft - m_percentThreshold; } } else { numToElim = (numAttrLeft >= m_numToEliminate) ? m_numToEliminate : numAttrLeft; } // Build the linear SVM with default parameters SMO smo = new SMO(); // SMO seems to get stuck if data not normalised when few attributes remain // smo.setNormalizeData(numAttrLeft < 40); smo.setFilterType(new SelectedTag(m_smoFilterType, SMO.TAGS_FILTER)); smo.setEpsilon(m_smoPParameter); smo.setToleranceParameter(m_smoTParameter); smo.setC(m_smoCParameter); smo.buildClassifier(trainCopy); // Find the attribute with maximum weight^2 FastVector weightsAndIndices = smo.weights(); double[] weightsSparse = (double[]) weightsAndIndices.elementAt(0); int[] indicesSparse = (int[]) weightsAndIndices.elementAt(1); double[] weights = new double[trainCopy.numAttributes()]; for (int j = 0; j < weightsSparse.length; j++) { weights[indicesSparse[j]] = weightsSparse[j] * weightsSparse[j]; } weights[trainCopy.classIndex()] = Double.MAX_VALUE; int minWeightIndex; int[] featArray = new int[numToElim]; boolean[] eliminated = new boolean[origIndices.length]; for (int j = 0; j < numToElim; j++) { minWeightIndex = Utils.minIndex(weights); attRanks[--numAttrLeft] = origIndices[minWeightIndex]; featArray[j] = minWeightIndex; eliminated[minWeightIndex] = true; weights[minWeightIndex] = Double.MAX_VALUE; } // Delete the worst attributes. weka.filters.unsupervised.attribute.Remove delTransform = new weka.filters.unsupervised.attribute.Remove(); delTransform.setInvertSelection(false); delTransform.setAttributeIndicesArray(featArray); delTransform.setInputFormat(trainCopy); trainCopy = Filter.useFilter(trainCopy, delTransform); // Update the array of remaining attribute indices int[] temp = new int[origIndices.length - numToElim]; int k = 0; for (int j = 0; j < origIndices.length; j++) { if (!eliminated[j]) { temp[k++] = origIndices[j]; } } origIndices = temp; } // Carefully handle all exceptions } catch (Exception e) { e.printStackTrace(); } return attRanks; } /** * Resets options to defaults. */ protected void resetOptions() { m_attScores = null; } /** * Evaluates an attribute by returning the rank of the square of its coefficient in a * linear support vector machine. * *@param attribute the index of the attribute to be evaluated * @exception Exception if the attribute could not be evaluated */ public double evaluateAttribute(int attribute) throws Exception { return m_attScores[attribute]; } /** * Return a description of the evaluator * @return description as a string */ public String toString() { StringBuffer text = new StringBuffer(); if (m_attScores == null) { text.append("\tSVM feature evaluator has not been built yet"); } else { text.append("\tSVM feature evaluator"); } text.append("\n"); return text.toString(); } /** * Main method for testing this class. * * @param args the options */ public static void main(String[] args) { try { System.out.println(AttributeSelection.SelectAttributes(new SVMAttributeEval(), args)); } catch (Exception e) { e.printStackTrace(); System.out.println(e.getMessage()); } } }