/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * RotationForest.java * Copyright (C) 2008 Juan Jose Rodriguez * Copyright (C) 2008 University of Waikato, Hamilton, New Zealand * */ package weka.classifiers.meta; import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer; import weka.core.Attribute; import weka.core.FastVector; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.OptionHandler; import weka.core.Randomizable; import weka.core.RevisionUtils; import weka.core.TechnicalInformation; import weka.core.TechnicalInformation.Field; import weka.core.TechnicalInformation.Type; import weka.core.TechnicalInformationHandler; import weka.core.WeightedInstancesHandler; import weka.core.Utils; import weka.filters.Filter; import weka.filters.unsupervised.attribute.Normalize; import weka.filters.unsupervised.attribute.PrincipalComponents; import weka.filters.unsupervised.attribute.RemoveUseless; import weka.filters.unsupervised.instance.RemovePercentage; import java.util.Enumeration; import java.util.LinkedList; import java.util.Random; import java.util.Vector; import weka.core.DenseInstance; /** <!-- globalinfo-start --> * Class for construction a Rotation Forest. Can do classification and regression depending on the base learner. <br/> * <br/> * For more information, see<br/> * <br/> * Juan J. Rodriguez, Ludmila I. Kuncheva, Carlos J. Alonso (2006). Rotation Forest: A new classifier ensemble method. IEEE Transactions on Pattern Analysis and Machine Intelligence. 28(10):1619-1630. URL http://doi.ieeecomputersociety.org/10.1109/TPAMI.2006.211. * <p/> <!-- globalinfo-end --> * <!-- technical-bibtex-start --> * BibTeX: * <pre> * @article{Rodriguez2006, * author = {Juan J. Rodriguez and Ludmila I. Kuncheva and Carlos J. Alonso}, * journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, * number = {10}, * pages = {1619-1630}, * title = {Rotation Forest: A new classifier ensemble method}, * volume = {28}, * year = {2006}, * ISSN = {0162-8828}, * URL = {http://doi.ieeecomputersociety.org/10.1109/TPAMI.2006.211} * } * </pre> * <p/> <!-- technical-bibtex-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -N * Whether minGroup (-G) and maxGroup (-H) refer to * the number of groups or their size. * (default: false)</pre> * * <pre> -G <num> * Minimum size of a group of attributes: * if numberOfGroups is true, the minimum number * of groups. * (default: 3)</pre> * * <pre> -H <num> * Maximum size of a group of attributes: * if numberOfGroups is true, the maximum number * of groups. * (default: 3)</pre> * * <pre> -P <num> * Percentage of instances to be removed. * (default: 50)</pre> * * <pre> -F <filter specification> * Full class name of filter to use, followed * by filter options. * eg: "weka.filters.unsupervised.attribute.PrincipalComponents-R 1.0"</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -I <num> * Number of iterations. * (default 10)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.trees.J48)</pre> * * <pre> * Options specific to classifier weka.classifiers.trees.J48: * </pre> * * <pre> -U * Use unpruned tree.</pre> * * <pre> -C <pruning confidence> * Set confidence threshold for pruning. * (default 0.25)</pre> * * <pre> -M <minimum number of instances> * Set minimum number of instances per leaf. * (default 2)</pre> * * <pre> -R * Use reduced error pruning.</pre> * * <pre> -N <number of folds> * Set number of folds for reduced error * pruning. One fold is used as pruning set. * (default 3)</pre> * * <pre> -B * Use binary splits only.</pre> * * <pre> -S * Don't perform subtree raising.</pre> * * <pre> -L * Do not clean up after the tree has been built.</pre> * * <pre> -A * Laplace smoothing for predicted probabilities.</pre> * * <pre> -Q <seed> * Seed for random data shuffling (default 1).</pre> * <!-- options-end --> * * @author Juan Jose Rodriguez (jjrodriguez@ubu.es) * @version $Revision: 7012 $ */ public class RotationForest extends RandomizableIteratedSingleClassifierEnhancer implements WeightedInstancesHandler, TechnicalInformationHandler { // It implements WeightedInstancesHandler because the base classifier // can implement this interface, but in this method the weights are // not used /** for serialization */ static final long serialVersionUID = -3255631880798499936L; /** The minimum size of a group */ protected int m_MinGroup = 3; /** The maximum size of a group */ protected int m_MaxGroup = 3; /** * Whether minGroup and maxGroup refer to the number of groups or their * size */ protected boolean m_NumberOfGroups = false; /** The percentage of instances to be removed */ protected int m_RemovedPercentage = 50; /** The attributes of each group */ protected int [][][] m_Groups = null; /** The type of projection filter */ protected Filter m_ProjectionFilter = null; /** The projection filters */ protected Filter [][] m_ProjectionFilters = null; /** Headers of the transformed dataset */ protected Instances [] m_Headers = null; /** Headers of the reduced datasets */ protected Instances [][] m_ReducedHeaders = null; /** Filter that remove useless attributes */ protected RemoveUseless m_RemoveUseless = null; /** Filter that normalized the attributes */ protected Normalize m_Normalize = null; /** * Constructor. */ public RotationForest() { m_Classifier = new weka.classifiers.trees.J48(); m_ProjectionFilter = defaultFilter(); } /** * Default projection method. */ protected Filter defaultFilter() { PrincipalComponents filter = new PrincipalComponents(); //filter.setNormalize(false); filter.setVarianceCovered(1.0); return filter; } /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for construction a Rotation Forest. Can do classification " + "and regression depending on the base learner. \n\n" + "For more information, see\n\n" + getTechnicalInformation().toString(); } /** * Returns an instance of a TechnicalInformation object, containing * detailed information about the technical background of this class, * e.g., paper reference or book this class is based on. * * @return the technical information about this class */ public TechnicalInformation getTechnicalInformation() { TechnicalInformation result; result = new TechnicalInformation(Type.ARTICLE); result.setValue(Field.AUTHOR, "Juan J. Rodriguez and Ludmila I. Kuncheva and Carlos J. Alonso"); result.setValue(Field.YEAR, "2006"); result.setValue(Field.TITLE, "Rotation Forest: A new classifier ensemble method"); result.setValue(Field.JOURNAL, "IEEE Transactions on Pattern Analysis and Machine Intelligence"); result.setValue(Field.VOLUME, "28"); result.setValue(Field.NUMBER, "10"); result.setValue(Field.PAGES, "1619-1630"); result.setValue(Field.ISSN, "0162-8828"); result.setValue(Field.URL, "http://doi.ieeecomputersociety.org/10.1109/TPAMI.2006.211"); return result; } /** * String describing default classifier. * * @return the default classifier classname */ protected String defaultClassifierString() { return "weka.classifiers.trees.J48"; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(5); newVector.addElement(new Option( "\tWhether minGroup (-G) and maxGroup (-H) refer to" + "\n\tthe number of groups or their size." + "\n\t(default: false)", "N", 0, "-N")); newVector.addElement(new Option( "\tMinimum size of a group of attributes:" + "\n\t\tif numberOfGroups is true, the minimum number" + "\n\t\tof groups." + "\n\t\t(default: 3)", "G", 1, "-G <num>")); newVector.addElement(new Option( "\tMaximum size of a group of attributes:" + "\n\t\tif numberOfGroups is true, the maximum number" + "\n\t\tof groups." + "\n\t\t(default: 3)", "H", 1, "-H <num>")); newVector.addElement(new Option( "\tPercentage of instances to be removed." + "\n\t\t(default: 50)", "P", 1, "-P <num>")); newVector.addElement(new Option( "\tFull class name of filter to use, followed\n" + "\tby filter options.\n" + "\teg: \"weka.filters.unsupervised.attribute.PrincipalComponents-R 1.0\"", "F", 1, "-F <filter specification>")); Enumeration enu = super.listOptions(); while (enu.hasMoreElements()) { newVector.addElement(enu.nextElement()); } return newVector.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -N * Whether minGroup (-G) and maxGroup (-H) refer to * the number of groups or their size. * (default: false)</pre> * * <pre> -G <num> * Minimum size of a group of attributes: * if numberOfGroups is true, the minimum number * of groups. * (default: 3)</pre> * * <pre> -H <num> * Maximum size of a group of attributes: * if numberOfGroups is true, the maximum number * of groups. * (default: 3)</pre> * * <pre> -P <num> * Percentage of instances to be removed. * (default: 50)</pre> * * <pre> -F <filter specification> * Full class name of filter to use, followed * by filter options. * eg: "weka.filters.unsupervised.attribute.PrincipalComponents-R 1.0"</pre> * * <pre> -S <num> * Random number seed. * (default 1)</pre> * * <pre> -I <num> * Number of iterations. * (default 10)</pre> * * <pre> -D * If set, classifier is run in debug mode and * may output additional info to the console</pre> * * <pre> -W * Full name of base classifier. * (default: weka.classifiers.trees.J48)</pre> * * <pre> * Options specific to classifier weka.classifiers.trees.J48: * </pre> * * <pre> -U * Use unpruned tree.</pre> * * <pre> -C <pruning confidence> * Set confidence threshold for pruning. * (default 0.25)</pre> * * <pre> -M <minimum number of instances> * Set minimum number of instances per leaf. * (default 2)</pre> * * <pre> -R * Use reduced error pruning.</pre> * * <pre> -N <number of folds> * Set number of folds for reduced error * pruning. One fold is used as pruning set. * (default 3)</pre> * * <pre> -B * Use binary splits only.</pre> * * <pre> -S * Don't perform subtree raising.</pre> * * <pre> -L * Do not clean up after the tree has been built.</pre> * * <pre> -A * Laplace smoothing for predicted probabilities.</pre> * * <pre> -Q <seed> * Seed for random data shuffling (default 1).</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { /* Taken from FilteredClassifier */ String filterString = Utils.getOption('F', options); if (filterString.length() > 0) { String [] filterSpec = Utils.splitOptions(filterString); if (filterSpec.length == 0) { throw new IllegalArgumentException("Invalid filter specification string"); } String filterName = filterSpec[0]; filterSpec[0] = ""; setProjectionFilter((Filter) Utils.forName(Filter.class, filterName, filterSpec)); } else { setProjectionFilter(defaultFilter()); } String tmpStr; tmpStr = Utils.getOption('G', options); if (tmpStr.length() != 0) setMinGroup(Integer.parseInt(tmpStr)); else setMinGroup(3); tmpStr = Utils.getOption('H', options); if (tmpStr.length() != 0) setMaxGroup(Integer.parseInt(tmpStr)); else setMaxGroup(3); tmpStr = Utils.getOption('P', options); if (tmpStr.length() != 0) setRemovedPercentage(Integer.parseInt(tmpStr)); else setRemovedPercentage(50); setNumberOfGroups(Utils.getFlag('N', options)); super.setOptions(options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] superOptions = super.getOptions(); String [] options = new String [superOptions.length + 9]; int current = 0; if (getNumberOfGroups()) { options[current++] = "-N"; } options[current++] = "-G"; options[current++] = "" + getMinGroup(); options[current++] = "-H"; options[current++] = "" + getMaxGroup(); options[current++] = "-P"; options[current++] = "" + getRemovedPercentage(); options[current++] = "-F"; options[current++] = getProjectionFilterSpec(); System.arraycopy(superOptions, 0, options, current, superOptions.length); current += superOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numberOfGroupsTipText() { return "Whether minGroup and maxGroup refer to the number of groups or their size."; } /** * Set whether minGroup and maxGroup refer to the number of groups or their * size * * @param numberOfGroups whether minGroup and maxGroup refer to the number * of groups or their size */ public void setNumberOfGroups(boolean numberOfGroups) { m_NumberOfGroups = numberOfGroups; } /** * Get whether minGroup and maxGroup refer to the number of groups or their * size * * @return whether minGroup and maxGroup refer to the number of groups or * their size */ public boolean getNumberOfGroups() { return m_NumberOfGroups; } /** * Returns the tip text for this property * @return tip text for this property suitable for displaying in the * explorer/experimenter gui */ public String minGroupTipText() { return "Minimum size of a group (if numberOfGrups is true, the minimum number of groups."; } /** * Sets the minimum size of a group. * * @param minGroup the minimum value. * of attributes. */ public void setMinGroup( int minGroup ) throws IllegalArgumentException { if( minGroup <= 0 ) throw new IllegalArgumentException( "MinGroup has to be positive." ); m_MinGroup = minGroup; } /** * Gets the minimum size of a group. * * @return the minimum value. */ public int getMinGroup() { return m_MinGroup; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String maxGroupTipText() { return "Maximum size of a group (if numberOfGrups is true, the maximum number of groups."; } /** * Sets the maximum size of a group. * * @param maxGroup the maximum value. * of attributes. */ public void setMaxGroup( int maxGroup ) throws IllegalArgumentException { if( maxGroup <= 0 ) throw new IllegalArgumentException( "MaxGroup has to be positive." ); m_MaxGroup = maxGroup; } /** * Gets the maximum size of a group. * * @return the maximum value. */ public int getMaxGroup() { return m_MaxGroup; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String removedPercentageTipText() { return "The percentage of instances to be removed."; } /** * Sets the percentage of instance to be removed * * @param removedPercentage the percentage. */ public void setRemovedPercentage( int removedPercentage ) throws IllegalArgumentException { if( removedPercentage < 0 ) throw new IllegalArgumentException( "RemovedPercentage has to be >=0." ); if( removedPercentage >= 100 ) throw new IllegalArgumentException( "RemovedPercentage has to be <100." ); m_RemovedPercentage = removedPercentage; } /** * Gets the percentage of instances to be removed * * @return the percentage. */ public int getRemovedPercentage() { return m_RemovedPercentage; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String projectionFilterTipText() { return "The filter used to project the data (e.g., PrincipalComponents)."; } /** * Sets the filter used to project the data. * * @param projectionFilter the filter. */ public void setProjectionFilter( Filter projectionFilter ) { m_ProjectionFilter = projectionFilter; } /** * Gets the filter used to project the data. * * @return the filter. */ public Filter getProjectionFilter() { return m_ProjectionFilter; } /** * Gets the filter specification string, which contains the class name of * the filter and any options to the filter * * @return the filter string. */ /* Taken from FilteredClassifier */ protected String getProjectionFilterSpec() { Filter c = getProjectionFilter(); if (c instanceof OptionHandler) { return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)c).getOptions()); } return c.getClass().getName(); } /** * Returns description of the Rotation Forest classifier. * * @return description of the Rotation Forest classifier as a string */ public String toString() { if (m_Classifiers == null) { return "RotationForest: No model built yet."; } StringBuffer text = new StringBuffer(); text.append("All the base classifiers: \n\n"); for (int i = 0; i < m_Classifiers.length; i++) text.append(m_Classifiers[i].toString() + "\n\n"); return text.toString(); } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 7012 $"); } /** * builds the classifier. * * @param data the training data to be used for generating the * classifier. * @throws Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); data = new Instances( data ); super.buildClassifier(data); checkMinMax(data); Random random; if( data.numInstances() > 0 ) { // This function fails if there are 0 instances random = data.getRandomNumberGenerator(m_Seed); } else { random = new Random(m_Seed); } m_RemoveUseless = new RemoveUseless(); m_RemoveUseless.setInputFormat(data); data = Filter.useFilter(data, m_RemoveUseless); m_Normalize = new Normalize(); m_Normalize.setInputFormat(data); data = Filter.useFilter(data, m_Normalize); if(m_NumberOfGroups) { generateGroupsFromNumbers(data, random); } else { generateGroupsFromSizes(data, random); } m_ProjectionFilters = new Filter[m_Groups.length][]; for(int i = 0; i < m_ProjectionFilters.length; i++ ) { m_ProjectionFilters[i] = Filter.makeCopies( m_ProjectionFilter, m_Groups[i].length ); } int numClasses = data.numClasses(); // Split the instances according to their class Instances [] instancesOfClass = new Instances[numClasses + 1]; if( data.classAttribute().isNumeric() ) { instancesOfClass = new Instances[numClasses]; instancesOfClass[0] = data; } else { instancesOfClass = new Instances[numClasses+1]; for( int i = 0; i < instancesOfClass.length; i++ ) { instancesOfClass[ i ] = new Instances( data, 0 ); } Enumeration enu = data.enumerateInstances(); while( enu.hasMoreElements() ) { Instance instance = (Instance)enu.nextElement(); if( instance.classIsMissing() ) { instancesOfClass[numClasses].add( instance ); } else { int c = (int)instance.classValue(); instancesOfClass[c].add( instance ); } } // If there are not instances with a missing class, we do not need to // consider them if( instancesOfClass[numClasses].numInstances() == 0 ) { Instances [] tmp = instancesOfClass; instancesOfClass = new Instances[ numClasses ]; System.arraycopy( tmp, 0, instancesOfClass, 0, numClasses ); } } // These arrays keep the information of the transformed data set m_Headers = new Instances[ m_Classifiers.length ]; m_ReducedHeaders = new Instances[ m_Classifiers.length ][]; // Construction of the base classifiers for(int i = 0; i < m_Classifiers.length; i++) { m_ReducedHeaders[i] = new Instances[ m_Groups[i].length ]; FastVector transformedAttributes = new FastVector( data.numAttributes() ); // Construction of the dataset for each group of attributes for( int j = 0; j < m_Groups[ i ].length; j++ ) { FastVector fv = new FastVector( m_Groups[i][j].length + 1 ); for( int k = 0; k < m_Groups[i][j].length; k++ ) { String newName = data.attribute( m_Groups[i][j][k] ).name() + "_" + k; fv.addElement( data.attribute( m_Groups[i][j][k] ).copy(newName) ); } fv.addElement( data.classAttribute( ).copy() ); Instances dataSubSet = new Instances( "rotated-" + i + "-" + j + "-", fv, 0); dataSubSet.setClassIndex( dataSubSet.numAttributes() - 1 ); // Select instances for the dataset m_ReducedHeaders[i][j] = new Instances( dataSubSet, 0 ); boolean [] selectedClasses = selectClasses( instancesOfClass.length, random ); for( int c = 0; c < selectedClasses.length; c++ ) { if( !selectedClasses[c] ) continue; Enumeration enu = instancesOfClass[c].enumerateInstances(); while( enu.hasMoreElements() ) { Instance instance = (Instance)enu.nextElement(); Instance newInstance = new DenseInstance(dataSubSet.numAttributes()); newInstance.setDataset( dataSubSet ); for( int k = 0; k < m_Groups[i][j].length; k++ ) { newInstance.setValue( k, instance.value( m_Groups[i][j][k] ) ); } newInstance.setClassValue( instance.classValue( ) ); dataSubSet.add( newInstance ); } } dataSubSet.randomize(random); // Remove a percentage of the instances Instances originalDataSubSet = dataSubSet; dataSubSet.randomize(random); RemovePercentage rp = new RemovePercentage(); rp.setPercentage( m_RemovedPercentage ); rp.setInputFormat( dataSubSet ); dataSubSet = Filter.useFilter( dataSubSet, rp ); if( dataSubSet.numInstances() < 2 ) { dataSubSet = originalDataSubSet; } // Project de data m_ProjectionFilters[i][j].setInputFormat( dataSubSet ); Instances projectedData = null; do { try { projectedData = Filter.useFilter( dataSubSet, m_ProjectionFilters[i][j] ); } catch ( Exception e ) { // The data could not be projected, we add some random instances addRandomInstances( dataSubSet, 10, random ); } } while( projectedData == null ); // Include the projected attributes in the attributes of the // transformed dataset for( int a = 0; a < projectedData.numAttributes() - 1; a++ ) { String newName = projectedData.attribute(a).name() + "_" + j; transformedAttributes.addElement( projectedData.attribute(a).copy(newName)); } } transformedAttributes.addElement( data.classAttribute().copy() ); Instances buildClas = new Instances( "rotated-" + i + "-", transformedAttributes, 0 ); buildClas.setClassIndex( buildClas.numAttributes() - 1 ); m_Headers[ i ] = new Instances( buildClas, 0 ); // Project all the training data Enumeration enu = data.enumerateInstances(); while( enu.hasMoreElements() ) { Instance instance = (Instance)enu.nextElement(); Instance newInstance = convertInstance( instance, i ); buildClas.add( newInstance ); } // Build the base classifier if (m_Classifier instanceof Randomizable) { ((Randomizable) m_Classifiers[i]).setSeed(random.nextInt()); } m_Classifiers[i].buildClassifier( buildClas ); } if(m_Debug){ printGroups(); } } /** * Adds random instances to the dataset. * * @param dataset the dataset * @param numInstances the number of instances * @param random a random number generator */ protected void addRandomInstances( Instances dataset, int numInstances, Random random ) { int n = dataset.numAttributes(); double [] v = new double[ n ]; for( int i = 0; i < numInstances; i++ ) { for( int j = 0; j < n; j++ ) { Attribute att = dataset.attribute( j ); if( att.isNumeric() ) { v[ j ] = random.nextDouble(); } else if ( att.isNominal() ) { v[ j ] = random.nextInt( att.numValues() ); } } dataset.add( new DenseInstance( 1, v ) ); } } /** * Checks m_MinGroup and m_MaxGroup * * @param data the dataset */ protected void checkMinMax(Instances data) { if( m_MinGroup > m_MaxGroup ) { int tmp = m_MaxGroup; m_MaxGroup = m_MinGroup; m_MinGroup = tmp; } int n = data.numAttributes(); if( m_MaxGroup >= n ) m_MaxGroup = n - 1; if( m_MinGroup >= n ) m_MinGroup = n - 1; } /** * Selects a non-empty subset of the classes * * @param numClasses the number of classes * @param random the random number generator. * @return a random subset of classes */ protected boolean [] selectClasses( int numClasses, Random random ) { int numSelected = 0; boolean selected[] = new boolean[ numClasses ]; for( int i = 0; i < selected.length; i++ ) { if(random.nextBoolean()) { selected[i] = true; numSelected++; } } if( numSelected == 0 ) { selected[random.nextInt( selected.length )] = true; } return selected; } /** * generates the groups of attributes, given their minimum and maximum * sizes. * * @param data the training data to be used for generating the * groups. * @param random the random number generator. */ protected void generateGroupsFromSizes(Instances data, Random random) { m_Groups = new int[m_Classifiers.length][][]; for( int i = 0; i < m_Classifiers.length; i++ ) { int [] permutation = attributesPermutation(data.numAttributes(), data.classIndex(), random); // The number of groups that have a given size int [] numGroupsOfSize = new int[m_MaxGroup - m_MinGroup + 1]; int numAttributes = 0; int numGroups; // Select the size of each group for( numGroups = 0; numAttributes < permutation.length; numGroups++ ) { int n = random.nextInt( numGroupsOfSize.length ); numGroupsOfSize[n]++; numAttributes += m_MinGroup + n; } m_Groups[i] = new int[numGroups][]; int currentAttribute = 0; int currentSize = 0; for( int j = 0; j < numGroups; j++ ) { while( numGroupsOfSize[ currentSize ] == 0 ) currentSize++; numGroupsOfSize[ currentSize ]--; int n = m_MinGroup + currentSize; m_Groups[i][j] = new int[n]; for( int k = 0; k < n; k++ ) { if( currentAttribute < permutation.length ) m_Groups[i][j][k] = permutation[ currentAttribute ]; else // For the last group, it can be necessary to reuse some attributes m_Groups[i][j][k] = permutation[ random.nextInt( permutation.length ) ]; currentAttribute++; } } } } /** * generates the groups of attributes, given their minimum and maximum * numbers. * * @param data the training data to be used for generating the * groups. * @param random the random number generator. */ protected void generateGroupsFromNumbers(Instances data, Random random) { m_Groups = new int[m_Classifiers.length][][]; for( int i = 0; i < m_Classifiers.length; i++ ) { int [] permutation = attributesPermutation(data.numAttributes(), data.classIndex(), random); int numGroups = m_MinGroup + random.nextInt(m_MaxGroup - m_MinGroup + 1); m_Groups[i] = new int[numGroups][]; int groupSize = permutation.length / numGroups; // Some groups will have an additional attribute int numBiggerGroups = permutation.length % numGroups; // Distribute the attributes in the groups int currentAttribute = 0; for( int j = 0; j < numGroups; j++ ) { if( j < numBiggerGroups ) { m_Groups[i][j] = new int[groupSize + 1]; } else { m_Groups[i][j] = new int[groupSize]; } for( int k = 0; k < m_Groups[i][j].length; k++ ) { m_Groups[i][j][k] = permutation[currentAttribute++]; } } } } /** * generates a permutation of the attributes. * * @param numAttributes the number of attributes. * @param classAttributes the index of the class attribute. * @param random the random number generator. * @return a permutation of the attributes */ protected int [] attributesPermutation(int numAttributes, int classAttribute, Random random) { int [] permutation = new int[numAttributes-1]; int i = 0; for(; i < classAttribute; i++){ permutation[i] = i; } for(; i < permutation.length; i++){ permutation[i] = i + 1; } permute( permutation, random ); return permutation; } /** * permutes the elements of a given array. * * @param v the array to permute * @param random the random number generator. */ protected void permute( int v[], Random random ) { for(int i = v.length - 1; i > 0; i-- ) { int j = random.nextInt( i + 1 ); if( i != j ) { int tmp = v[i]; v[i] = v[j]; v[j] = tmp; } } } /** * prints the groups. */ protected void printGroups( ) { for( int i = 0; i < m_Groups.length; i++ ) { for( int j = 0; j < m_Groups[i].length; j++ ) { System.err.print( "( " ); for( int k = 0; k < m_Groups[i][j].length; k++ ) { System.err.print( m_Groups[i][j][k] ); System.err.print( " " ); } System.err.print( ") " ); } System.err.println( ); } } /** * Transforms an instance for the i-th classifier. * * @param instance the instance to be transformed * @param i the base classifier number * @return the transformed instance * @throws Exception if the instance can't be converted successfully */ protected Instance convertInstance( Instance instance, int i ) throws Exception { Instance newInstance = new DenseInstance( m_Headers[ i ].numAttributes( ) ); newInstance.setWeight(instance.weight()); newInstance.setDataset( m_Headers[ i ] ); int currentAttribute = 0; // Project the data for each group for( int j = 0; j < m_Groups[i].length; j++ ) { Instance auxInstance = new DenseInstance( m_Groups[i][j].length + 1 ); int k; for( k = 0; k < m_Groups[i][j].length; k++ ) { auxInstance.setValue( k, instance.value( m_Groups[i][j][k] ) ); } auxInstance.setValue( k, instance.classValue( ) ); auxInstance.setDataset( m_ReducedHeaders[ i ][ j ] ); m_ProjectionFilters[i][j].input( auxInstance ); auxInstance = m_ProjectionFilters[i][j].output( ); m_ProjectionFilters[i][j].batchFinished(); for( int a = 0; a < auxInstance.numAttributes() - 1; a++ ) { newInstance.setValue( currentAttribute++, auxInstance.value( a ) ); } } newInstance.setClassValue( instance.classValue() ); return newInstance; } /** * Calculates the class membership probabilities for the given test * instance. * * @param instance the instance to be classified * @return preedicted class probability distribution * @throws Exception if distribution can't be computed successfully */ public double[] distributionForInstance(Instance instance) throws Exception { m_RemoveUseless.input(instance); instance =m_RemoveUseless.output(); m_RemoveUseless.batchFinished(); m_Normalize.input(instance); instance =m_Normalize.output(); m_Normalize.batchFinished(); double [] sums = new double [instance.numClasses()], newProbs; for (int i = 0; i < m_Classifiers.length; i++) { Instance convertedInstance = convertInstance(instance, i); if (instance.classAttribute().isNumeric() == true) { sums[0] += m_Classifiers[i].classifyInstance(convertedInstance); } else { newProbs = m_Classifiers[i].distributionForInstance(convertedInstance); for (int j = 0; j < newProbs.length; j++) sums[j] += newProbs[j]; } } if (instance.classAttribute().isNumeric() == true) { sums[0] /= (double)m_NumIterations; return sums; } else if (Utils.eq(Utils.sum(sums), 0)) { return sums; } else { Utils.normalize(sums); return sums; } } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { runClassifier(new RotationForest(), argv); } }