package weka.classifiers.meta; import java.util.Enumeration; import java.util.Vector; import weka.classifiers.RandomizableMultipleClassifiersCombiner; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.RevisionUtils; import weka.core.TechnicalInformation; import weka.core.TechnicalInformationHandler; import weka.core.Utils; /** * * @author rlichten */ public class WeightedVote extends RandomizableMultipleClassifiersCombiner implements TechnicalInformationHandler { protected double[] m_Weights; public void setWeights( double[] weights ) { double min = weights[Utils.minIndex(weights)]; if( min < 0 ) { for( int i = 0; i < weights.length; i++ ) { weights[i] += -min; } } double sum = Utils.sum( weights ); if( sum <= 0 ) { for( int i = 0; i < weights.length; i++ ) { weights[i] = 1.0 / (double)weights.length; } } Utils.normalize( weights ); m_Weights = weights; } public double[] getWeights() { return m_Weights; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector result = new Vector(); Enumeration optionEnum = super.listOptions(); while( optionEnum.hasMoreElements() ) { result.addElement( optionEnum.nextElement() ); } result.addElement( new Option( "\tThe weights to use (default: even weighting)\n", "w", 1, "-w <comma-separated-weights>" ) ); return result.elements(); } public void setOptions( String[] options ) throws Exception { String weightStr = Utils.getOption( 'w', options ); if( weightStr.length() != 0 ) { String[] weightStrArray = weightStr.split( "," ); double[] weightsArray = new double[weightStrArray.length]; for( int i = 0; i < weightStrArray.length; i++ ) { weightsArray[i] = new Double( weightStrArray[i] ).doubleValue(); } setWeights( weightsArray ); } } public String[] getOptions() { String[] options = new String[4]; int current = 0; String weightStr = ""; for( int i = 0; i < m_Weights.length; i++ ) { weightStr += m_Weights[i] + ","; } weightStr.replace( ",$", "" ); options[current++] = "-w"; options[current++] = "" + weightStr; while( current < options.length ) { options[current++] = ""; } return options; } /** * Classifies the given test instance. * * @param instance the instance to be classified * @return the predicted most likely class for the instance or * Instance.missingValue() if no prediction is made * @throws Exception if an error occurred during the prediction */ public double classifyInstance( Instance instance ) throws Exception { double result; double[] dist = distributionForInstance( instance ); if( instance.classAttribute().isNominal() ) { int index = Utils.maxIndex( dist ); if( dist[index] == 0 ) { result = Instance.missingValue(); } else { result = index; } } else if( instance.classAttribute().isNumeric() ) { result = dist[0]; } else { result = Instance.missingValue(); } return result; } /** * Classifies a given instance using the selected combination rule. * * @param instance the instance to be classified * @return the distribution * @throws Exception if instance could not be classified * successfully */ public double[] distributionForInstance( Instance instance ) throws Exception { double[] result = new double[instance.numClasses()]; for( int i = 0; i < m_Classifiers.length; i++ ) { double[] dist = getClassifier( i ).distributionForInstance( instance ); for( int j = 0; j < result.length; j++ ) { result[j] += m_Weights[i] * dist[j]; } } if( !instance.classAttribute().isNumeric() && ( Utils.sum( result ) > 0 ) ) { Utils.normalize( result ); } return result; } /** * This method selects a classifier from the set of classifiers * by minimising error on the training data. * * @param data the training data to be used for generating the * classifier. * @throws Exception if the classifier could not be built successfully */ public void buildClassifier( Instances data ) throws Exception { getCapabilities().testWithFail( data ); Instances newData = new Instances( data ); newData.deleteWithMissingClass(); for( int i = 0; i < m_Classifiers.length; i++ ) { getClassifier( i ).buildClassifier( newData ); } } public String getRevision() { return RevisionUtils.extract( "$Revision: 1.0$" ); } public TechnicalInformation getTechnicalInformation() { throw new UnsupportedOperationException( "Not supported yet." ); } /** * Main method for testing this class. * * @param argv should contain the following arguments: * -t training file [-T test file] [-c class index] */ public static void main( String[] argv ) { runClassifier( new WeightedVote(), argv ); } }