/*
* File: WeightedMostFrequentLearner.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright April 18, 2008, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.baseline;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.learning.algorithm.SupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.ConstantEvaluator;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import java.util.Collection;
/**
* The {@code WeightedMostFrequentLearner} class implements a baseline learning
* algorithm that finds the most frequent output of a given dataset based on
* the weights of the examples.
*
* @param <OutputType> The output type of the data.
* @author Justin Basilico
* @since 2.1
*/
@CodeReview(
reviewer="Kevin R. Dixon",
date="2008-07-22",
changesNeeded=false,
comments={
"Fixed a few typos in javadoc.",
"Removed implements Serializeable, as BatchLearner already does that.",
"I don't particularly like this class... I just don't think it's useful.",
"However, the code looks fine."
}
)
public class WeightedMostFrequentLearner<OutputType>
extends AbstractCloneableSerializable
implements SupervisedBatchLearner<Object,OutputType,ConstantEvaluator<OutputType>>
{
/**
* Creates a new {@code MostFrequentLearner}.
*/
public WeightedMostFrequentLearner()
{
super();
}
/**
* Creates a constant evaluator based on the most frequent output in a given
* collection of input-output pairs, taking the weight into account.
*
* @param data {@inheritDoc}
* @return {@inheritDoc}
*/
@Override
public ConstantEvaluator<OutputType> learn(
final Collection<? extends InputOutputPair<? extends Object, OutputType>> data )
{
// We are going to sum up the weight associated with each output value.
final DefaultDataDistribution<OutputType> weightDistribution =
new DefaultDataDistribution<OutputType>();
// Go through all the examples and increment the weight sum for each
// output value.
for (InputOutputPair<?, ? extends OutputType> example : data)
{
final double weight = DatasetUtil.getWeight(example);
final OutputType output = example.getOutput();
weightDistribution.increment(output, weight);
}
// Figure out the output with the highest weight.
final ConstantEvaluator<OutputType> result =
new ConstantEvaluator<OutputType>();
if (weightDistribution.getTotal() > 0.0)
{
result.setValue(weightDistribution.getMaxValueKey());
}
// Create the resulting evaluator.
return result;
}
}