/* * File: MultiCategoryAdaBoost.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry Learning Core * * Copyright March 24, 2011, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. * */ package gov.sandia.cognition.learning.algorithm.ensemble; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.collection.CollectionUtil; import gov.sandia.cognition.evaluator.Evaluator; import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner; import gov.sandia.cognition.learning.algorithm.BatchLearner; import gov.sandia.cognition.learning.algorithm.BatchLearnerContainer; import gov.sandia.cognition.learning.data.DatasetUtil; import gov.sandia.cognition.learning.data.DefaultWeightedInputOutputPair; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.util.ObjectUtil; import java.util.ArrayList; import java.util.Collection; import java.util.Set; /** * An implementation of a multi-class version of the Adaptive Boosting * (AdaBoost) algorithm, known as AdaBoost.M1. Note that the "weak learner" in * this version of AdaBoost requires a weighted error rate that is greater than * 0.5 to be accepted into the ensemble and for learning to continue. * * @param <InputType> * The type of input that the weak learner can learn over. * @param <CategoryType> * The type of categories to learn over. * @author Justin Basilico * @since 3.2.0 */ @PublicationReference( author={"Yoav Freund", "Robert E.Schapire"}, title="A decision-theoretic generalization of on-line learning and an application to boosting", publication="Journal of Computer and System Sciences", notes="Volume 55, Number 1", year=1997, pages={119,139}, type=PublicationType.Journal, url="http://www.cse.ucsd.edu/~yfreund/papers/adaboost.pdf") public class MultiCategoryAdaBoost<InputType, CategoryType> extends AbstractAnytimeSupervisedBatchLearner<InputType, CategoryType, WeightedVotingCategorizerEnsemble<InputType, CategoryType, Evaluator<? super InputType, ? extends CategoryType>>> implements BatchLearnerContainer<BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>>> { /** The default maximum number of iterations is {@value}. */ public static final int DEFAULT_MAX_ITERATIONS = 100; /** * The "weak learner" that must learn from the weighted input-output pairs * on each iteration. */ protected BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> weakLearner; /** The ensemble learned by the algorithm. */ protected transient WeightedVotingCategorizerEnsemble<InputType, CategoryType, Evaluator<? super InputType, ? extends CategoryType>> ensemble; /** An array list containing the weighted version of the data. */ protected transient ArrayList<DefaultWeightedInputOutputPair<InputType, CategoryType>> weightedData; /** * Creates a new {@code MultiCategoryAdaBoost} with default parameters and a null * weak learner. */ public MultiCategoryAdaBoost() { this(null, DEFAULT_MAX_ITERATIONS); } /** * Creates a new {@code MultiCategoryAdaBoost} with the given parameters. * * @param weakLearner * The weak learner to use. * @param maxIterations * The maximum number of iterations. Must be positive. */ public MultiCategoryAdaBoost( final BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> weakLearner, final int maxIterations) { super(maxIterations); this.setWeakLearner(weakLearner); } @Override protected boolean initializeAlgorithm() { if (CollectionUtil.isEmpty(this.getData())) { return false; } // We initialize the weighted training examples and count them up // as we go so that we can initialize the weights in the next step. int numExamples = 0; this.weightedData = new ArrayList<DefaultWeightedInputOutputPair<InputType, CategoryType>>( this.getData().size()); for (InputOutputPair<? extends InputType, CategoryType> example : this.getData()) { if (example != null && example.getOutput() != null) { this.weightedData.add( new DefaultWeightedInputOutputPair<InputType, CategoryType>( example.getInput(), example.getOutput(), DatasetUtil.getWeight(example))); numExamples++; } } if (numExamples <= 0) { this.weightedData = null; return false; } // Figure out the set of categories. final Set<CategoryType> categories = DatasetUtil.findUniqueOutputs( this.weightedData); this.ensemble = new WeightedVotingCategorizerEnsemble<InputType, CategoryType, Evaluator<? super InputType, ? extends CategoryType>>( categories); return true; } @Override protected boolean step() { // Normalize the weights to sum to 1.0. double weightSum = 0.0; for (DefaultWeightedInputOutputPair<?, ?> example : this.weightedData) { weightSum += example.getWeight(); } for (DefaultWeightedInputOutputPair<?, ?> example : this.weightedData) { example.setWeight(example.getWeight() / weightSum); } // Call the weak learner. final Evaluator<? super InputType, ? extends CategoryType> member = this.getWeakLearner().learn(weightedData); // Next compute the weighted error rate (epsilon) for the newly trained // member. Note that the weights sum to one so the error rate can be // at most 1.0. double weightedErrorRate = 0.0; final int dataSize = this.weightedData.size(); final boolean[] correctness = new boolean[dataSize]; for (int i = 0; i < dataSize; i++) { final DefaultWeightedInputOutputPair<InputType, CategoryType> example = this.weightedData.get(i); // Figure out if the categorizer is correct on this example. final CategoryType actual = example.getOutput(); final CategoryType predicted = member.evaluate(example.getInput()); final boolean correct = ObjectUtil.equalsSafe(predicted, actual); // Keep track of the correctness in order to compute the weight // update loop correctness[i] = correct; if (!correct) { // This was an error. Increase the error rate. weightedErrorRate += example.getWeight(); } } if (weightedErrorRate > 0.5) { // The best weak learner had too many errors, so we stop. return false; } // Compute beta. final double beta = weightedErrorRate / (1.0 - weightedErrorRate); // Update all of the weights based on beta. for (int i = 0; i < dataSize; i++) { final DefaultWeightedInputOutputPair<InputType, CategoryType> example = this.weightedData.get(i); final boolean correct = correctness[i]; // Compute the new weight. final double oldWeight = example.getWeight(); double newWeight = oldWeight * Math.pow(beta, 1.0 - (correct ? 0.0 : 1.0)); example.setWeight(newWeight); } // Add the member to the ensemble using the weight log(1/beta). final double memberWeight = (beta == 0.0 ? Double.POSITIVE_INFINITY : Math.log(1.0 / beta)); this.ensemble.add(member, memberWeight); return true; } @Override protected void cleanupAlgorithm() { this.weightedData = null; } @Override public WeightedVotingCategorizerEnsemble<InputType, CategoryType, Evaluator<? super InputType, ? extends CategoryType>> getResult() { return this.ensemble; } @Override public BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> getLearner() { return this.weakLearner; } /** * Gets the weak learner that is passed the weighted training data on each * step of the algorithm. * * @return * The weak learner for the algorithm to use. */ public BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> getWeakLearner() { return this.weakLearner; } /** * Sets the weak learner that is passed the weighted training data on each * step of the algorithm. * * @param weakLearner * The weak learner for the algorithm to use. */ public void setWeakLearner( final BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> weakLearner) { this.weakLearner = weakLearner; } }