/* * File: FactorizationMachineStochasticGradient.java * Authors: Justin Basilico * Project: Cognitive Foundry * * Copyright 2013 Cognitive Foundry. All rights reserved. */ package gov.sandia.cognition.learning.algorithm.factor.machine; import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationReferences; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.collection.CollectionUtil; import gov.sandia.cognition.learning.data.DatasetUtil; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.math.matrix.Matrix; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.VectorEntry; import gov.sandia.cognition.util.ArgumentChecker; import gov.sandia.cognition.util.DefaultNamedValue; import gov.sandia.cognition.util.NamedValue; import java.util.ArrayList; import java.util.Random; /** * Implements a Stochastic Gradient Descent (SGD) algorithm for learning a * Factorization Machine. * * @author Justin Basilico * @since 3.4.0 * @see FactorizationMachine */ @PublicationReferences(references={ @PublicationReference( title="Factorization Machines", author={"Steffen Rendle"}, year=2010, type=PublicationType.Conference, publication="Proceedings of the 10th IEEE International Conference on Data Mining (ICDM)", url="http://www.inf.uni-konstanz.de/~rendle/pdf/Rendle2010FM.pdf"), @PublicationReference( title="Factorization Machines with libFM", author="Steffen Rendle", year=2012, type=PublicationType.Journal, publication="ACM Transactions on Intelligent Systems Technology", url="http://www.csie.ntu.edu.tw/~b97053/paper/Factorization%20Machines%20with%20libFM.pdf", notes="Algorithm 1: Stochastic Gradient Descent (SGD)") }) public class FactorizationMachineStochasticGradient extends AbstractFactorizationMachineLearner implements MeasurablePerformanceAlgorithm { // TODO: Support a version that does binary categorization. /** The default learning rate is {@value}. */ public static final double DEFAULT_LEARNING_RATE = 0.001; /** The learning rate for the algorithm. Must be positive. */ protected double learningRate; /** The input data represented as a list for fast access. */ protected transient ArrayList<? extends InputOutputPair<? extends Vector, Double>> dataList; /** The total error for the current iteration. */ protected transient double totalError; /** The total change in factorization machine parameters for the current * iteration. */ protected transient double totalChange; /** * Creates a new {@link FactorizationMachineStochasticGradient} with * default parameters. */ public FactorizationMachineStochasticGradient() { this(DEFAULT_FACTOR_COUNT, DEFAULT_LEARNING_RATE, DEFAULT_BIAS_REGULARIZATION, DEFAULT_WEIGHT_REGULARIZATION, DEFAULT_FACTOR_REGULARIZATION, DEFAULT_SEED_SCALE, DEFAULT_MAX_ITERATIONS, new Random()); } /** * Creates a new {@link AbstractFactorizationMachineLearner}. * * @param factorCount * The number of factors to use. Zero means no factors. Cannot be * negative. * @param learningRate * The learning rate. Must be positive. * @param biasRegularization * The regularization term for the bias. Cannot be negative. * @param weightRegularization * The regularization term for the linear weights. Cannot be negative. * @param factorRegularization * The regularization term for the factor matrix. Cannot be negative. * @param seedScale * The random initialization scale for the factors. * Multiplied by a random Gaussian to initialize each factor value. * Cannot be negative. * @param maxIterations * The maximum number of iterations for the algorithm to run. Cannot * be negative. * @param random * The random number generator. */ public FactorizationMachineStochasticGradient( final int factorCount, final double learningRate, final double biasRegularization, final double weightRegularization, final double factorRegularization, final double seedScale, final int maxIterations, final Random random) { super(factorCount, biasRegularization, weightRegularization, factorRegularization, seedScale, maxIterations, random); this.setLearningRate(learningRate); } @Override protected boolean initializeAlgorithm() { if (!super.initializeAlgorithm()) { return false; } this.dataList = CollectionUtil.asArrayList(this.data); this.totalError = 0.0; this.totalChange = 0.0; return true; } @Override protected boolean step() { this.totalError = 0.0; this.totalChange = 0.0; // TODO: Should there be a more general SGD harness that does permutation of // the order and block SGD? for (final InputOutputPair<? extends Vector, Double> example : this.dataList) { this.update(example); } // TODO: Stopping conditions. return true; } /** * Performs a single update of step of the stochastic gradient descent * by updating according to the given example. * * @param example * The example to do a stochastic gradient step for. */ protected void update( final InputOutputPair<? extends Vector, Double> example) { final Vector input = example.getInput(); final double label = example.getOutput(); final double weight = DatasetUtil.getWeight(example); final double prediction = this.result.evaluateAsDouble(input); final double error = prediction - label; // Compute the step size for this example. final double stepSize = this.learningRate * weight / this.data.size(); if (this.isBiasEnabled()) { // Update the bias term. final double oldBias = this.result.getBias(); final double biasChange = stepSize * (2.0 * error + 2.0 * this.biasRegularization * oldBias); this.result.setBias(oldBias - biasChange); this.totalChange += Math.abs(biasChange); } if (this.isWeightsEnabled()) { // Update the weight terms. final Vector weights = this.result.getWeights(); for (final VectorEntry entry : input) { final int index = entry.getIndex(); final double value = entry.getValue(); final double weightChange = stepSize * (2.0 * error * value + 2.0 * this.weightRegularization * weights.getElement(index)); weights.decrement(index, weightChange); this.totalChange += Math.abs(weightChange); } this.result.setWeights(weights); } if (this.isFactorsEnabled()) { // Update the factor terms. final Matrix factors = this.result.getFactors(); // TODO: This same calculation is needed in model evaluation. for (int k = 0; k < this.factorCount; k++) { double sum = 0.0; for (final VectorEntry entry : input) { sum += entry.getValue() * factors.getElement(k, entry.getIndex()); } for (final VectorEntry entry : input) { final int index = entry.getIndex(); final double value = entry.getValue(); final double factorElement = factors.getElement(k, index); final double gradient = value * (sum - value * factorElement); final double factorChange = stepSize * (2.0 * error * gradient + 2.0 * this.factorRegularization * factorElement); factors.decrement(k, index, factorChange); this.totalChange += Math.abs(factorChange); } } this.result.setFactors(factors); } this.totalError += error * error; } @Override protected void cleanupAlgorithm() { this.dataList = null; } /** * Gets the total change from the current iteration. * * @return * The total change in the parameters of the factorization machine. */ public double getTotalChange() { return this.totalChange; } /** * Gets the total squared error from the current iteration. * * @return * The total squared error. */ public double getTotalError() { return this.totalError; } /** * Gets the regularization penalty term for the current result. It * computes the squared 2-norm of the parameters of the factorization * machine, each multiplied with their appropriate regularization weight. * * @return * The regularization penalty term for the objective. */ public double getRegularizationPenalty() { final double bias = this.result.getBias(); double penalty = this.biasRegularization * bias * bias; if (this.result.hasWeights()) { penalty += this.weightRegularization * this.result.getWeights().norm2Squared(); } if (this.result.hasFactors()) { penalty += this.factorRegularization * this.result.getFactors().normFrobeniusSquared(); } return penalty; } /** * Gets the total objective, which is the mean squared error plus the * regularization terms. * * @return * The value of the optimization objective. */ public double getObjective() { return this.getTotalError() / this.data.size() + this.getRegularizationPenalty(); } @Override public NamedValue<? extends Number> getPerformance() { return DefaultNamedValue.create("objective", this.getObjective()); } /** * Gets the learning rate. It governs the step size of the algorithm. * * @return * The learning rate. Must be positive. */ public double getLearningRate() { return this.learningRate; } /** * * Gets the learning rate. It governs the step size of the algorithm. * * @param learningRate * The learning rate. Must be positive. */ public void setLearningRate( final double learningRate) { ArgumentChecker.assertIsPositive("learningRate", learningRate); this.learningRate = learningRate; } }