package func; import dist.*; import dist.Distribution; import dist.DiscreteDistribution; import func.dtree.BinaryDecisionTreeSplit; import func.dtree.DecisionTreeNode; import func.dtree.DecisionTreeSplit; import func.dtree.DecisionTreeSplitStatistics; import func.dtree.InformationGainSplitEvaluator; import func.dtree.PruningCriteria; import func.dtree.SplitEvaluator; import func.dtree.StandardDecisionTreeSplit; import shared.DataSet; import shared.DataSetDescription; import shared.Instance; /** * A class implementing a decision tree * @author Andrew Guillory gtg008g@mail.gatech.edu * @version 1.0 */ public class DecisionTreeClassifier extends AbstractConditionalDistribution implements FunctionApproximater { /** * The evaluator for deciding on splits */ private SplitEvaluator splitEvaluator; /** * Whether or not to prune */ private PruningCriteria pruningCriteria; /** * The root node in the tree */ private DecisionTreeNode root; /** * Whether or not to use binary splits */ private boolean useBinarySplits; /** * The ranges of the different attributes */ private int[] attributeRanges; /** * The range of the classifications */ private int classRange; /** * Create a new decision tree * @param splitEvaluator the splitting chooser * @param pruningCriteria the criteria for prunning * @param useBinarySplits whether or not to use binary splits * @param instances the instances to build the tree from */ public DecisionTreeClassifier(SplitEvaluator splitEvaluator, PruningCriteria pruningCriteria, boolean useBinarySplits) { this.splitEvaluator = splitEvaluator; this.pruningCriteria = pruningCriteria; this.useBinarySplits = useBinarySplits; } /** * Create a new decision tree with no prunning * @param splitEvaluator the splitting chooser * @param useBinarySplits whether or not to use binary splits * @param instances the instances to build the tree from */ public DecisionTreeClassifier(SplitEvaluator splitEvaluator, boolean useBinarySplits) { this(splitEvaluator, null, useBinarySplits); } /** * Create a new decision tree with no prunning * @param instances the instances to build the tree from */ public DecisionTreeClassifier() { this(new InformationGainSplitEvaluator(), null, false); } /** * Estimate from the given data set * @param set the set */ public void estimate(DataSet instances) { // make the description if it isn't there if (instances.getDescription() == null) { DataSetDescription desc = new DataSetDescription(); desc.induceFrom(instances); instances.setDescription(desc); } // initialize the ranges attributeRanges = new int[instances.getDescription().getAttributeTypes().length]; for (int i = 0; i < attributeRanges.length; i++) { attributeRanges[i] = instances.getDescription().getDiscreteRange(i); } // build the tree root = buildTree(instances); // if the root is pruned, use a stump why not? if (root == null) { DecisionStumpClassifier stump = new DecisionStumpClassifier(splitEvaluator); stump.estimate(instances); root = stump.getStump(); } } /** * Build a tree from the given instances * @param instances the instances to build the tree from * @return the tree */ private DecisionTreeNode buildTree(DataSet instances) { // nothing left in the tree if (instances.size() == 0) { return null; } // check if all of the same class boolean allOfSameClass = true; int sameClass = instances.get(0).getLabel().getDiscrete(); for (int i = 0; i < instances.size() && allOfSameClass; i++) { allOfSameClass = instances.get(i).getLabel().getDiscrete() == sameClass; } if (allOfSameClass) { return null; } // find the best splitter DecisionTreeSplit bestSplit = null; DecisionTreeSplitStatistics bestStats = null; double bestValue = Double.NEGATIVE_INFINITY; if (!useBinarySplits) { for (int i = 0; i < attributeRanges.length; i++) { DecisionTreeSplit split = new StandardDecisionTreeSplit(i, attributeRanges[i]); DecisionTreeSplitStatistics stats = new DecisionTreeSplitStatistics(split, instances); double value = splitEvaluator.splitValue(stats); if (value > bestValue) { bestValue = value; bestSplit = split; bestStats = stats; } } } else { for (int i = 0; i < attributeRanges.length; i++) { for (int j = 0; j < attributeRanges[i]; j++) { DecisionTreeSplit split = new BinaryDecisionTreeSplit(i, j); DecisionTreeSplitStatistics stats = new DecisionTreeSplitStatistics(split, instances); double value = splitEvaluator.splitValue(stats); if (value > bestValue) { bestValue = value; bestSplit = split; bestStats = stats; } } } } // divide up the instances Instance[][] divided = new Instance[bestSplit.getNumberOfBranches()][]; // check for at least two non zero branches int nonZero = 0; for (int i = 0; i < divided.length; i++) { divided[i] = new Instance[bestStats.getInstanceCount(i)]; if (divided[i].length != 0) { nonZero++; } } if (nonZero < 2) { return null; } // recursive step int[] counters = new int[divided.length]; for (int i = 0; i < instances.size(); i++) { int branch = bestSplit.getBranchOf(instances.get(i)); divided[branch][counters[branch]] = instances.get(i); counters[branch]++; } DecisionTreeNode[] nodes = new DecisionTreeNode[divided.length]; for (int i = 0; i < nodes.length; i++) { DataSet newSet = new DataSet(divided[i], instances.getDescription()); nodes[i] = buildTree(newSet); } DecisionTreeNode node = new DecisionTreeNode(bestSplit, bestStats, nodes); if (node.isLeaf() && pruningCriteria != null && pruningCriteria.shouldPrune(bestStats)) { return null; } return node; } /** * Get the class distribution for an instance * @param instance the instance to classify * @return the distribution */ public Distribution distributionFor(Instance instance) { DecisionTreeNode node = root; while (node.getNode(node.getSplit().getBranchOf(instance)) != null) { node = node.getNode(node.getSplit().getBranchOf(instance)); } int branch = node.getSplit().getBranchOf(instance); if (node.getSplitStatistics().getInstanceCount(branch) == 0) { return new DiscreteDistribution( node.getSplitStatistics().getClassProbabilities()); } else { return new DiscreteDistribution( node.getSplitStatistics().getConditionalClassProbabilities(branch)); } } /** * Get the classifiation for an instance * @param instance the instance to classify * @return the classification */ public Instance value(Instance instance) { return distributionFor(instance).mode(); } /** * Get the root node * @return the root */ public DecisionTreeNode getRoot() { return root; } /** * Get the split evaluator for the stump * @return the evaluator */ public SplitEvaluator getSplitEvaluator() { return splitEvaluator; } /** * Get the prunning criteria * @return the prunning criteria */ public PruningCriteria getPruningCriteria() { return pruningCriteria; } /** * Does the tree use binary splits * @return true if it should */ public boolean isUseBinarySplits() { return useBinarySplits; } /** * Set the pruning criteria * @param criteria the pruning criteria */ public void setPruningCriteria(PruningCriteria criteria) { pruningCriteria = criteria; } /** * Set the split evaluator * @param evaluator the split evaluator */ public void setSplitEvaluator(SplitEvaluator evaluator) { splitEvaluator = evaluator; } /** * Set whether to use binary splits * @param b true if we should */ public void setUseBinarySplits(boolean b) { useBinarySplits = b; } /** * Get the height of the tree * @return the height */ public int getHeight() { return height(root); } /** * Get the height of the tree * @param root the root node * @return the height */ private int height(DecisionTreeNode root) { if (root == null) { return 0; } int height = 1; for (int i = 0; i < root.getNodes().length; i++) { height = Math.max(height, 1 + height(root.getNode(i))); } return height; } /** * @see java.lang.Object#toString() */ public String toString() { return root.toString(); } }