/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
/*
* RandomTree.java
* Copyright (C) 2001-2012 University of Waikato, Hamilton, New Zealand
*
*/
package weka.classifiers.trees;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.ContingencyTables;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
/**
* <!-- globalinfo-start -->
* Class for constructing a tree that considers K randomly chosen attributes at each node. Performs no pruning. Also has an option to allow estimation of class probabilities based on a hold-out set (backfitting).
* <p/>
* <!-- globalinfo-end -->
*
* <!-- options-start -->
* Valid options are: <p/>
*
* <pre> -K <number of attributes>
* Number of attributes to randomly investigate
* (<0 = int(log_2(#attributes)+1)).</pre>
*
* <pre> -M <minimum number of instances>
* Set minimum number of instances per leaf.</pre>
*
* <pre> -S <num>
* Seed for random number generator.
* (default 1)</pre>
*
* <pre> -depth <num>
* The maximum depth of the tree, 0 for unlimited.
* (default 0)</pre>
*
* <pre> -N <num>
* Number of folds for backfitting (default 0, no backfitting).</pre>
*
* <pre> -U
* Allow unclassified instances.</pre>
*
* <pre> -D
* If set, classifier is run in debug mode and
* may output additional info to the console</pre>
*
* <!-- options-end -->
*
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 8034 $
*/
public class RandomTree extends AbstractClassifier implements OptionHandler,
WeightedInstancesHandler, Randomizable, Drawable {
/** for serialization */
static final long serialVersionUID = 8934314652175299374L;
/** The subtrees appended to this tree. */
protected RandomTree[] m_Successors;
/** The attribute to split on. */
protected int m_Attribute = -1;
/** The split point. */
protected double m_SplitPoint = Double.NaN;
/** The header information. */
protected Instances m_Info = null;
/** The proportions of training instances going down each branch. */
protected double[] m_Prop = null;
/** Class probabilities from the training data. */
protected double[] m_ClassDistribution = null;
/** Minimum number of instances for leaf. */
protected double m_MinNum = 1.0;
/** The number of attributes considered for a split. */
protected int m_KValue = 0;
/** The random seed to use. */
protected int m_randomSeed = 1;
/** The maximum depth of the tree (0 = unlimited) */
protected int m_MaxDepth = 0;
/** Determines how much data is used for backfitting */
protected int m_NumFolds = 0;
/** Whether unclassified instances are allowed */
protected boolean m_AllowUnclassifiedInstances = false;
/** a ZeroR model in case no model can be built from the data */
protected Classifier m_ZeroR;
/**
* Returns a string describing classifier
*
* @return a description suitable for displaying in the
* explorer/experimenter gui
*/
public String globalInfo() {
return "Class for constructing a tree that considers K randomly "
+ " chosen attributes at each node. Performs no pruning. Also has"
+ " an option to allow estimation of class probabilities based on"
+ " a hold-out set (backfitting).";
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String minNumTipText() {
return "The minimum total weight of the instances in a leaf.";
}
/**
* Get the value of MinNum.
*
* @return Value of MinNum.
*/
public double getMinNum() {
return m_MinNum;
}
/**
* Set the value of MinNum.
*
* @param newMinNum
* Value to assign to MinNum.
*/
public void setMinNum(double newMinNum) {
m_MinNum = newMinNum;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String KValueTipText() {
return "Sets the number of randomly chosen attributes. If 0, log_2(number_of_attributes) + 1 is used.";
}
/**
* Get the value of K.
*
* @return Value of K.
*/
public int getKValue() {
return m_KValue;
}
/**
* Set the value of K.
*
* @param k
* Value to assign to K.
*/
public void setKValue(int k) {
m_KValue = k;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String seedTipText() {
return "The random number seed used for selecting attributes.";
}
/**
* Set the seed for random number generation.
*
* @param seed
* the seed
*/
public void setSeed(int seed) {
m_randomSeed = seed;
}
/**
* Gets the seed for the random number generations
*
* @return the seed for the random number generation
*/
public int getSeed() {
return m_randomSeed;
}
/**
* Returns the tip text for this property
*
* @return tip text for this property suitable for displaying in the
* explorer/experimenter gui
*/
public String maxDepthTipText() {
return "The maximum depth of the tree, 0 for unlimited.";
}
/**
* Get the maximum depth of trh tree, 0 for unlimited.
*
* @return the maximum depth.
*/
public int getMaxDepth() {
return m_MaxDepth;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numFoldsTipText() {
return "Determines the amount of data used for backfitting. One fold is used for "
+ "backfitting, the rest for growing the tree. (Default: 0, no backfitting)";
}
/**
* Get the value of NumFolds.
*
* @return Value of NumFolds.
*/
public int getNumFolds() {
return m_NumFolds;
}
/**
* Set the value of NumFolds.
*
* @param newNumFolds Value to assign to NumFolds.
*/
public void setNumFolds(int newNumFolds) {
m_NumFolds = newNumFolds;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String allowUnclassifiedInstancesTipText() {
return "Whether to allow unclassified instances.";
}
/**
* Get the value of NumFolds.
*
* @return Value of NumFolds.
*/
public boolean getAllowUnclassifiedInstances() {
return m_AllowUnclassifiedInstances;
}
/**
* Set the value of AllowUnclassifiedInstances.
*
* @param newAllowUnclassifiedInstances Value to assign to AllowUnclassifiedInstances.
*/
public void setAllowUnclassifiedInstances(boolean newAllowUnclassifiedInstances) {
m_AllowUnclassifiedInstances = newAllowUnclassifiedInstances;
}
/**
* Set the maximum depth of the tree, 0 for unlimited.
*
* @param value
* the maximum depth.
*/
public void setMaxDepth(int value) {
m_MaxDepth = value;
}
/**
* Lists the command-line options for this classifier.
*
* @return an enumeration over all possible options
*/
public Enumeration listOptions() {
Vector newVector = new Vector();
newVector.addElement(new Option(
"\tNumber of attributes to randomly investigate\n"
+ "\t(<0 = int(log_2(#attributes)+1)).", "K", 1,
"-K <number of attributes>"));
newVector.addElement(new Option(
"\tSet minimum number of instances per leaf.", "M", 1,
"-M <minimum number of instances>"));
newVector.addElement(new Option("\tSeed for random number generator.\n"
+ "\t(default 1)", "S", 1, "-S <num>"));
newVector.addElement(new Option(
"\tThe maximum depth of the tree, 0 for unlimited.\n"
+ "\t(default 0)", "depth", 1, "-depth <num>"));
newVector.
addElement(new Option("\tNumber of folds for backfitting " +
"(default 0, no backfitting).",
"N", 1, "-N <num>"));
newVector.
addElement(new Option("\tAllow unclassified instances.",
"U", 0, "-U"));
Enumeration enu = super.listOptions();
while (enu.hasMoreElements()) {
newVector.addElement(enu.nextElement());
}
return newVector.elements();
}
/**
* Gets options from this classifier.
*
* @return the options for the current setup
*/
public String[] getOptions() {
Vector result;
String[] options;
int i;
result = new Vector();
result.add("-K");
result.add("" + getKValue());
result.add("-M");
result.add("" + getMinNum());
result.add("-S");
result.add("" + getSeed());
if (getMaxDepth() > 0) {
result.add("-depth");
result.add("" + getMaxDepth());
}
if (getNumFolds() > 0) {
result.add("-N");
result.add("" + getNumFolds());
}
if (getAllowUnclassifiedInstances()) {
result.add("-U");
}
options = super.getOptions();
for (i = 0; i < options.length; i++)
result.add(options[i]);
return (String[]) result.toArray(new String[result.size()]);
}
/**
* Parses a given list of options.
* <p/>
*
* <!-- options-start -->
* Valid options are: <p/>
*
* <pre> -K <number of attributes>
* Number of attributes to randomly investigate
* (<0 = int(log_2(#attributes)+1)).</pre>
*
* <pre> -M <minimum number of instances>
* Set minimum number of instances per leaf.</pre>
*
* <pre> -S <num>
* Seed for random number generator.
* (default 1)</pre>
*
* <pre> -depth <num>
* The maximum depth of the tree, 0 for unlimited.
* (default 0)</pre>
*
* <pre> -N <num>
* Number of folds for backfitting (default 0, no backfitting).</pre>
*
* <pre> -U
* Allow unclassified instances.</pre>
*
* <pre> -D
* If set, classifier is run in debug mode and
* may output additional info to the console</pre>
*
* <!-- options-end -->
*
* @param options
* the list of options as an array of strings
* @throws Exception
* if an option is not supported
*/
public void setOptions(String[] options) throws Exception {
String tmpStr;
tmpStr = Utils.getOption('K', options);
if (tmpStr.length() != 0) {
m_KValue = Integer.parseInt(tmpStr);
} else {
m_KValue = 0;
}
tmpStr = Utils.getOption('M', options);
if (tmpStr.length() != 0) {
m_MinNum = Double.parseDouble(tmpStr);
} else {
m_MinNum = 1;
}
tmpStr = Utils.getOption('S', options);
if (tmpStr.length() != 0) {
setSeed(Integer.parseInt(tmpStr));
} else {
setSeed(1);
}
tmpStr = Utils.getOption("depth", options);
if (tmpStr.length() != 0) {
setMaxDepth(Integer.parseInt(tmpStr));
} else {
setMaxDepth(0);
}
String numFoldsString = Utils.getOption('N', options);
if (numFoldsString.length() != 0) {
m_NumFolds = Integer.parseInt(numFoldsString);
} else {
m_NumFolds = 0;
}
setAllowUnclassifiedInstances(Utils.getFlag('U', options));
super.setOptions(options);
Utils.checkForRemainingOptions(options);
}
/**
* Returns default capabilities of the classifier.
*
* @return the capabilities of this classifier
*/
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.DATE_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
// class
result.enable(Capability.NOMINAL_CLASS);
result.enable(Capability.MISSING_CLASS_VALUES);
return result;
}
/**
* Builds classifier.
*
* @param data
* the data to train with
* @throws Exception
* if something goes wrong or the data doesn't fit
*/
public void buildClassifier(Instances data) throws Exception {
// Make sure K value is in range
if (m_KValue > data.numAttributes() - 1)
m_KValue = data.numAttributes() - 1;
if (m_KValue < 1)
m_KValue = (int) Utils.log2(data.numAttributes()) + 1;
// can classifier handle the data?
getCapabilities().testWithFail(data);
// remove instances with missing class
data = new Instances(data);
data.deleteWithMissingClass();
// only class? -> build ZeroR model
if (data.numAttributes() == 1) {
System.err
.println("Cannot build model (only class attribute present in data!), "
+ "using ZeroR model instead!");
m_ZeroR = new weka.classifiers.rules.ZeroR();
m_ZeroR.buildClassifier(data);
return;
} else {
m_ZeroR = null;
}
// Figure out appropriate datasets
Instances train = null;
Instances backfit = null;
Random rand = data.getRandomNumberGenerator(m_randomSeed);
if (m_NumFolds <= 0) {
train = data;
} else {
data.randomize(rand);
data.stratify(m_NumFolds);
train = data.trainCV(m_NumFolds, 1, rand);
backfit = data.testCV(m_NumFolds, 1);
}
// Create the attribute indices window
int[] attIndicesWindow = new int[data.numAttributes() - 1];
int j = 0;
for (int i = 0; i < attIndicesWindow.length; i++) {
if (j == data.classIndex())
j++; // do not include the class
attIndicesWindow[i] = j++;
}
// Compute initial class counts
double[] classProbs = new double[train.numClasses()];
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
classProbs[(int) inst.classValue()] += inst.weight();
}
// Build tree
buildTree(train, classProbs, new Instances(data, 0), m_MinNum, m_Debug, attIndicesWindow,
rand, 0, getAllowUnclassifiedInstances());
// Backfit if required
if (backfit != null) {
backfitData(backfit);
}
}
/**
* Backfits the given data into the tree.
*/
public void backfitData(Instances data) throws Exception {
// Compute initial class counts
double[] classProbs = new double[data.numClasses()];
for (int i = 0; i < data.numInstances(); i++) {
Instance inst = data.instance(i);
classProbs[(int) inst.classValue()] += inst.weight();
}
// Fit data into tree
backfitData(data, classProbs);
}
/**
* Computes class distribution of an instance using the decision tree.
*
* @param instance
* the instance to compute the distribution for
* @return the computed class distribution
* @throws Exception
* if computation fails
*/
public double[] distributionForInstance(Instance instance) throws Exception {
// default model?
if (m_ZeroR != null) {
return m_ZeroR.distributionForInstance(instance);
}
double[] returnedDist = null;
if (m_Attribute > -1) {
// Node is not a leaf
if (instance.isMissing(m_Attribute)) {
// Value is missing
returnedDist = new double[m_Info.numClasses()];
// Split instance up
for (int i = 0; i < m_Successors.length; i++) {
double[] help = m_Successors[i]
.distributionForInstance(instance);
if (help != null) {
for (int j = 0; j < help.length; j++) {
returnedDist[j] += m_Prop[i] * help[j];
}
}
}
} else if (m_Info.attribute(m_Attribute).isNominal()) {
// For nominal attributes
returnedDist = m_Successors[(int) instance.value(m_Attribute)]
.distributionForInstance(instance);
} else {
// For numeric attributes
if (instance.value(m_Attribute) < m_SplitPoint) {
returnedDist = m_Successors[0]
.distributionForInstance(instance);
} else {
returnedDist = m_Successors[1]
.distributionForInstance(instance);
}
}
}
// Node is a leaf or successor is empty?
if ((m_Attribute == -1) || (returnedDist == null)) {
// Is node empty?
if (m_ClassDistribution == null) {
if (getAllowUnclassifiedInstances()) {
return new double[m_Info.numClasses()];
} else {
return null;
}
}
// Else return normalized distribution
double[] normalizedDistribution = (double[]) m_ClassDistribution.clone();
Utils.normalize(normalizedDistribution);
return normalizedDistribution;
} else {
return returnedDist;
}
}
/**
* Outputs the decision tree as a graph
*
* @return the tree as a graph
*/
public String toGraph() {
try {
StringBuffer resultBuff = new StringBuffer();
toGraph(resultBuff, 0);
String result = "digraph Tree {\n" + "edge [style=bold]\n"
+ resultBuff.toString() + "\n}\n";
return result;
} catch (Exception e) {
return null;
}
}
/**
* Outputs one node for graph.
*
* @param text
* the buffer to append the output to
* @param num
* unique node id
* @return the next node id
* @throws Exception
* if generation fails
*/
public int toGraph(StringBuffer text, int num) throws Exception {
int maxIndex = Utils.maxIndex(m_ClassDistribution);
String classValue = m_Info.classAttribute().value(maxIndex);
num++;
if (m_Attribute == -1) {
text.append("N" + Integer.toHexString(hashCode()) + " [label=\""
+ num + ": " + classValue + "\"" + "shape=box]\n");
} else {
text.append("N" + Integer.toHexString(hashCode()) + " [label=\""
+ num + ": " + classValue + "\"]\n");
for (int i = 0; i < m_Successors.length; i++) {
text.append("N" + Integer.toHexString(hashCode()) + "->" + "N"
+ Integer.toHexString(m_Successors[i].hashCode())
+ " [label=\"" + m_Info.attribute(m_Attribute).name());
if (m_Info.attribute(m_Attribute).isNumeric()) {
if (i == 0) {
text.append(" < "
+ Utils.doubleToString(m_SplitPoint, 2));
} else {
text.append(" >= "
+ Utils.doubleToString(m_SplitPoint, 2));
}
} else {
text.append(" = " + m_Info.attribute(m_Attribute).value(i));
}
text.append("\"]\n");
num = m_Successors[i].toGraph(text, num);
}
}
return num;
}
/**
* Outputs the decision tree.
*
* @return a string representation of the classifier
*/
public String toString() {
// only ZeroR model?
if (m_ZeroR != null) {
StringBuffer buf = new StringBuffer();
buf
.append(this.getClass().getName().replaceAll(".*\\.", "")
+ "\n");
buf.append(this.getClass().getName().replaceAll(".*\\.", "")
.replaceAll(".", "=")
+ "\n\n");
buf
.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
buf.append(m_ZeroR.toString());
return buf.toString();
}
if (m_Successors == null) {
return "RandomTree: no model has been built yet.";
} else {
return "\nRandomTree\n==========\n"
+ toString(0)
+ "\n"
+ "\nSize of the tree : "
+ numNodes()
+ (getMaxDepth() > 0 ? ("\nMax depth of tree: " + getMaxDepth())
: (""));
}
}
/**
* Outputs a leaf.
*
* @return the leaf as string
* @throws Exception
* if generation fails
*/
protected String leafString() throws Exception {
double sum = 0, maxCount = 0;
int maxIndex = 0;
if (m_ClassDistribution != null) {
sum = Utils.sum(m_ClassDistribution);
maxIndex = Utils.maxIndex(m_ClassDistribution);
maxCount = m_ClassDistribution[maxIndex];
}
return " : "
+ m_Info.classAttribute().value(maxIndex)
+ " ("
+ Utils.doubleToString(sum, 2)
+ "/"
+ Utils.doubleToString(sum - maxCount, 2) + ")";
}
/**
* Recursively outputs the tree.
*
* @param level
* the current level of the tree
* @return the generated subtree
*/
protected String toString(int level) {
try {
StringBuffer text = new StringBuffer();
if (m_Attribute == -1) {
// Output leaf info
return leafString();
} else if (m_Info.attribute(m_Attribute).isNominal()) {
// For nominal attributes
for (int i = 0; i < m_Successors.length; i++) {
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
text.append(m_Info.attribute(m_Attribute).name() + " = "
+ m_Info.attribute(m_Attribute).value(i));
text.append(m_Successors[i].toString(level + 1));
}
} else {
// For numeric attributes
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
text.append(m_Info.attribute(m_Attribute).name() + " < "
+ Utils.doubleToString(m_SplitPoint, 2));
text.append(m_Successors[0].toString(level + 1));
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
text.append(m_Info.attribute(m_Attribute).name() + " >= "
+ Utils.doubleToString(m_SplitPoint, 2));
text.append(m_Successors[1].toString(level + 1));
}
return text.toString();
} catch (Exception e) {
e.printStackTrace();
return "RandomTree: tree can't be printed";
}
}
/**
* Recursively backfits data into the tree.
*
* @param data
* the data to work with
* @param classProbs
* the class distribution
* @throws Exception
* if generation fails
*/
protected void backfitData(Instances data, double[] classProbs) throws Exception {
// Make leaf if there are no training instances
if (data.numInstances() == 0) {
m_Attribute = -1;
m_ClassDistribution = null;
m_Prop = null;
return;
}
// Check if node doesn't contain enough instances or is pure
// or maximum depth reached
m_ClassDistribution = (double[]) classProbs.clone();
/* if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum
|| Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)], Utils
.sum(m_ClassDistribution))) {
// Make leaf
m_Attribute = -1;
m_Prop = null;
return;
}*/
// Are we at an inner node
if (m_Attribute > -1) {
// Compute new weights for subsets based on backfit data
m_Prop = new double[m_Successors.length];
for (int i = 0; i < data.numInstances(); i++) {
Instance inst = data.instance(i);
if (!inst.isMissing(m_Attribute)) {
if (data.attribute(m_Attribute).isNominal()) {
m_Prop[(int)inst.value(m_Attribute)] += inst.weight();
} else {
m_Prop[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1] += inst.weight();
}
}
}
// If we only have missing values we can make this node into a leaf
if (Utils.sum(m_Prop) <= 0) {
m_Attribute = -1;
m_Prop = null;
return;
}
// Otherwise normalize the proportions
Utils.normalize(m_Prop);
// Split data
Instances[] subsets = splitData(data);
// Go through subsets
for (int i = 0; i < subsets.length; i++) {
// Compute distribution for current subset
double[] dist = new double[data.numClasses()];
for (int j = 0; j < subsets[i].numInstances(); j++) {
dist[(int)subsets[i].instance(j).classValue()] += subsets[i].instance(j).weight();
}
// Backfit subset
m_Successors[i].backfitData(subsets[i], dist);
}
// If unclassified instances are allowed, we don't need to store the class distribution
if (getAllowUnclassifiedInstances()) {
m_ClassDistribution = null;
return;
}
// Otherwise, if all successors are non-empty, we don't need to store the class distribution
boolean emptySuccessor = false;
for (int i = 0; i < subsets.length; i++) {
if (m_Successors[i].m_ClassDistribution == null) {
emptySuccessor = true;
return;
}
}
m_ClassDistribution = null;
// If we have a least two non-empty successors, we should keep this tree
/* int nonEmptySuccessors = 0;
for (int i = 0; i < subsets.length; i++) {
if (m_Successors[i].m_ClassDistribution != null) {
nonEmptySuccessors++;
if (nonEmptySuccessors > 1) {
return;
}
}
}
// Otherwise, this node is a leaf or should become a leaf
m_Successors = null;
m_Attribute = -1;
m_Prop = null;
return;*/
}
}
/**
* Recursively generates a tree.
*
* @param data
* the data to work with
* @param classProbs
* the class distribution
* @param header
* the header of the data
* @param minNum
* the minimum number of instances per leaf
* @param debug
* whether debugging is on
* @param attIndicesWindow
* the attribute window to choose attributes from
* @param random
* random number generator for choosing random attributes
* @param depth
* the current depth
* @param determineStructure
* whether to determine structure
* @throws Exception
* if generation fails
*/
protected void buildTree(Instances data, double[] classProbs, Instances header,
double minNum, boolean debug, int[] attIndicesWindow,
Random random, int depth, boolean allow) throws Exception {
// Store structure of dataset, set minimum number of instances
m_Info = header;
m_Debug = debug;
m_MinNum = minNum;
m_AllowUnclassifiedInstances = allow;
// Make leaf if there are no training instances
if (data.numInstances() == 0) {
m_Attribute = -1;
m_ClassDistribution = null;
m_Prop = null;
return;
}
// Check if node doesn't contain enough instances or is pure
// or maximum depth reached
m_ClassDistribution = (double[]) classProbs.clone();
if (Utils.sum(m_ClassDistribution) < 2 * m_MinNum
|| Utils.eq(m_ClassDistribution[Utils.maxIndex(m_ClassDistribution)], Utils
.sum(m_ClassDistribution))
|| ((getMaxDepth() > 0) && (depth >= getMaxDepth()))) {
// Make leaf
m_Attribute = -1;
m_Prop = null;
return;
}
// Compute class distributions and value of splitting
// criterion for each attribute
double[] vals = new double[data.numAttributes()];
double[][][] dists = new double[data.numAttributes()][0][0];
double[][] props = new double[data.numAttributes()][0];
double[] splits = new double[data.numAttributes()];
// Investigate K random attributes
int attIndex = 0;
int windowSize = attIndicesWindow.length;
int k = m_KValue;
boolean gainFound = false;
while ((windowSize > 0) && (k-- > 0 || !gainFound)) {
int chosenIndex = random.nextInt(windowSize);
attIndex = attIndicesWindow[chosenIndex];
// shift chosen attIndex out of window
attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize - 1];
attIndicesWindow[windowSize - 1] = attIndex;
windowSize--;
splits[attIndex] = distribution(props, dists, attIndex, data);
vals[attIndex] = gain(dists[attIndex], priorVal(dists[attIndex]));
if (Utils.gr(vals[attIndex], 0))
gainFound = true;
}
// Find best attribute
m_Attribute = Utils.maxIndex(vals);
double[][] distribution = dists[m_Attribute];
// Any useful split found?
if (Utils.gr(vals[m_Attribute], 0)) {
// Build subtrees
m_SplitPoint = splits[m_Attribute];
m_Prop = props[m_Attribute];
Instances[] subsets = splitData(data);
m_Successors = new RandomTree[distribution.length];
for (int i = 0; i < distribution.length; i++) {
m_Successors[i] = new RandomTree();
m_Successors[i].setKValue(m_KValue);
m_Successors[i].setMaxDepth(getMaxDepth());
m_Successors[i].buildTree(subsets[i], distribution[i], header, m_MinNum, m_Debug,
attIndicesWindow, random, depth + 1, allow);
}
// If all successors are non-empty, we don't need to store the class distribution
boolean emptySuccessor = false;
for (int i = 0; i < subsets.length; i++) {
if (m_Successors[i].m_ClassDistribution == null) {
emptySuccessor = true;
break;
}
}
if (!emptySuccessor) {
m_ClassDistribution = null;
}
} else {
// Make leaf
m_Attribute = -1;
}
}
/**
* Computes size of the tree.
*
* @return the number of nodes
*/
public int numNodes() {
if (m_Attribute == -1) {
return 1;
} else {
int size = 1;
for (int i = 0; i < m_Successors.length; i++) {
size += m_Successors[i].numNodes();
}
return size;
}
}
/**
* Splits instances into subsets based on the given split.
*
* @param data
* the data to work with
* @return the subsets of instances
* @throws Exception
* if something goes wrong
*/
protected Instances[] splitData(Instances data) throws Exception {
// Allocate array of Instances objects
Instances[] subsets = new Instances[m_Prop.length];
for (int i = 0; i < m_Prop.length; i++) {
subsets[i] = new Instances(data, data.numInstances());
}
// Go through the data
for (int i = 0; i < data.numInstances(); i++) {
// Get instance
Instance inst = data.instance(i);
// Does the instance have a missing value?
if (inst.isMissing(m_Attribute)) {
// Split instance up
for (int k = 0; k < m_Prop.length; k++) {
if (m_Prop[k] > 0) {
Instance copy = (Instance)inst.copy();
copy.setWeight(m_Prop[k] * inst.weight());
subsets[k].add(copy);
}
}
// Proceed to next instance
continue;
}
// Do we have a nominal attribute?
if (data.attribute(m_Attribute).isNominal()) {
subsets[(int)inst.value(m_Attribute)].add(inst);
// Proceed to next instance
continue;
}
// Do we have a numeric attribute?
if (data.attribute(m_Attribute).isNumeric()) {
subsets[(inst.value(m_Attribute) < m_SplitPoint) ? 0 : 1].add(inst);
// Proceed to next instance
continue;
}
// Else throw an exception
throw new IllegalArgumentException("Unknown attribute type");
}
// Save memory
for (int i = 0; i < m_Prop.length; i++) {
subsets[i].compactify();
}
// Return the subsets
return subsets;
}
/**
* Computes class distribution for an attribute.
*
* @param props
* @param dists
* @param att
* the attribute index
* @param data
* the data to work with
* @throws Exception
* if something goes wrong
*/
protected double distribution(double[][] props, double[][][] dists, int att, Instances data)
throws Exception {
double splitPoint = Double.NaN;
Attribute attribute = data.attribute(att);
double[][] dist = null;
int indexOfFirstMissingValue = -1;
if (attribute.isNominal()) {
// For nominal attributes
dist = new double[attribute.numValues()][data.numClasses()];
for (int i = 0; i < data.numInstances(); i++) {
Instance inst = data.instance(i);
if (inst.isMissing(att)) {
// Skip missing values at this stage
if (indexOfFirstMissingValue < 0) {
indexOfFirstMissingValue = i;
}
continue;
}
dist[(int) inst.value(att)][(int) inst.classValue()] += inst.weight();
}
} else {
// For numeric attributes
double[][] currDist = new double[2][data.numClasses()];
dist = new double[2][data.numClasses()];
// Sort data
data.sort(att);
// Move all instances into second subset
for (int j = 0; j < data.numInstances(); j++) {
Instance inst = data.instance(j);
if (inst.isMissing(att)) {
// Can stop as soon as we hit a missing value
indexOfFirstMissingValue = j;
break;
}
currDist[1][(int) inst.classValue()] += inst.weight();
}
// Value before splitting
double priorVal = priorVal(currDist);
// Save initial distribution
for (int j = 0; j < currDist.length; j++) {
System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);
}
// Try all possible split points
double currSplit = data.instance(0).value(att);
double currVal, bestVal = -Double.MAX_VALUE;
for (int i = 0; i < data.numInstances(); i++) {
Instance inst = data.instance(i);
if (inst.isMissing(att)) {
// Can stop as soon as we hit a missing value
break;
}
// Can we place a sensible split point here?
if (inst.value(att) > currSplit) {
// Compute gain for split point
currVal = gain(currDist, priorVal);
// Is the current split point the best point so far?
if (currVal > bestVal) {
// Store value of current point
bestVal = currVal;
// Save split point
splitPoint = (inst.value(att) + currSplit) / 2.0;
// Save distribution
for (int j = 0; j < currDist.length; j++) {
System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);
}
}
}
currSplit = inst.value(att);
// Shift over the weight
currDist[0][(int) inst.classValue()] += inst.weight();
currDist[1][(int) inst.classValue()] -= inst.weight();
}
}
// Compute weights for subsets
props[att] = new double[dist.length];
for (int k = 0; k < props[att].length; k++) {
props[att][k] = Utils.sum(dist[k]);
}
if (Utils.eq(Utils.sum(props[att]), 0)) {
for (int k = 0; k < props[att].length; k++) {
props[att][k] = 1.0 / (double) props[att].length;
}
} else {
Utils.normalize(props[att]);
}
// Any instances with missing values ?
if (indexOfFirstMissingValue > -1) {
// Distribute weights for instances with missing values
for (int i = indexOfFirstMissingValue; i < data.numInstances(); i++) {
Instance inst = data.instance(i);
if (attribute.isNominal()) {
// Need to check if attribute value is missing
if (inst.isMissing(att)) {
for (int j = 0; j < dist.length; j++) {
dist[j][(int) inst.classValue()] += props[att][j] * inst.weight();
}
}
} else {
// Can be sure that value is missing, so no test required
for (int j = 0; j < dist.length; j++) {
dist[j][(int) inst.classValue()] += props[att][j] * inst.weight();
}
}
}
}
// Return distribution and split point
dists[att] = dist;
return splitPoint;
}
/**
* Computes value of splitting criterion before split.
*
* @param dist
* the distributions
* @return the splitting criterion
*/
protected double priorVal(double[][] dist) {
return ContingencyTables.entropyOverColumns(dist);
}
/**
* Computes value of splitting criterion after split.
*
* @param dist
* the distributions
* @param priorVal
* the splitting criterion
* @return the gain after the split
*/
protected double gain(double[][] dist, double priorVal) {
return priorVal - ContingencyTables.entropyConditionedOnRows(dist);
}
/**
* Returns the revision string.
*
* @return the revision
*/
public String getRevision() {
return RevisionUtils.extract("$Revision: 8034 $");
}
/**
* Main method for this class.
*
* @param argv
* the commandline parameters
*/
public static void main(String[] argv) {
runClassifier(new RandomTree(), argv);
}
/**
* Returns graph describing the tree.
*
* @return the graph describing the tree
* @throws Exception
* if graph can't be computed
*/
public String graph() throws Exception {
if (m_Successors == null) {
throw new Exception("RandomTree: No model built yet.");
}
StringBuffer resultBuff = new StringBuffer();
toGraph(resultBuff, 0, null);
String result = "digraph RandomTree {\n" + "edge [style=bold]\n"
+ resultBuff.toString() + "\n}\n";
return result;
}
/**
* Returns the type of graph this classifier represents.
*
* @return Drawable.TREE
*/
public int graphType() {
return Drawable.TREE;
}
/**
* Outputs one node for graph.
*
* @param text
* the buffer to append the output to
* @param num
* the current node id
* @param parent
* the parent of the nodes
* @return the next node id
* @throws Exception
* if something goes wrong
*/
protected int toGraph(StringBuffer text, int num, RandomTree parent)
throws Exception {
num++;
if (m_Attribute == -1) {
text.append("N" + Integer.toHexString(RandomTree.this.hashCode())
+ " [label=\"" + num + leafString() + "\""
+ " shape=box]\n");
} else {
text.append("N" + Integer.toHexString(RandomTree.this.hashCode())
+ " [label=\"" + num + ": "
+ m_Info.attribute(m_Attribute).name() + "\"]\n");
for (int i = 0; i < m_Successors.length; i++) {
text.append("N"
+ Integer.toHexString(RandomTree.this.hashCode())
+ "->" + "N"
+ Integer.toHexString(m_Successors[i].hashCode())
+ " [label=\"");
if (m_Info.attribute(m_Attribute).isNumeric()) {
if (i == 0) {
text.append(" < "
+ Utils.doubleToString(m_SplitPoint, 2));
} else {
text.append(" >= "
+ Utils.doubleToString(m_SplitPoint, 2));
}
} else {
text.append(" = " + m_Info.attribute(m_Attribute).value(i));
}
text.append("\"]\n");
num = m_Successors[i].toGraph(text, num, this);
}
}
return num;
}
}