package func.dtree; import shared.DataSet; /** * A class representing the distribution of nodes along * a decision tree split * @author Andrew Guillory gtg008g@mail.gatech.edu * @version 1.0 */ public class DecisionTreeSplitStatistics { /** * The instance counts for each of the branches */ private int[] instanceCounts; /** * The class probabilities * that is classProbabilities[i] is * the P(class = i) */ private double[] classProbabilities; /** * The conditional probabilities * that is conditionalClassProbabilities[i][j] is * the P(class = j | instance is in branch i) */ private double[][] conditionalClassProbabilities; /** * The branch probabilities * that is branchProbabilities[i] is * the P(instance is in branch i) */ private double[] branchProbabilities; /** * Calculate statistics from the given split and instances * @param split the split * @param instances the instances split on * @param classRange the range of class values */ public DecisionTreeSplitStatistics(DecisionTreeSplit split, DataSet instances) { int classRange = instances.getDescription().getLabelDescription().getDiscreteRange(); instanceCounts = new int[split.getNumberOfBranches()]; classProbabilities = new double[classRange]; conditionalClassProbabilities = new double[split.getNumberOfBranches()][classRange]; branchProbabilities = new double[split.getNumberOfBranches()]; // the sum of all of the weights double weightSum = 0; for (int i = 0; i < instances.size(); i++) { double weight = instances.get(i).getWeight(); int branch = split.getBranchOf(instances.get(i)); instanceCounts[branch]++; classProbabilities[instances.get(i).getLabel().getDiscrete()] += weight; branchProbabilities[branch] += weight; // we first just calculate the joint probabilities conditionalClassProbabilities[branch] [instances.get(i).getLabel().getDiscrete()] += weight; weightSum += weight; } // turn the unnormalized joint prob's into normalized conditional probs for (int i = 0; i < conditionalClassProbabilities.length; i++) { if (branchProbabilities[i] == 0) { continue; } for (int j = 0; j < conditionalClassProbabilities[i].length; j++) { conditionalClassProbabilities[i][j] /= branchProbabilities[i]; } } // normalize the attribute and class arrays for (int i = 0; i < classProbabilities.length; i++) { classProbabilities[i] /= weightSum; } for (int i = 0; i < branchProbabilities.length; i++) { branchProbabilities[i] /= weightSum; } } /** * Get the branch probabilties * @return the branch probabilites */ public double[] getBranchProbabilities() { return branchProbabilities; } /** * Get the number of branches * @return the number of branches */ public int getBranchCount() { return branchProbabilities.length; } /** * Get a branch probabilty * @param branch the branch * @return the probability */ public double getBranchProbability(int branch) { return branchProbabilities[branch]; } /** * Get the class probabilties * @return the probabilites */ public double[] getClassProbabilities() { return classProbabilities; } /** * Get the number of classes * @return the number of classes */ public int getClassCount() { return classProbabilities.length; } /** * Get a class probabilty * @param c the class * @return the probability */ public double getClassProbability(int c) { return classProbabilities[c]; } /** * Get the conditional class probabilites * @return the probabilties */ public double[][] getConditionalClassProbabilities() { return conditionalClassProbabilities; } /** * Get the conditional class probabilites * for a given branch * @param branc the branch * @return the probabilties */ public double[] getConditionalClassProbabilities(int branch) { return conditionalClassProbabilities[branch]; } /** * Get the instance counts * @return the counts */ public int[] getInstanceCounts() { return instanceCounts; } /** * Get an instance counts for a given branch * @param branch the branch * @return the count */ public int getInstanceCount(int branch) { return instanceCounts[branch]; } /** * Get the most likely class * @return the most likely class */ public int getMostLikelyClass() { int mostLikely = 0; for (int i = 1; i < classProbabilities.length; i++) { if (classProbabilities[i] > classProbabilities[mostLikely]) { mostLikely = i; } } return mostLikely; } }