/*
* File: MaximumAPosterioriCategorizer.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Mar 26, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.function.categorization;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Map;
import java.util.Random;
import java.util.Set;
/**
* Categorizer that returns the category with the highest posterior likelihood
* for a given observation. This is known as a MAP categorizer, where
* the posterior is proportionate to the category's conditional likelihood
* for a given observation times the prior probability of the category.
* @param <ObservationType> Type of observations
* @param <CategoryType> Type of categories
* @author Kevin R. Dixon
* @since 3.0
*/
@PublicationReference(
author="Wikipedia",
title="Maximum a posteriori estimation",
type=PublicationType.WebPage,
year=2010,
url="http://en.wikipedia.org/wiki/Maximum_a_posteriori_estimation"
)
public class MaximumAPosterioriCategorizer<ObservationType,CategoryType>
extends AbstractDistribution<ObservationType>
implements DiscriminantCategorizer<ObservationType,CategoryType,Double>
{
/**
* PMF of the various categories
*/
DataDistribution.PMF<CategoryType> categoryPriors;
/**
* Map that contains the probability functions for the observations
* for the given categories.
*/
Map<CategoryType,ProbabilityFunction<ObservationType>> categoryConditionals;
/**
* Creates a new instance of MaximumAPosterioriCategorizer
*/
public MaximumAPosterioriCategorizer()
{
this.categoryPriors = new DefaultDataDistribution.PMF<CategoryType>( 2 );
this.categoryConditionals =
new HashMap<CategoryType, ProbabilityFunction<ObservationType>>( 2 );
}
@Override
@SuppressWarnings("unchecked")
public MaximumAPosterioriCategorizer<ObservationType,CategoryType> clone()
{
return (MaximumAPosterioriCategorizer<ObservationType,CategoryType>) super.clone();
}
/**
* Adds the given category with the given mass (which is divided by the
* masses of all categories to determine the prior probability weight)
* and the distribution function
* @param category
* Category to add
* @param mass
* Mass of the category
* @param conditional
* Conditional probability function of observations for the category
*/
public void addCategory(
final CategoryType category,
final double mass,
final ProbabilityFunction<ObservationType> conditional )
{
this.categoryPriors.increment(category, mass);
this.categoryConditionals.put( category, conditional );
}
/**
* Gets the prior probability weight and conditional distribution for
* the given category.
* @param category
* Category to consider
* @return
* Prior probability weight and conditional distribution for
* the given category.
*/
public WeightedValue<ProbabilityFunction<ObservationType>> getCategory(
final CategoryType category )
{
ProbabilityFunction<ObservationType> conditional =
this.categoryConditionals.get(category);
double prior = this.categoryPriors.evaluate(category);
return new DefaultWeightedValue<ProbabilityFunction<ObservationType>>(
conditional, prior );
}
@Override
public Set<? extends CategoryType> getCategories()
{
return this.categoryConditionals.keySet();
}
@Override
public CategoryType evaluate(
final ObservationType input)
{
return this.evaluateWithDiscriminant(input).getValue();
}
@Override
public DefaultWeightedValueDiscriminant<CategoryType> evaluateWithDiscriminant(
final ObservationType input)
{
CategoryType maxCategory = null;
double maxPosterior = Double.NEGATIVE_INFINITY;
for( CategoryType category : this.getCategories() )
{
double posterior = this.computePosterior(input, category);
if( maxPosterior < posterior )
{
maxPosterior = posterior;
maxCategory = category;
}
}
return DefaultWeightedValueDiscriminant.create(maxCategory, maxPosterior);
}
/**
* Computes the posterior of the observation given the category.
* Actually, this is the conjunctive likelihood since we've not normalizing
* by the likelihood of the observation over all categories. Since we're
* only interested in finding the MAP category, we're doing the standard
* thing and not normalizing.
* @param observation
* Observation to consider
* @param category
* Category to consider
* @return
* Posterior likelihood of the observation given the category.
*/
public double computePosterior(
final ObservationType observation,
final CategoryType category )
{
ProbabilityFunction<ObservationType> categoryConditional =
this.categoryConditionals.get(category);
double posterior;
if( categoryConditional != null )
{
double prior = this.categoryPriors.evaluate(category);
double conditional = categoryConditional.evaluate(observation);
posterior = conditional*prior;
}
else
{
posterior = 0.0;
}
return posterior;
}
/**
* Gets the mean category, if it is a number or ring.
*
* @return
* The mean.
*/
@SuppressWarnings("unchecked")
public ObservationType getMean()
{
ObservationType mean = null;
for( CategoryType category : this.getCategories() )
{
ObservationType categoryMean = this.getMean();
double prior = this.categoryPriors.evaluate(category);
if( categoryMean instanceof Number )
{
if( mean == null )
{
mean = (ObservationType) new Double( 0.0 );
}
double weightedCategoryMean = prior * ((Number) categoryMean).doubleValue();
mean = (ObservationType)
new Double( ((Number) mean).doubleValue() + weightedCategoryMean );
}
else if( categoryMean instanceof Ring<?> )
{
Ring<?> weightedCategoryMean = ((Ring<?>) categoryMean).scale(prior);
if( mean == null )
{
mean = (ObservationType) weightedCategoryMean;
}
else
{
((Ring) mean).plusEquals( weightedCategoryMean );
}
}
else
{
throw new UnsupportedOperationException(
"Mean not supported for type " + categoryMean.getClass() );
}
}
return mean;
}
@Override
public void sampleInto(
final Random random,
final int numSamples,
final Collection<? super ObservationType> output)
{
ArrayList<? extends CategoryType> categories =
this.categoryPriors.sample(random, numSamples);
for( CategoryType category : categories )
{
ProbabilityFunction<ObservationType> pdf =
this.categoryConditionals.get(category);
output.add( pdf.sample(random) );
}
}
/**
* Learner for the MAP categorizer
* @param <ObservationType> Type of observations
* @param <CategoryType> Type of categories
*/
public static class Learner<ObservationType,CategoryType>
extends AbstractCloneableSerializable
implements SupervisedBatchLearner<ObservationType,CategoryType,MaximumAPosterioriCategorizer<ObservationType,CategoryType>>
{
/**
* Learner that creates the conditional distributions for each
* category.
*/
private BatchLearner<Collection<? extends ObservationType>, ? extends ComputableDistribution<ObservationType>> conditionalLearner;
/**
* Default constructor
*/
public Learner()
{
this( null );
}
/**
* Creates a new instance of Learner
* @param conditionalLearner
* Learner that creates the conditional distributions for each
* category.
*/
public Learner(
final BatchLearner<Collection<? extends ObservationType>, ? extends ComputableDistribution<ObservationType>> conditionalLearner)
{
this.conditionalLearner = conditionalLearner;
}
@Override
public MaximumAPosterioriCategorizer.Learner<ObservationType,CategoryType> clone()
{
@SuppressWarnings("unchecked")
Learner<ObservationType,CategoryType> clone =
(Learner<ObservationType,CategoryType>) super.clone();
clone.setConditionalLearner(
ObjectUtil.cloneSmart( this.getConditionalLearner() ) );
return clone;
}
@Override
public MaximumAPosterioriCategorizer<ObservationType, CategoryType> learn(
final Collection<? extends InputOutputPair<? extends ObservationType, CategoryType>> data)
{
DataDistribution.PMF<CategoryType> categoryPrior =
new DefaultDataDistribution.PMF<CategoryType>();
Map<CategoryType,LinkedList<ObservationType>> categoryData =
new HashMap<CategoryType, LinkedList<ObservationType>>();
for( InputOutputPair<? extends ObservationType,CategoryType> pair : data )
{
categoryPrior.increment( pair.getOutput() );
LinkedList<ObservationType> categoryValues = categoryData.get( pair.getOutput() );
if( categoryValues == null )
{
categoryValues = new LinkedList<ObservationType>();
categoryData.put( pair.getOutput(), categoryValues );
}
categoryValues.add( pair.getInput() );
}
MaximumAPosterioriCategorizer<ObservationType,CategoryType> categorizer =
new MaximumAPosterioriCategorizer<ObservationType, CategoryType>();
for( CategoryType category : categoryPrior.getDomain() )
{
LinkedList<ObservationType> categoryValues =
categoryData.get(category);
ComputableDistribution<ObservationType> distribution =
this.conditionalLearner.learn(categoryValues);
ProbabilityFunction<ObservationType> conditional =
distribution.getProbabilityFunction();
categorizer.addCategory(
category, categoryPrior.get(category), conditional );
}
return categorizer;
}
/**
* Getter for conditionalLearner
* @return
* Learner that creates the conditional distributions for each
* category.
*/
public BatchLearner<Collection<? extends ObservationType>,? extends ComputableDistribution<ObservationType>> getConditionalLearner()
{
return this.conditionalLearner;
}
/**
* Setter for conditionalLearner
* @param conditionalLearner
* Learner that creates the conditional distributions for each
* category.
*/
public void setConditionalLearner(
BatchLearner<Collection<? extends ObservationType>, ? extends ComputableDistribution<ObservationType>> conditionalLearner)
{
this.conditionalLearner = conditionalLearner;
}
}
}