/* * File: CategoryBalancedBaggingLearner.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry Learning Core * * Copyright April 18, 2011, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. * */ package gov.sandia.cognition.learning.algorithm.ensemble; import gov.sandia.cognition.evaluator.Evaluator; import gov.sandia.cognition.learning.algorithm.BatchLearner; import gov.sandia.cognition.learning.data.DatasetUtil; import gov.sandia.cognition.learning.data.InputOutputPair; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Random; import java.util.Set; /** * An extension of the basic bagging learner that attempts to sample bags that * have equal numbers of examples from every category. * * @param <InputType> * The input type for supervised learning. Passed on to the internal * learning algorithm. Also the input type for the learned ensemble. * @param <CategoryType> * The output type for supervised learning. Passed on to the internal * learning algorithm. Also the output type of the learned ensemble. * @author Justin Basilico * @since 3.3.0 */ public class CategoryBalancedBaggingLearner<InputType, CategoryType> extends BaggingCategorizerLearner<InputType, CategoryType> { /** The list of categories. */ protected ArrayList<CategoryType> categoryList; /** The mapping of categories to indices of examples belonging to the category. */ protected HashMap<CategoryType, ArrayList<Integer>> dataPerCategory; /** * Creates a new instance of CategoryBalancedBaggingLearner. */ public CategoryBalancedBaggingLearner() { this(null); } /** * Creates a new instance of CategoryBalancedBaggingLearner. * * @param learner * The learner to use to create the categorizer on each iteration. */ public CategoryBalancedBaggingLearner( final BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner) { this(learner, DEFAULT_MAX_ITERATIONS, DEFAULT_PERCENT_TO_SAMPLE, new Random()); } /** * Creates a new instance of CategoryBalancedBaggingLearner. * * @param learner * The learner to use to create the categorizer on each iteration. * @param maxIterations * The maximum number of iterations to run for, which is also the * number of learners to create. * @param percentToSample * The percentage of the total size of the data to sample on each * iteration. Must be positive. * @param random * The random number generator to use. */ public CategoryBalancedBaggingLearner( final BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner, final int maxIterations, final double percentToSample, final Random random) { super(learner, maxIterations, percentToSample, random); } @Override protected boolean initializeAlgorithm() { boolean result = super.initializeAlgorithm(); if (result) { // Map each category to a list of indices for it. final int dataSize = this.dataList.size(); final Set<CategoryType> categories = DatasetUtil.findUniqueOutputs( this.dataList); this.categoryList = new ArrayList<CategoryType>(categories); this.dataPerCategory = new LinkedHashMap<CategoryType, ArrayList<Integer>>( categories.size()); for (CategoryType category : categories) { this.dataPerCategory.put(category, new ArrayList<Integer>()); } for (int i = 0; i < dataSize; i++) { final CategoryType category = this.dataList.get(i).getOutput(); this.dataPerCategory.get(category).add(i); } } return result; } @Override protected void fillBag( final int sampleCount) { // Get the number of categories. final int categoryCount = this.categoryList.size(); if ((sampleCount % categoryCount) != 0) { // Shuffle the category list to deal with uneven numbers of // examples per category. Collections.shuffle(this.categoryList, this.random); } int remainingSampleSize = sampleCount; for (int i = 0; i < categoryCount && remainingSampleSize > 0; i++) { final CategoryType category = this.categoryList.get(i); final ArrayList<Integer> indices = this.dataPerCategory.get(category); final int categorySize = indices.size(); final int categorySampleSize = Math.max(1, remainingSampleSize / (categoryCount - i)); for (int j = 0; j < categorySampleSize; j++) { final int index = indices.get( this.random.nextInt(categorySize)); final InputOutputPair<? extends InputType, CategoryType> example = this.dataList.get(index); this.bag.add(example); this.dataInBag[index] += 1; } remainingSampleSize -= categorySampleSize; } } @Override protected void cleanupAlgorithm() { this.dataPerCategory = null; this.categoryList = null; super.cleanupAlgorithm(); } }