/* * File: OnlineMultiPerceptron.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry Learning Core * * Copyright January 28, 2011, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. * */ package gov.sandia.cognition.learning.algorithm.perceptron; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.learning.function.categorization.LinearMultiCategorizer; import gov.sandia.cognition.learning.algorithm.AbstractBatchAndIncrementalLearner; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.VectorFactory; import gov.sandia.cognition.math.matrix.VectorFactoryContainer; import gov.sandia.cognition.math.matrix.Vectorizable; import gov.sandia.cognition.util.ArgumentChecker; import gov.sandia.cognition.util.DefaultWeightedValue; import gov.sandia.cognition.util.ObjectUtil; import java.util.LinkedList; /** * An online, multiple category version of the Perceptron algorithm. It learns * a separate linear binary categorizer for each category. * * @param <CategoryType> * The type of output categories. Can be any type that has a valid * equals and hashCode method. * @author Justin Basilico * @since 3.3.0 */ public class OnlineMultiPerceptron<CategoryType> extends AbstractBatchAndIncrementalLearner<InputOutputPair<? extends Vectorizable, CategoryType>, LinearMultiCategorizer<CategoryType>> implements VectorFactoryContainer { /** The default minimum margin is {@value}. */ public static final double DEFAULT_MIN_MARGIN = 0.0; /** The minimum margin to enforce. Must be non-negative. */ protected double minMargin; /** The factory to create weight vectors. */ protected VectorFactory<?> vectorFactory; /** * Creates a new {@code OnlineMultiPerceptron}. */ public OnlineMultiPerceptron() { this(DEFAULT_MIN_MARGIN); } /** * Creates a new {@code OnlineMultiPerceptron} with the * given minimum margin. * * @param minMargin * The minimum margin to consider an example correct. */ public OnlineMultiPerceptron( final double minMargin) { this(minMargin, VectorFactory.getDefault()); } /** * Creates a new {@code OnlineMultiPerceptron} with the * given minimum margin and backing vector factory. * * @param minMargin * The minimum margin to consider an example correct. * @param vectorFactory * The vector factory used to create the weight vectors. */ public OnlineMultiPerceptron( final double minMargin, final VectorFactory<?> vectorFactory) { super(); this.setMinMargin(minMargin); this.setVectorFactory(vectorFactory); } @Override public LinearMultiCategorizer<CategoryType> createInitialLearnedObject() { return new LinearMultiCategorizer<CategoryType>(); } @Override public void update( final LinearMultiCategorizer<CategoryType> target, final InputOutputPair<? extends Vectorizable, CategoryType> example) { final Vector input = example.getInput().convertToVector(); final CategoryType actual = example.getOutput(); final boolean knownCategory = target.getCategories().contains(actual); if (!knownCategory) { // This category was never seen, so create a new prototype for it. final Vector initialWeights = this.getVectorFactory().createVector( input.getDimensionality()); target.getPrototypes().put(actual, new LinearBinaryCategorizer(initialWeights, 0.0)); } // See what the predicted category is. CategoryType predicted = null; double predictedScore = Double.NEGATIVE_INFINITY; for (CategoryType category : target.getCategories()) { double score = target.evaluateAsDouble(input, category); if (actual.equals(category)) { score -= this.minMargin; } if (score > predictedScore) { predicted = category; predictedScore = score; } } final boolean correct = ObjectUtil.equalsSafe(actual, predicted); if (!correct) { // Increment the prototype for the actual category. final LinearBinaryCategorizer actualPrototype = target.getPrototypes().get(actual); actualPrototype.getWeights().plusEquals(input); actualPrototype.setBias(actualPrototype.getBias() + 1.0); // Decrement the prototype for the predicted category. final LinearBinaryCategorizer predictedPrototype = target.getPrototypes().get(predicted); predictedPrototype.getWeights().minusEquals(input); predictedPrototype.setBias(predictedPrototype.getBias() - 1.0); } // else - Not an error, no need to update. } /** * Gets the minimum margin to enforce. Any value less than or equal to * this is considered an error for the algorithm. * * @return * The minimum margin. Cannot be negative. */ public double getMinMargin() { return this.minMargin; } /** * Gets the minimum margin to enforce. Any value less than or equal to * this is considered an error for the algorithm. * * @param minMargin * The minimum margin. Cannot be negative. */ public void setMinMargin( final double minMargin) { ArgumentChecker.assertIsNonNegative("minMargin", minMargin); this.minMargin = minMargin; } /** * Gets the VectorFactory used to create the weight vector. * * @return The VectorFactory used to create the weight vector. */ @Override public VectorFactory<?> getVectorFactory() { return this.vectorFactory; } /** * Sets the VectorFactory used to create the weight vector. * * @param vectorFactory The VectorFactory used to create the weight vector. */ public void setVectorFactory( final VectorFactory<?> vectorFactory) { this.vectorFactory = vectorFactory; } /** * Variant of a multi-category Perceptron that performs a uniform weight * update on all categories that are scored higher than the true category * such that the weights are equal and sum to -1. * * @param <CategoryType> * The type of output categories. Can be any type that has a valid * equals and hashCode method. */ @PublicationReference( title="Ultraconservative online algorithms for multiclass problems", author={"Koby Crammer", "Yoram Singer"}, year=2003, type=PublicationType.Journal, publication="The Journal of Machine Learning Research", pages={951, 991}, url="http://portal.acm.org/citation.cfm?id=944936") public static class UniformUpdate<CategoryType> extends OnlineMultiPerceptron<CategoryType> { /** * Creates a new {@code OnlineMultiPerceptron.UniformUpdate}. */ public UniformUpdate() { super(); } /** * Creates a new {@code OnlineMultiPerceptron.UniformUpdate} with the * given minimum margin. * * @param minMargin * The minimum margin to consider an example correct. */ public UniformUpdate( final double minMargin) { super(minMargin); } /** * Creates a new {@code OnlineMultiPerceptron.UniformUpdate} with the * given minimum margin and backing vector factory. * * @param minMargin * The minimum margin to consider an example correct. * @param vectorFactory * The vector factory used to create the weight vectors. */ public UniformUpdate( final double minMargin, final VectorFactory<?> vectorFactory) { super(minMargin, vectorFactory); } @Override public void update( final LinearMultiCategorizer<CategoryType> target, final InputOutputPair<? extends Vectorizable, CategoryType> example) { // TODO: This shares a lot of code with the parent class update. // Figure out a way to combine them in a sensible way. // -- jdbasil (2011-01-30) final Vector input = example.getInput().convertToVector(); final CategoryType actual = example.getOutput(); final boolean knownCategory = target.getCategories().contains(actual); if (!knownCategory) { // This category was never seen, so create a new prototype for it. final Vector initialWeights = this.getVectorFactory().createVector( input.getDimensionality()); target.getPrototypes().put(actual, new LinearBinaryCategorizer(initialWeights, 0.0)); } // See what the predicted category is. final double actualScore = target.evaluateAsDouble(input, actual) - this.minMargin; final LinkedList<CategoryType> errors = new LinkedList<CategoryType>(); for (CategoryType category : target.getCategories()) { final double score = target.evaluateAsDouble(input, category); if (!actual.equals(category) && score >= actualScore) { errors.add(category); } } final boolean correct = errors.isEmpty(); if (!correct) { // Increment the prototype for the actual category. final LinearBinaryCategorizer actualPrototype = target.getPrototypes().get(actual); actualPrototype.getWeights().plusEquals(input); actualPrototype.setBias(actualPrototype.getBias() + 1.0); // Decrement the prototype for the predicted category. final double errorWeight = 1.0 / errors.size(); final Vector errorUpdate = input.scale(errorWeight); for (CategoryType category : errors) { final LinearBinaryCategorizer prototype = target.getPrototypes().get(category); prototype.getWeights().minusEquals(errorUpdate); prototype.setBias(prototype.getBias() - errorWeight); } } // else - Not an error, no need to update. } } /** * Variant of a multi-category Perceptron that performs a proportional * weight update on all categories that are scored higher than the true * category such that the weights sum to 1.0 and are proportional how much * larger the score was for each incorrect category than the true category. * * @param <CategoryType> * The type of output categories. Can be any type that has a valid * equals and hashCode method. */ @PublicationReference( title="Ultraconservative online algorithms for multiclass problems", author={"Koby Crammer", "Yoram Singer"}, year=2003, type=PublicationType.Journal, publication="The Journal of Machine Learning Research", pages={951, 991}, url="http://portal.acm.org/citation.cfm?id=944936") public static class ProportionalUpdate<CategoryType> extends OnlineMultiPerceptron<CategoryType> { /** The default minimum margin is {@value}. */ public static final double DEFAULT_MIN_MARGIN = 0.001; /** * Creates a new {@code OnlineMultiPerceptron.ProportionalUpdate}. */ public ProportionalUpdate() { this(DEFAULT_MIN_MARGIN); } /** * Creates a new {@code OnlineMultiPerceptron.ProportionalUpdate} with the * given minimum margin. * * @param minMargin * The minimum margin to consider an example correct. */ public ProportionalUpdate( final double minMargin) { super(minMargin); } /** * Creates a new {@code OnlineMultiPerceptron.ProportionalUpdate} with the * given minimum margin and backing vector factory. * * @param minMargin * The minimum margin to consider an example correct. * @param vectorFactory * The vector factory used to create the weight vectors. */ public ProportionalUpdate( final double minMargin, final VectorFactory<?> vectorFactory) { super(minMargin, vectorFactory); } @Override public void update( final LinearMultiCategorizer<CategoryType> target, final InputOutputPair<? extends Vectorizable, CategoryType> example) { // TODO: This shares a lot of code with the parent class update. // Figure out a way to combine them in a sensible way. // -- jdbasil (2011-01-30) final Vector input = example.getInput().convertToVector(); final CategoryType actual = example.getOutput(); final boolean knownCategory = target.getCategories().contains(actual); if (!knownCategory) { // This category was never seen, so create a new prototype for it. final Vector initialWeights = this.getVectorFactory().createVector( input.getDimensionality()); target.getPrototypes().put(actual, new LinearBinaryCategorizer(initialWeights, 0.0)); } // See what the predicted category is. final double actualScore = target.evaluateAsDouble(input, actual) - minMargin; final LinkedList<DefaultWeightedValue<CategoryType>> errors = new LinkedList<DefaultWeightedValue<CategoryType>>(); double differenceSum = 0.0; for (CategoryType category : target.getCategories()) { final double score = target.evaluateAsDouble(input, category); double difference = score - actualScore; if (difference >= 0.0 && !actual.equals(category)) { errors.add(DefaultWeightedValue.create(category, difference)); differenceSum += difference; } } final boolean correct = errors.isEmpty(); if (!correct) { // Increment the prototype for the actual category. final LinearBinaryCategorizer actualPrototype = target.getPrototypes().get(actual); actualPrototype.getWeights().plusEquals(input); actualPrototype.setBias(actualPrototype.getBias() + 1.0); // Decrement the prototype for the predicted category. for (DefaultWeightedValue<CategoryType> category : errors) { final LinearBinaryCategorizer prototype = target.getPrototypes().get(category.getValue()); final double errorWeight = category.getWeight() / differenceSum; prototype.getWeights().minusEquals(input.scale(errorWeight)); prototype.setBias(prototype.getBias() - errorWeight); } } // else - Not an error, no need to update. } @Override public void setMinMargin( final double minMargin) { ArgumentChecker.assertIsPositive("minMargin", minMargin); super.setMinMargin(minMargin); } } }