/*
* File: VectorThresholdInformationGainLearner.java
* Authors: Justin Basilico, Art Munson
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright November 8, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.tree;
import gov.sandia.cognition.collection.ArrayUtil;
import gov.sandia.cognition.math.MathUtil;
import gov.sandia.cognition.math.matrix.mtj.Vector2;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import java.util.ArrayList;
import java.util.Map;
/**
* The {@code VectorThresholdInformationGainLearner} computes the best
* threshold over a dataset of vectors using information gain to determine the
* optimal index and threshold. This is an implementation of what is used in
* the C4.5 decision tree algorithm.
* <BR><BR>
* Information gain for a given split (sets X and Y) for two categories (a and b):
* <BR> ig(X, Y) = entropy(X + Y)
* <BR> – (|X| / (|X| + |Y|)) entropy(X)
* <BR> – (|Y| / (|X| + |Y|)) entropy(Y)
* <BR> with
* <BR><BR>
* <BR> entropy(Z) = - (Za / |Z|) log2(Za / |Z|) – (Zb / |Z|) log2(Zb / |Z|)
* <BR><BR>
* <BR>where
* <BR> Za = number of a's in Z, and
* <BR> Zb = number of b's in Z.
* <BR> In the multi-class case, the entropy is defined as the sum over all of the
* categories (c) of -Zc / |Z| log2(Zc / |Z|).
*
* @param <OutputType> The output type of the data.
* @author Justin Basilico
* @since 2.0
*/
public class VectorThresholdInformationGainLearner<OutputType>
extends AbstractVectorThresholdMaximumGainLearner<OutputType>
implements PriorWeightedNodeLearner<OutputType>
{
/** The categories for the prior. */
protected ArrayList<OutputType> categories = null;
/** The priors for each category. */
protected double[] categoryPriors = null;
/** The counts for each category. */
protected int[] categoryCounts = null;
/** Following is scratch space used when computing weighted
* entropy. It is declared here so it can be allocated once,
* instead of during every entropy evaluation. */
protected double[] categoryProbabilities = null;
/**
* Creates a new instance of VectorDeciderLearner.
*/
public VectorThresholdInformationGainLearner()
{
super();
}
/**
* Creates a new {@code VectorThresholdInformationGainLearner}.
*
* @param minSplitSize
* The minimum split size. Must be positive.
*/
public VectorThresholdInformationGainLearner(
final int minSplitSize)
{
super(minSplitSize, null);
}
@Override
public VectorThresholdInformationGainLearner<OutputType> clone()
{
@SuppressWarnings("unchecked")
final VectorThresholdInformationGainLearner<OutputType> result = (VectorThresholdInformationGainLearner<OutputType>)
super.clone();
result.categories = this.categories == null ? null : new ArrayList<>(this.categories);
result.categoryPriors = ArrayUtil.copy(this.categoryPriors);
result.categoryCounts = ArrayUtil.copy(this.categoryCounts);
result.categoryProbabilities = ArrayUtil.copy(this.categoryProbabilities);
return result;
}
@Override
public double computeSplitGain(
final DefaultDataDistribution<OutputType> baseCounts,
final DefaultDataDistribution<OutputType> positiveCounts,
final DefaultDataDistribution<OutputType> negativeCounts)
{
if (categoryPriors == null) {
// Support legacy code that does not configure class
// priors.
return legacyComputSplitGain(baseCounts,
positiveCounts,
negativeCounts);
}
Vector2 baseEntropy = weightedEntropy(baseCounts);
Vector2 posEntropy = weightedEntropy(positiveCounts);
Vector2 negEntropy = weightedEntropy(negativeCounts);
double posWt = posEntropy.getSecond() / baseEntropy.getSecond();
double negWt = negEntropy.getSecond() / baseEntropy.getSecond();
double gain = baseEntropy.getFirst()
- posWt*posEntropy.getFirst()
- negWt*negEntropy.getFirst();
return gain;
}
/**
* Computes entropy of the counts, weighted by prior
* probabilities. This entropy calculation comes from Breiman et
* al. (1984), "Classification and Regression Trees".
* @return The pair of values (entropy, marginal node prob).
*/
private Vector2 weightedEntropy(
final DefaultDataDistribution<OutputType> counts)
{
// Variable p_t stores the marginal probability of a training
// point reaching this tree node (node t). It is defined as:
// p(t) = sum_j p(j, t)
// where j indexes over classes.
double p_t = 0;
for (int j = 0; j < categoryProbabilities.length; ++j) {
// Compute joint probability of seeing class j and landing
// in this tree node. We estimate this as:
// p(j, t) = prior(j) * p(t | j)
// where
// p(t | j) = (# class j at node t) / (# class j in training)
categoryProbabilities[j] = categoryPriors[j]
* counts.get(categories.get(j))
/ (double)(categoryCounts[j]);
p_t += categoryProbabilities[j];
}
// The entropy of data at a node t equals
// - sum_j p(j | t) log p(j | t)
double entropy = 0;
for (int j = 0; j < categoryProbabilities.length; ++j) {
double condProb = categoryProbabilities[j] / p_t;
if (condProb > 0) {
entropy -= condProb * MathUtil.log2(condProb);
}
}
return new Vector2(entropy, p_t);
}
/**
* Legacy implementation of gain computation. This code does not
* incorporate class priors.
*/
private double legacyComputSplitGain(
final DefaultDataDistribution<OutputType> baseCounts,
final DefaultDataDistribution<OutputType> positiveCounts,
final DefaultDataDistribution<OutputType> negativeCounts)
{
final double totalCount = baseCounts.getTotal();
final double entropyBase = baseCounts.getEntropy();
final double entropyPositive = positiveCounts.getEntropy();
final double entropyNegative = negativeCounts.getEntropy();
final double proportionPositive = positiveCounts.getTotal() / totalCount;
final double proportionNegative = negativeCounts.getTotal() / totalCount;
final double gain = entropyBase
- proportionPositive * entropyPositive
- proportionNegative * entropyNegative;
return gain;
}
///// Implementation of PriorWeightedNodeLearner /////
public void configure(Map<OutputType,Double> priors,
Map<OutputType,Integer> trainCounts)
{
categories = new ArrayList<OutputType>(trainCounts.keySet());
categoryCounts = new int[categories.size()];
int total = 0;
for (int j = 0; j < categories.size(); ++j) {
categoryCounts[j] = trainCounts.get(categories.get(j));
total += categoryCounts[j];
}
categoryPriors = new double[categories.size()];
if (priors == null) {
if (total > 0) {
// Default to relative class frequencies.
for (int j = 0; j < categories.size(); ++j) {
categoryPriors[j] = categoryCounts[j] / ((double)total);
}
}
else {
// This is really unlikely . . .
for (int j = 0; j < categories.size(); ++j) {
categoryPriors[j] = 1.0 / categories.size();
}
}
}
else {
for (int j = 0; j < categories.size(); ++j) {
categoryPriors[j] = priors.get(categories.get(j));
}
}
categoryProbabilities = new double[categories.size()];
}
}