/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* RandomTree.java
* Copyright (C) 2001 Richard Kirkby, Eibe Frank
*
*/
package weka.classifiers.trees;
import weka.classifiers.*;
import weka.core.*;
import java.util.*;
/**
* Class for constructing a tree that considers K random features at each node.
* Performs no pruning.
*
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
* @version $Revision: 1.1.1.1 $
*/
public class RandomTree extends DistributionClassifier
implements OptionHandler, WeightedInstancesHandler, Randomizable {
/** The subtrees appended to this tree. */
protected RandomTree[] m_Successors;
/** The attribute to split on. */
protected int m_Attribute = -1;
/** The split point. */
protected double m_SplitPoint = Double.NaN;
/** The class distribution from the training data. */
protected double[][] m_Distribution = null;
/** The header information. */
protected Instances m_Info = null;
/** The proportions of training instances going down each branch. */
protected double[] m_Prop = null;
/** Class probabilities from the training data. */
protected double[] m_ClassProbs = null;
/** Minimum number of instances for leaf. */
protected double m_MinNum = 1.0;
/** Debug info */
protected boolean m_Debug = false;
/** The number of attributes considered for a split. */
protected int m_KValue = 1;
/** Random number generator. */
protected Random m_random;
/** The random seed to use. */
protected int m_randomSeed = 1;
/**
* Get the value of MinNum.
*
* @return Value of MinNum.
*/
public double getMinNum() {
return m_MinNum;
}
/**
* Set the value of MinNum.
*
* @param newMinNum Value to assign to MinNum.
*/
public void setMinNum(double newMinNum) {
m_MinNum = newMinNum;
}
/**
* Get the value of K.
*
* @return Value of K.
*/
public int getKValue() {
return m_KValue;
}
/**
* Set the value of K.
*
* @param k Value to assign to K.
*/
public void setKValue(int k) {
m_KValue = k;
}
/**
* Get the value of Debug.
*
* @return Value of Debug.
*/
public boolean getDebug() {
return m_Debug;
}
/**
* Set the value of Debug.
*
* @param newDebug Value to assign to Debug.
*/
public void setDebug(boolean newDebug) {
m_Debug = newDebug;
}
/**
* Set the seed for random number generation.
*
* @param seed the seed
*/
public void setSeed(int seed) {
m_randomSeed = seed;
}
/**
* Gets the seed for the random number generations
*
* @return the seed for the random number generation
*/
public int getSeed() {
return m_randomSeed;
}
/**
* Lists the command-line options for this classifier.
*/
public Enumeration listOptions() {
Vector newVector = new Vector(6);
newVector.
addElement(new Option("\tNumber of attributes to randomly investigate.",
"K", 1, "-K <number of attributes>"));
newVector.
addElement(new Option("\tSet minimum number of instances per leaf.",
"M", 1, "-M <minimum number of instances>"));
newVector.
addElement(new Option("\tTurns debugging info on.",
"D", 0, "-D"));
newVector
.addElement(new Option("\tSeed for random number generator.\n"
+ "\t(default 1)",
"S", 1, "-S"));
return newVector.elements();
}
/**
* Gets options from this classifier.
*/
public String[] getOptions() {
String [] options = new String [10];
int current = 0;
options[current++] = "-K";
options[current++] = "" + getKValue();
options[current++] = "-M";
options[current++] = "" + getMinNum();
options[current++] = "-S";
options[current++] = "" + getSeed();
if (getDebug()) {
options[current++] = "-D";
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Parses a given list of options.
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*/
public void setOptions(String[] options) throws Exception{
String kValueString = Utils.getOption('K', options);
if (kValueString.length() != 0) {
m_KValue = Integer.parseInt(kValueString);
} else {
m_KValue = 1;
}
String minNumString = Utils.getOption('M', options);
if (minNumString.length() != 0) {
m_MinNum = (double)Integer.parseInt(minNumString);
} else {
m_MinNum = 1;
}
String seed = Utils.getOption('S', options);
if (seed.length() != 0) {
setSeed(Integer.parseInt(seed));
} else {
setSeed(1);
}
m_Debug = Utils.getFlag('D', options);
Utils.checkForRemainingOptions(options);
}
/**
* Builds classifier.
*/
public void buildClassifier(Instances data) throws Exception {
// Make sure K value is in range
if (m_KValue > data.numAttributes()-1) m_KValue = data.numAttributes()-1;
// Delete instances with missing class
data = new Instances(data);
data.deleteWithMissingClass();
Instances train = data;
// Create array of sorted indices and weights
int[][] sortedIndices = new int[train.numAttributes()][0];
double[][] weights = new double[train.numAttributes()][0];
double[] vals = new double[train.numInstances()];
for (int j = 0; j < train.numAttributes(); j++) {
if (j != train.classIndex()) {
weights[j] = new double[train.numInstances()];
if (train.attribute(j).isNominal()) {
// Handling nominal attributes. Putting indices of
// instances with missing values at the end.
sortedIndices[j] = new int[train.numInstances()];
int count = 0;
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
if (!inst.isMissing(j)) {
sortedIndices[j][count] = i;
weights[j][count] = inst.weight();
count++;
}
}
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
if (inst.isMissing(j)) {
sortedIndices[j][count] = i;
weights[j][count] = inst.weight();
count++;
}
}
} else {
// Sorted indices are computed for numeric attributes
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
vals[i] = inst.value(j);
}
sortedIndices[j] = Utils.sort(vals);
for (int i = 0; i < train.numInstances(); i++) {
weights[j][i] = train.instance(sortedIndices[j][i]).weight();
}
}
}
}
// Compute initial class counts
double[] classProbs = new double[train.numClasses()];
for (int i = 0; i < train.numInstances(); i++) {
Instance inst = train.instance(i);
classProbs[(int)inst.classValue()] += inst.weight();
}
// Create the attribute indices window
int[] attIndicesWindow = new int[data.numAttributes()-1];
int j=0;
for (int i=0; i<attIndicesWindow.length; i++) {
if (j == data.classIndex()) j++; // do not include the class
attIndicesWindow[i] = j++;
}
// Build tree
buildTree(sortedIndices, weights, train, classProbs,
new Instances(train, 0), m_MinNum, m_Debug,
attIndicesWindow);
}
/**
* Computes class distribution of an instance using the decision tree.
*/
public double[] distributionForInstance(Instance instance) throws Exception {
double[] returnedDist = null;
if (m_Attribute > -1) {
// Node is not a leaf
if (instance.isMissing(m_Attribute)) {
// Value is missing
returnedDist = new double[m_Info.numClasses()];
// Split instance up
for (int i = 0; i < m_Successors.length; i++) {
double[] help = m_Successors[i].distributionForInstance(instance);
if (help != null) {
for (int j = 0; j < help.length; j++) {
returnedDist[j] += m_Prop[i] * help[j];
}
}
}
} else if (m_Info.attribute(m_Attribute).isNominal()) {
// For nominal attributes
returnedDist = m_Successors[(int)instance.value(m_Attribute)].
distributionForInstance(instance);
} else {
// For numeric attributes
if (Utils.sm(instance.value(m_Attribute), m_SplitPoint)) {
returnedDist = m_Successors[0].distributionForInstance(instance);
} else {
returnedDist = m_Successors[1].distributionForInstance(instance);
}
}
}
if ((m_Attribute == -1) || (returnedDist == null)) {
// Node is a leaf or successor is empty
return m_ClassProbs;
} else {
return returnedDist;
}
}
/**
* Outputs the decision tree as a graph
*/
public String toGraph() {
try {
StringBuffer resultBuff = new StringBuffer();
toGraph(resultBuff, 0);
String result = "digraph Tree {\n" + "edge [style=bold]\n" + resultBuff.toString()
+ "\n}\n";
return result;
} catch (Exception e) {
return null;
}
}
/**
* Outputs one node for graph.
*/
public int toGraph(StringBuffer text, int num) throws Exception {
int maxIndex = Utils.maxIndex(m_ClassProbs);
String classValue = m_Info.classAttribute().value(maxIndex);
num++;
if (m_Attribute == -1) {
text.append("N" + Integer.toHexString(hashCode()) +
" [label=\"" + num + ": " + classValue + "\"" +
"shape=box]\n");
}else {
text.append("N" + Integer.toHexString(hashCode()) +
" [label=\"" + num + ": " + classValue + "\"]\n");
for (int i = 0; i < m_Successors.length; i++) {
text.append("N" + Integer.toHexString(hashCode())
+ "->" +
"N" + Integer.toHexString(m_Successors[i].hashCode()) +
" [label=\"" + m_Info.attribute(m_Attribute).name());
if (m_Info.attribute(m_Attribute).isNumeric()) {
if (i == 0) {
text.append(" < " +
Utils.doubleToString(m_SplitPoint, 2));
} else {
text.append(" >= " +
Utils.doubleToString(m_SplitPoint, 2));
}
} else {
text.append(" = " + m_Info.attribute(m_Attribute).value(i));
}
text.append("\"]\n");
num = m_Successors[i].toGraph(text, num);
}
}
return num;
}
/**
* Outputs the decision tree.
*/
public String toString() {
return
"\nRandomTree\n==========\n" + toString(0) + "\n" +
"\nSize of the tree : " + numNodes();
}
/**
* Outputs a leaf.
*/
protected String leafString() throws Exception {
int maxIndex = Utils.maxIndex(m_Distribution[0]);
return " : " + m_Info.classAttribute().value(maxIndex) +
" (" + Utils.doubleToString(Utils.sum(m_Distribution[0]), 2) + "/" +
Utils.doubleToString((Utils.sum(m_Distribution[0]) -
m_Distribution[0][maxIndex]), 2) + ")";
}
/**
* Recursively outputs the tree.
*/
protected String toString(int level) {
try {
StringBuffer text = new StringBuffer();
if (m_Attribute == -1) {
// Output leaf info
return leafString();
} else if (m_Info.attribute(m_Attribute).isNominal()) {
// For nominal attributes
for (int i = 0; i < m_Successors.length; i++) {
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
text.append(m_Info.attribute(m_Attribute).name() + " = " +
m_Info.attribute(m_Attribute).value(i));
text.append(m_Successors[i].toString(level + 1));
}
} else {
// For numeric attributes
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
text.append(m_Info.attribute(m_Attribute).name() + " < " +
Utils.doubleToString(m_SplitPoint, 2));
text.append(m_Successors[0].toString(level + 1));
text.append("\n");
for (int j = 0; j < level; j++) {
text.append("| ");
}
text.append(m_Info.attribute(m_Attribute).name() + " >= " +
Utils.doubleToString(m_SplitPoint, 2));
text.append(m_Successors[1].toString(level + 1));
}
return text.toString();
} catch (Exception e) {
e.printStackTrace();
return "RandomTree: tree can't be printed";
}
}
/**
* Recursively generates a tree.
*/
protected void buildTree(int[][] sortedIndices, double[][] weights,
Instances data, double[] classProbs,
Instances header, double minNum, boolean debug,
int[] attIndicesWindow)
throws Exception {
// create the random generator
m_random = new Random(m_randomSeed);
// Store structure of dataset, set minimum number of instances
m_Info = header;
m_Debug = debug;
m_MinNum = minNum;
// Make leaf if there are no training instances
if (sortedIndices[0].length == 0) {
m_Distribution = new double[1][data.numClasses()];
m_ClassProbs = null;
return;
}
// Check if node doesn't contain enough instances or is pure
m_ClassProbs = new double[classProbs.length];
System.arraycopy(classProbs, 0, m_ClassProbs, 0, classProbs.length);
if (Utils.sm(Utils.sum(m_ClassProbs), 2 * m_MinNum) ||
Utils.eq(m_ClassProbs[Utils.maxIndex(m_ClassProbs)],
Utils.sum(m_ClassProbs))) {
// Make leaf
m_Attribute = -1;
m_Distribution = new double[1][m_ClassProbs.length];
for (int i = 0; i < m_ClassProbs.length; i++) {
m_Distribution[0][i] = m_ClassProbs[i];
}
Utils.normalize(m_ClassProbs);
return;
}
// Compute class distributions and value of splitting
// criterion for each attribute
double[] vals = new double[data.numAttributes()];
double[][][] dists = new double[data.numAttributes()][0][0];
double[][] props = new double[data.numAttributes()][0];
double[] splits = new double[data.numAttributes()];
// Investigate K random attributes
int attIndex = 0;
int windowSize = attIndicesWindow.length;
int k = m_KValue;
boolean gainFound = false;
while ((windowSize > 0) && (k-- > 0 || !gainFound)) {
int chosenIndex = m_random.nextInt(windowSize);
attIndex = attIndicesWindow[chosenIndex];
// shift chosen attIndex out of window
attIndicesWindow[chosenIndex] = attIndicesWindow[windowSize-1];
attIndicesWindow[windowSize-1] = attIndex;
windowSize--;
splits[attIndex] = distribution(props, dists, attIndex,
sortedIndices[attIndex],
weights[attIndex], data);
vals[attIndex] = gain(dists[attIndex], priorVal(dists[attIndex]));
if (Utils.gr(vals[attIndex], 0)) gainFound = true;
}
// Find best attribute
m_Attribute = Utils.maxIndex(vals);
m_Distribution = dists[m_Attribute];
// Any useful split found?
if (Utils.gr(vals[m_Attribute], 0)) {
// Build subtrees
m_SplitPoint = splits[m_Attribute];
m_Prop = props[m_Attribute];
int[][][] subsetIndices =
new int[m_Distribution.length][data.numAttributes()][0];
double[][][] subsetWeights =
new double[m_Distribution.length][data.numAttributes()][0];
splitData(subsetIndices, subsetWeights, m_Attribute, m_SplitPoint,
sortedIndices, weights, m_Distribution, data);
m_Successors = new RandomTree[m_Distribution.length];
for (int i = 0; i < m_Distribution.length; i++) {
m_Successors[i] = new RandomTree();
m_Successors[i].setSeed(m_random.nextInt());
m_Successors[i].setKValue(m_KValue);
m_Successors[i].buildTree(subsetIndices[i], subsetWeights[i], data,
m_Distribution[i], header, m_MinNum, m_Debug,
attIndicesWindow);
}
} else {
// Make leaf
m_Attribute = -1;
m_Distribution = new double[1][m_ClassProbs.length];
for (int i = 0; i < m_ClassProbs.length; i++) {
m_Distribution[0][i] = m_ClassProbs[i];
}
}
// Normalize class counts
Utils.normalize(m_ClassProbs);
}
/**
* Computes size of the tree.
*/
public int numNodes() {
if (m_Attribute == -1) {
return 1;
} else {
int size = 1;
for (int i = 0; i < m_Successors.length; i++) {
size += m_Successors[i].numNodes();
}
return size;
}
}
/**
* Splits instances into subsets.
*/
protected void splitData(int[][][] subsetIndices, double[][][] subsetWeights,
int att, double splitPoint,
int[][] sortedIndices, double[][] weights,
double[][] dist, Instances data) throws Exception {
int j;
int[] num;
// For each attribute
for (int i = 0; i < data.numAttributes(); i++) {
if (i != data.classIndex()) {
if (data.attribute(att).isNominal()) {
// For nominal attributes
num = new int[data.attribute(att).numValues()];
for (int k = 0; k < num.length; k++) {
subsetIndices[k][i] = new int[sortedIndices[i].length];
subsetWeights[k][i] = new double[sortedIndices[i].length];
}
for (j = 0; j < sortedIndices[i].length; j++) {
Instance inst = data.instance(sortedIndices[i][j]);
if (inst.isMissing(att)) {
// Split instance up
for (int k = 0; k < num.length; k++) {
if (Utils.gr(m_Prop[k], 0)) {
subsetIndices[k][i][num[k]] = sortedIndices[i][j];
subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j];
num[k]++;
}
}
} else {
int subset = (int)inst.value(att);
subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
subsetWeights[subset][i][num[subset]] = weights[i][j];
num[subset]++;
}
}
} else {
// For numeric attributes
num = new int[2];
for (int k = 0; k < 2; k++) {
subsetIndices[k][i] = new int[sortedIndices[i].length];
subsetWeights[k][i] = new double[weights[i].length];
}
for (j = 0; j < sortedIndices[i].length; j++) {
Instance inst = data.instance(sortedIndices[i][j]);
if (inst.isMissing(att)) {
// Split instance up
for (int k = 0; k < num.length; k++) {
if (Utils.gr(m_Prop[k], 0)) {
subsetIndices[k][i][num[k]] = sortedIndices[i][j];
subsetWeights[k][i][num[k]] = m_Prop[k] * weights[i][j];
num[k]++;
}
}
} else {
int subset = Utils.sm(inst.value(att), splitPoint) ? 0 : 1;
subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
subsetWeights[subset][i][num[subset]] = weights[i][j];
num[subset]++;
}
}
}
// Trim arrays
for (int k = 0; k < num.length; k++) {
int[] copy = new int[num[k]];
System.arraycopy(subsetIndices[k][i], 0, copy, 0, num[k]);
subsetIndices[k][i] = copy;
double[] copyWeights = new double[num[k]];
System.arraycopy(subsetWeights[k][i], 0, copyWeights, 0, num[k]);
subsetWeights[k][i] = copyWeights;
}
}
}
}
/**
* Computes class distribution for an attribute.
*/
protected double distribution(double[][] props, double[][][] dists, int att,
int[] sortedIndices,
double[] weights, Instances data)
throws Exception {
double splitPoint = Double.NaN;
Attribute attribute = data.attribute(att);
double[][] dist = null;
int i;
if (attribute.isNominal()) {
// For nominal attributes
dist = new double[attribute.numValues()][data.numClasses()];
for (i = 0; i < sortedIndices.length; i++) {
Instance inst = data.instance(sortedIndices[i]);
if (inst.isMissing(att)) {
break;
}
dist[(int)inst.value(att)][(int)inst.classValue()] += weights[i];
}
} else {
// For numeric attributes
double[][] currDist = new double[2][data.numClasses()];
dist = new double[2][data.numClasses()];
// Move all instances into second subset
for (int j = 0; j < sortedIndices.length; j++) {
Instance inst = data.instance(sortedIndices[j]);
if (inst.isMissing(att)) {
break;
}
currDist[1][(int)inst.classValue()] += weights[j];
}
double priorVal = priorVal(currDist);
for (int j = 0; j < currDist.length; j++) {
System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);
}
// Try all possible split points
double currSplit = data.instance(sortedIndices[0]).value(att);
double currVal, bestVal = -Double.MAX_VALUE;
for (i = 0; i < sortedIndices.length; i++) {
Instance inst = data.instance(sortedIndices[i]);
if (inst.isMissing(att)) {
break;
}
if (Utils.gr(inst.value(att), currSplit)) {
currVal = gain(currDist, priorVal);
if (Utils.gr(currVal, bestVal)) {
bestVal = currVal;
splitPoint = (inst.value(att) + currSplit) / 2.0;
for (int j = 0; j < currDist.length; j++) {
System.arraycopy(currDist[j], 0, dist[j], 0, dist[j].length);
}
}
}
currSplit = inst.value(att);
currDist[0][(int)inst.classValue()] += weights[i];
currDist[1][(int)inst.classValue()] -= weights[i];
}
}
// Compute weights
props[att] = new double[dist.length];
for (int k = 0; k < props[att].length; k++) {
props[att][k] = Utils.sum(dist[k]);
}
if (Utils.eq(Utils.sum(props[att]), 0)) {
for (int k = 0; k < props[att].length; k++) {
props[att][k] = 1.0 / (double)props[att].length;
}
} else {
Utils.normalize(props[att]);
}
// Any instances with missing values ?
if (i < sortedIndices.length) {
// Distribute counts
while (i < sortedIndices.length) {
Instance inst = data.instance(sortedIndices[i]);
for (int j = 0; j < dist.length; j++) {
dist[j][(int)inst.classValue()] += props[att][j] * weights[i];
}
i++;
}
}
// Return distribution and split point
dists[att] = dist;
return splitPoint;
}
/**
* Computes value of splitting criterion before split.
*/
protected double priorVal(double[][] dist) {
return ContingencyTables.entropyOverColumns(dist);
}
/**
* Computes value of splitting criterion after split.
*/
protected double gain(double[][] dist, double priorVal) {
return priorVal - ContingencyTables.entropyConditionedOnRows(dist);
}
/**
* Main method for this class.
*/
public static void main(String[] argv) {
try {
System.out.println(Evaluation.evaluateModel(new RandomTree(), argv));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
}