/*
* File: BatchMultiPerceptron.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry Learning Core
*
* Copyright April 21, 2011, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
*
*/
package gov.sandia.cognition.learning.algorithm.perceptron;
import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.learning.function.categorization.LinearMultiCategorizer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.ArgumentChecker;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.Set;
/**
* Implements a multi-class version of the standard batch Perceptron learning
* algorithm. It learns over labeled examples that have a vector input and an
* arbitrary output label. This version keeps a separate Perceptron as the
* prototype for each category. When it makes an error, it subtracts the input
* from the incorrect prototype and adds it to the correct prototype.
* This implementation also allows a margin to be enforced, which
* in the multi-class case means that the output value for the actual class
* must be at least the given margin from the next highest class.
*
* @param <CategoryType>
* The type of output categories. Can be any type that has a valid
* equals and hashCode method.
* @author Justin Basilico
* @since 3.2.0
* @see Perceptron
* @see OnlinePerceptron
*/
@PublicationReference(
title="Ultraconservative Online Algorithms for Multiclass Problems",
author={"Koby Crammer", "Yoram Singer"},
year=2003,
type=PublicationType.Journal,
publication="Journal of Machine Learning Research",
pages={951, 991},
url="http://portal.acm.org/citation.cfm?id=944936")
public class BatchMultiPerceptron<CategoryType>
extends AbstractAnytimeSupervisedBatchLearner<Vectorizable, CategoryType, LinearMultiCategorizer<CategoryType>>
implements MeasurablePerformanceAlgorithm, VectorFactoryContainer
{
/** The default maximum number of iterations, {@value}. */
public static final int DEFAULT_MAX_ITERATIONS = 100;
/** The default minimum margin is {@value}. */
public static final double DEFAULT_MIN_MARGIN = 0.0;
/** The minimum margin to enforce. Must be non-negative. */
protected double minMargin;
/** The factory to create weight vectors. */
protected VectorFactory<?> vectorFactory;
/** The linear categorizer created by the algorithm. */
protected transient LinearMultiCategorizer<CategoryType> result;
/** The number of errors on the most recent iteration. */
protected transient int errorCount;
/**
* Creates a new {@code BatchMultiPerceptron} with default parameters.
*/
public BatchMultiPerceptron()
{
this(DEFAULT_MAX_ITERATIONS);
}
/**
* Creates a new {@code BatchMultiPerceptron} with the given maximum number
* of iterations and a default margin of 0.0.
*
* @param maxIterations
* The maximum number of iterations. Must be positive.
*/
public BatchMultiPerceptron(
final int maxIterations)
{
this(maxIterations, DEFAULT_MIN_MARGIN);
}
/**
* Creates a new {@code BatchMultiPerceptron} with the given maximum
* number of iterations and margin to enforce.
*
* @param maxIterations
* The maximum number of iterations. Must be positive.
* @param minMargin
* The minimum margin to enforce. Must be non-negative.
*/
public BatchMultiPerceptron(
final int maxIterations,
final double minMargin)
{
this(maxIterations, minMargin, VectorFactory.getDefault());
}
/**
* Creates a new {@code BatchMultiPerceptron} with the given parameters.
*
* @param maxIterations
* The maximum number of iterations. Must be positive.
* @param minMargin
* The minimum margin to enforce. Must be non-negative.
* @param vectorFactory
* The factory to use to create weight vectors.
*/
public BatchMultiPerceptron(
final int maxIterations,
final double minMargin,
final VectorFactory<?> vectorFactory)
{
super(maxIterations);
this.setMinMargin(minMargin);
this.setVectorFactory(vectorFactory);
}
@Override
protected boolean initializeAlgorithm()
{
if (CollectionUtil.isEmpty(this.getData()))
{
// No data to learn from.
return false;
}
// Get the dimensionality of the data.
final int dimensionality = DatasetUtil.getInputDimensionality(
this.getData());
// Create the categorizer we will learn and create the prototypes for
// each category.
this.result = new LinearMultiCategorizer<CategoryType>();
final Set<CategoryType> categories = DatasetUtil.findUniqueOutputs(
this.getData());
for (CategoryType category : categories)
{
final LinearBinaryCategorizer prototype =
new LinearBinaryCategorizer(
this.getVectorFactory().createVector(dimensionality), 0.0);
this.result.getPrototypes().put(category, prototype);
}
// The algorithm is now initialized.
return true;
}
@Override
protected boolean step()
{
// Reset the number of errors for the new iteration.
this.setErrorCount(0);
// Loop over all the training instances.
for (InputOutputPair<? extends Vectorizable, CategoryType> example
: this.getData())
{
if (example == null)
{
// Ignore null examples.
continue;
}
// Get the input as a Vector and the actual category.
final Vector input = example.getInput().convertToVector();
final CategoryType actual = example.getOutput();
// See what the predicted category is.
CategoryType predicted = null;
double predictedScore = Double.NEGATIVE_INFINITY;
for (CategoryType category : this.result.getCategories())
{
double score = this.result.evaluateAsDouble(input, category);
if (this.minMargin != 0.0 && actual.equals(category))
{
// Enforce a margin on the correct category.
score -= this.minMargin;
}
if (score > predictedScore)
{
// This is the predicted category.
predicted = category;
predictedScore = score;
}
}
// See if the algorithm was correct or not.
final boolean correct = ObjectUtil.equalsSafe(actual, predicted);
if (!correct)
{
// The classification was incorrect so we need to update.
this.setErrorCount(this.getErrorCount() + 1);
// Increment the prototype for the actual category.
final LinearBinaryCategorizer actualPrototype =
this.result.getPrototypes().get(actual);
actualPrototype.getWeights().plusEquals(input);
actualPrototype.setBias(actualPrototype.getBias() + 1.0);
// Decrement the prototype for the predicted category.
final LinearBinaryCategorizer predictedPrototype =
this.result.getPrototypes().get(predicted);
predictedPrototype.getWeights().minusEquals(input);
predictedPrototype.setBias(predictedPrototype.getBias() - 1.0);
}
// else - Not an error, no need to update.
}
// Keep going while the error count is positive.
return this.getErrorCount() > 0;
}
@Override
protected void cleanupAlgorithm()
{
// Nothing to clean up.
}
@Override
public LinearMultiCategorizer<CategoryType> getResult()
{
return this.result;
}
/**
* Sets the result of the algorithm.
*
* @param result
* The result of the algorithm.
*/
protected void setResult(
final LinearMultiCategorizer<CategoryType> result)
{
this.result = result;
}
/**
* Gets the minimum margin to enforce. Any value less than or equal to
* this is considered an error for the algorithm.
*
* @return
* The minimum margin. Cannot be negative.
*/
public double getMinMargin()
{
return this.minMargin;
}
/**
* Gets the minimum margin to enforce. Any value less than or equal to
* this is considered an error for the algorithm.
*
* @param minMargin
* The minimum margin. Cannot be negative.
*/
public void setMinMargin(
final double minMargin)
{
ArgumentChecker.assertIsNonNegative("minMargin", minMargin);
this.minMargin = minMargin;
}
/**
* Gets the VectorFactory used to create the weight vector.
*
* @return The VectorFactory used to create the weight vector.
*/
@Override
public VectorFactory<?> getVectorFactory()
{
return this.vectorFactory;
}
/**
* Sets the VectorFactory used to create the weight vector.
*
* @param vectorFactory The VectorFactory used to create the weight vector.
*/
public void setVectorFactory(
final VectorFactory<?> vectorFactory)
{
this.vectorFactory = vectorFactory;
}
/**
* Gets the error count of the most recent iteration.
*
* @return The current error count.
*/
public int getErrorCount()
{
return this.errorCount;
}
/**
* Sets the error count of the most recent iteration.
*
* @param errorCount The current error count.
*/
protected void setErrorCount(
final int errorCount)
{
this.errorCount = errorCount;
}
/**
* Gets the performance, which is the error count on the last iteration.
*
* @return The performance of the algorithm.
*/
@Override
public NamedValue<Integer> getPerformance()
{
return new DefaultNamedValue<Integer>("error count",
this.getErrorCount());
}
}