package func; import dist.*; import dist.Distribution; import dist.DiscreteDistribution; import shared.DataSet; import shared.DataSetDescription; import shared.Instance; /** * A class for construcing a ensemble of decision stumps * @author Andrew Guillory gtg008g@mail.gatech.edu * @version 1.0 */ public class AdaBoostClassifier extends AbstractConditionalDistribution implements FunctionApproximater { /** * The classifier class to use */ private Class classifier; /** * The stumps themselves */ private FunctionApproximater[] classifiers; /** * The weights for each of the stumps */ private double[] weights; /** * The range of the class */ private int classRange; /** * The size of the ensemble */ private int size; /** * Create a new decision stump ensemble * @param size the number of stumps * @param splitEvaluator the splitting strategy to use * @param classifier the classifier class to use */ public AdaBoostClassifier(int size, Class classifier) { this.size = size; this.classifier = classifier; } /** * Create a new decision stump ensemble * @param size the number of stumps * @param splitEvaluator the splitting strategy to use */ public AdaBoostClassifier(int size) { this(size, DecisionStumpClassifier.class); } /** * Make a new default ensemble */ public AdaBoostClassifier() { this(100); } /** * Build the ensemble * @param instances the instances to train with */ public void estimate(DataSet instances) { classifiers = new FunctionApproximater[size]; weights = new double[size]; // initialize the weights of the instances for (int i = 0; i < instances.size(); i++) { instances.get(i).setWeight(1.0 / instances.size()); } // getting some info if (instances.getDescription() == null) { DataSetDescription desc = new DataSetDescription(); desc.induceFrom(instances); instances.setDescription(desc); } classRange = instances.getDescription().getLabelDescription().getDiscreteRange(); for (int i = 0; i < classifiers.length; i++) { try { // make a new classifier classifiers[i] = (FunctionApproximater) classifier.getConstructor(new Class[0]).newInstance(new Object[0]); classifiers[i].estimate(instances); } catch (Exception e) { throw new UnsupportedOperationException("Could not create " + classifier); } // find the error for the newest classifier double error = 0; for (int j = 0; j < instances.size(); j++) { if (classifiers[i].value(instances.get(j)).getDiscrete() != instances.get(j).getLabel().getDiscrete()) { error += instances.get(j).getWeight(); } } double beta = error / (1 - error); // set the weight of the classifier weights[i] = Math.log(1 / beta); // the classifier didn't do any good if (error == .5) { classifiers[i] = null; break; } else if (error == 0) { break; } // readjust the weights of the instances // and calculate the sum of the weights double weightSum = 0; for (int j = 0; j < instances.size(); j++) { if (classifiers[i].value(instances.get(j)).getDiscrete() == instances.get(j).getLabel().getDiscrete()) { instances.get(j).setWeight(instances.get(j).getWeight() * beta); weightSum += instances.get(j).getWeight(); } else { weightSum += instances.get(j).getWeight(); } } // normalize the weights for (int j = 0; j < instances.size(); j++) { instances.get(j).setWeight(instances.get(j).getWeight() / weightSum); } } } /** * Get the classification for an instances * @param data the data to classify * @return the class distribution */ public Instance value(Instance data) { double[] votes = new double[classRange]; for (int i = 0; i < classifiers.length && classifiers[i] != null; i++) { votes[classifiers[i].value(data).getDiscrete()] += weights[i]; } int classification = 0; for (int i = 1; i < votes.length; i++) { if (votes[i] > votes[classification]) { classification = i; } } return new Instance(classification); } /** * @see func.Classifier#classDistribution(shared.Instance) */ public Distribution distributionFor(Instance data) { Instance v = value(data); double[] p = new double[classRange]; p[v.getDiscrete()] = 1; return new DiscreteDistribution(p); } /** * Get the stump count * @return the stump count */ public int getSize() { return size; } /** * Set the stump count * @param i the stump count */ public void setSize(int i) { size = i; } /** * Get the decision stumps * @return the stumps */ public FunctionApproximater[] getClassifiers() { return classifiers; } /** * Get the weights of the stumps * @return the stumps */ public double[] getWeights() { return weights; } /** * @see java.lang.Object#toString() */ public String toString() { String ret = ""; for (int i = 0; i < classifiers.length && classifiers[i] != null; i++) { ret += "weight " + weights[i] + "\n"; ret += classifiers[i] + "\n\n"; } return ret; } }