/*
* File: CategoryBalancedIVotingLearner.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry Learning Core
*
* Copyright February 08, 2011, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
*
*/
package gov.sandia.cognition.learning.algorithm.ensemble;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.factory.Factory;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Random;
/**
* An extension of IVoting for dealing with skew problems that makes sure that
* there are an equal number of examples from each category in each sample that
* an ensemble member is trained on.
*
* @param <InputType>
* The type of the input for the categorizer to learn. This is the type
* passed to the internal batch learner to learn each ensemble member.
* @param <CategoryType>
* The type of the category that is the output for the categorizer to
* learn. It is also passed to the internal batch learner to learn each
* ensemble member. It must have a valid equals and hashCode method.
* @author Justin Basilico
* @since 3.3.0
*/
public class CategoryBalancedIVotingLearner<InputType, CategoryType>
extends IVotingCategorizerLearner<InputType, CategoryType>
{
/**
* Creates a new {@code CategoryBalancedIVotingLearner}.
*/
public CategoryBalancedIVotingLearner()
{
this(null, DEFAULT_MAX_ITERATIONS, DEFAULT_PERCENT_TO_SAMPLE,
new Random());
}
/**
* Creates a new {@code CategoryBalancedIVotingLearner}.
*
* @param learner
* The learner to use to create the categorizer on each iteration.
* @param maxIterations
* The maximum number of iterations to run for, which is also the
* number of learners to create.
* @param percentToSample
* The percentage of the total size of the data to sample on each
* iteration. Must be positive.
* @param random
* The random number generator to use.
*/
public CategoryBalancedIVotingLearner(
final BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner,
final int maxIterations,
final double percentToSample,
final Random random)
{
this(learner, maxIterations, percentToSample,
DEFAULT_PROPORTION_INCORRECT_IN_SAMPLE,
DEFAULT_VOTE_OUT_OF_BAG_ONLY,
new DefaultDataDistribution.DefaultFactory<CategoryType>(2),
random);
}
/**
* Creates a new {@code CategoryBalancedIVotingLearner}.
*
* @param learner
* The learner to use to create the categorizer on each iteration.
* @param maxIterations
* The maximum number of iterations to run for, which is also the
* number of learners to create.
* @param percentToSample
* The percentage of the total size of the data to sample on each
* iteration. Must be positive.
* @param proportionIncorrectInSample
* The percentage of incorrect examples to put in each sample. Must
* be between 0.0 and 1.0 (inclusive).
* @param voteOutOfBagOnly
* Controls whether or not in-bag or out-of-bag votes are used to
* determine accuracy.
* @param counterFactory
* The factory for counting votes.
* @param random
* The random number generator to use.
*/
public CategoryBalancedIVotingLearner(
final BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner,
final int maxIterations,
final double percentToSample,
final double proportionIncorrectInSample,
final boolean voteOutOfBagOnly,
final Factory<? extends DataDistribution<CategoryType>> counterFactory,
final Random random)
{
super(learner, maxIterations, percentToSample, proportionIncorrectInSample,
voteOutOfBagOnly, counterFactory, random);
}
@Override
protected void createBag(
final ArrayList<Integer> correctIndices,
final ArrayList<Integer> incorrectIndices)
{
// First we need to figure out which items are currently correct and
// incorrect in each category.
// Initialize the data structures.
final LinkedHashMap<CategoryType, ArrayList<Integer>> correctPerCategory =
new LinkedHashMap<CategoryType, ArrayList<Integer>>();
final LinkedHashMap<CategoryType, ArrayList<Integer>> incorrectPerCategory =
new LinkedHashMap<CategoryType, ArrayList<Integer>>();
for (CategoryType category : this.ensemble.getCategories())
{
correctPerCategory.put(category, new ArrayList<Integer>());
incorrectPerCategory.put(category, new ArrayList<Integer>());
}
// Add the index to the appropriate list.
for (Integer index : correctIndices)
{
final CategoryType category = this.dataList.get(index).getOutput();
correctPerCategory.get(category).add(index);
}
for (Integer index : incorrectIndices)
{
final CategoryType category = this.dataList.get(index).getOutput();
incorrectPerCategory.get(category).add(index);
}
// Figure out how many to sample per category.
final int categoryCount = this.ensemble.getCategories().size();
final int correctPerCategorySize =
Math.max(1, this.numCorrectToSample / categoryCount);
final int incorrectPerCategorySize =
Math.max(1, this.numIncorrectToSample / categoryCount);
// Now sample from each category.
for (CategoryType category : this.ensemble.getCategories())
{
// Get the correct and incorrect indices for thie category.
ArrayList<Integer> categoryCorrect =
correctPerCategory.get(category);
ArrayList<Integer> categoryIncorrect =
incorrectPerCategory.get(category);
if (categoryIncorrect.isEmpty())
{
// Nothing incorrect, so just sample more from correct.
categoryIncorrect = categoryCorrect;
}
else if (correctIndices.isEmpty())
{
// Nothing correct, so just sample more from incorrect.
categoryCorrect = categoryIncorrect;
}
// Sample with replacement.
sampleIndicesWithReplacementInto(categoryCorrect, this.dataList,
correctPerCategorySize, this.random,
this.currentBag, this.dataInBag);
sampleIndicesWithReplacementInto(categoryIncorrect, this.dataList,
incorrectPerCategorySize, this.random,
this.currentBag, this.dataInBag);
}
}
}