package func;
import dist.AbstractConditionalDistribution;
import dist.DiscreteDistribution;
import dist.Distribution;
import func.dtree.BinaryDecisionTreeSplit;
import func.dtree.DecisionTreeNode;
import func.dtree.DecisionTreeSplit;
import func.dtree.DecisionTreeSplitStatistics;
import func.dtree.InformationGainSplitEvaluator;
import func.dtree.SplitEvaluator;
import shared.DataSet;
import shared.DataSetDescription;
import shared.Instance;
/**
* A decision stump
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class DecisionStumpClassifier extends AbstractConditionalDistribution implements FunctionApproximater {
/**
* The evaluator for deciding on splits
*/
private SplitEvaluator splitEvaluator;
/**
* The stump itself
*/
private DecisionTreeNode stump;
/**
* The ranges of the different attributes
*/
private int[] attributeRanges;
/**
* Create a new decision stump
* @param splitEvaluator the splitting chooser
* @param instances the instances to build the tree from
*/
public DecisionStumpClassifier(SplitEvaluator splitEvaluator) {
this.splitEvaluator = splitEvaluator;
}
/**
* Create a new decision stump
*/
public DecisionStumpClassifier() {
this(new InformationGainSplitEvaluator());
}
/**
* Estimate from data
* @param instances the data set
*/
public void estimate(DataSet instances) {
// make the description if it isn't there
if (instances.getDescription() == null) {
DataSetDescription desc = new DataSetDescription();
desc.induceFrom(instances);
instances.setDescription(desc);
}
// initialize the ranges
attributeRanges = new int[instances.getDescription().getAttributeTypes().length];
for (int i = 0; i < attributeRanges.length; i++) {
attributeRanges[i] = instances.getDescription().getDiscreteRange(i);
}
// build the stump
stump = buildStump(instances);
if (stump == null) {
throw new RuntimeException("Invalid Stump Exception");
}
}
/**
* Build a stump from the instances
* @param instances the instances to build the stump from
* @return the stump
*/
private DecisionTreeNode buildStump(DataSet instances) {
// find the best binary splitter
DecisionTreeSplit bestSplit = null;
DecisionTreeSplitStatistics bestStats = null;
double bestValue = Double.NEGATIVE_INFINITY;
for (int i = 0; i < attributeRanges.length; i++) {
for (int j = 0; j < attributeRanges[i]; j++) {
DecisionTreeSplit split = new BinaryDecisionTreeSplit(i, j);
DecisionTreeSplitStatistics stats = new DecisionTreeSplitStatistics(split, instances);
double value = splitEvaluator.splitValue(stats);
if (value > bestValue) {
bestValue = value; bestSplit = split; bestStats = stats;
}
}
}
DecisionTreeNode node = new DecisionTreeNode(bestSplit, bestStats,
new DecisionTreeNode[bestStats.getBranchCount()]);
return node;
}
/**
* @see dist.ConditionalDistribution#distributionFor(shared.Instance)
*/
public Distribution distributionFor(Instance instance) {
int branch = stump.getSplit().getBranchOf(instance);
if (stump.getSplitStatistics().getInstanceCount(branch) == 0) {
return new DiscreteDistribution(stump.getSplitStatistics().getClassProbabilities());
} else {
return new DiscreteDistribution(stump.getSplitStatistics().getConditionalClassProbabilities(branch));
}
}
/**
* @see func.FunctionApproximater#value(shared.Instance)
*/
public Instance value(Instance i) {
return distributionFor(i).mode();
}
/**
* Get the stump for the decision tree node
* @return the stump
*/
public DecisionTreeNode getStump() {
return stump;
}
/**
* Get the split evaluator for the stump
* @return the evaluator
*/
public SplitEvaluator getSplitEvaluator() {
return splitEvaluator;
}
/**
* @see java.lang.Object#toString()
*/
public String toString() {
return stump.toString();
}
}