/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * RandomForest.java * Copyright (C) 2001 Richard Kirkby * */ package weka.classifiers.trees; import weka.classifiers.*; import weka.classifiers.meta.Bagging; import weka.core.*; import java.util.*; /** * Class for constructing random forests. * * For more information see: <p> * Leo Breiman. Random Forests. Machine Learning 45 (1):5-32, October 2001. <p> * * Valid options are: <p> * * -I num <br> * Set the number of trees in the forest * (default 10) <p> * * -K num <br> * Set the number of features to consider. * If < 1 (the default) will use logM+1, where M is the number of inputs. <p> * * -S seed <br> * Random number seed (default 1). <p> * * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @version $Revision: 1.1.1.1 $ */ public class RandomForest extends DistributionClassifier implements OptionHandler, Randomizable, WeightedInstancesHandler { /** Number of trees in forest. */ protected int m_numTrees = 10; /** Number of features to consider in random feature selection. If less than 1 will use int(logM+1) ) */ protected int m_numFeatures = 0; /** The random seed. */ protected int m_randomSeed = 1; /** Final number of features that were considered in last build. */ protected int m_KValue = 0; /** The bagger. */ protected Bagging m_bagger = null; /** * Get the value of numTrees. * * @return Value of numTrees. */ public int getNumTrees() { return m_numTrees; } /** * Set the value of numTrees. * * @param newNumTrees Value to assign to numTrees. */ public void setNumTrees(int newNumTrees) { m_numTrees = newNumTrees; } /** * Get the number of features used in random selection. * * @return Value of numFeatures. */ public int getNumFeatures() { return m_numFeatures; } /** * Set the number of features to use in random selection. * * @param newNumFeatures Value to assign to numFeatures. */ public void setNumFeatures(int newNumFeatures) { m_numFeatures = newNumFeatures; } /** * Set the seed for random number generation. * * @param seed the seed */ public void setSeed(int seed) { m_randomSeed = seed; } /** * Gets the seed for the random number generations * * @return the seed for the random number generation */ public int getSeed() { return m_randomSeed; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(3); newVector. addElement(new Option("\tNumber of trees to build.", "I", 1, "-I <number of trees>")); newVector. addElement(new Option("\tNumber of features to consider (<1=int(logM+1)).", "K", 1, "-K <number of features>")); newVector .addElement(new Option("\tSeed for random number generator.\n" + "\t(default 1)", "S", 1, "-S")); return newVector.elements(); } /** * Gets the current settings of the forest. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions() { String [] options = new String [10]; int current = 0; options[current++] = "-I"; options[current++] = "" + getNumTrees(); options[current++] = "-K"; options[current++] = "" + getNumFeatures(); options[current++] = "-S"; options[current++] = "" + getSeed(); while (current < options.length) { options[current++] = ""; } return options; } /** * Parses a given list of options. * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception{ String numTreesString = Utils.getOption('I', options); if (numTreesString.length() != 0) { m_numTrees = Integer.parseInt(numTreesString); } else { m_numTrees = 10; } String numFeaturesString = Utils.getOption('K', options); if (numFeaturesString.length() != 0) { m_numFeatures = Integer.parseInt(numFeaturesString); } else { m_numFeatures = 0; } String seed = Utils.getOption('S', options); if (seed.length() != 0) { setSeed(Integer.parseInt(seed)); } else { setSeed(1); } Utils.checkForRemainingOptions(options); } /** * Builds a classifier for a set of instances. * * @param instances the instances to train the classifier with * @exception Exception if something goes wrong */ public void buildClassifier(Instances data) throws Exception { m_bagger = new Bagging(); RandomTree rTree = new RandomTree(); // set up the random tree options m_KValue = m_numFeatures; if (m_KValue < 1) m_KValue = (int) Utils.log2(data.numAttributes())+1; rTree.setKValue(m_KValue); // set up the bagger and build the forest m_bagger.setClassifier(rTree); m_bagger.setSeed(m_randomSeed); m_bagger.setNumIterations(m_numTrees); m_bagger.buildClassifier(data); } /** * Returns the class probability distribution for an instance. * * @param instance the instance to be classified * @return the distribution the forest generates for the instance */ public double[] distributionForInstance(Instance instance) throws Exception { return m_bagger.distributionForInstance(instance); } /** * Outputs a description of this classifier. * * @return a string containing a description of the classifier */ public String toString() { if (m_bagger == null) return "Random forest not built yet"; else return "Random forest of " + m_numTrees + " trees, each constructed while considering " + m_KValue + " random feature" + (m_KValue==1 ? "" : "s") + ".\n\n"; } /** * Main method for this class. * * @param argv the options */ public static void main(String[] argv) { try { System.out.println(Evaluation.evaluateModel(new RandomForest(), argv)); } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } } }