/* * File: FactorizationMachineAlternatingLeastSquares.java * Authors: Justin Basilico * Project: Cognitive Foundry * * Copyright 2015 Cognitive Foundry. All rights reserved. */ package gov.sandia.cognition.learning.algorithm.factor.machine; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationReferences; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.collection.CollectionUtil; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.math.matrix.Matrix; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.VectorEntry; import gov.sandia.cognition.math.matrix.VectorFactory; import gov.sandia.cognition.util.ArgumentChecker; import gov.sandia.cognition.util.DefaultNamedValue; import gov.sandia.cognition.util.NamedValue; import java.util.ArrayList; import java.util.Random; /** * Implements an Alternating Least Squares (ALS) algorithm for learning a * Factorization Machine. * * @author Justin Basilico * @since 3.4.1 * @see FactorizationMachine */ @PublicationReferences(references={ @PublicationReference( title="Fast Context-aware Recommendations with Factorization Machines", author={"Steffen Rendle", "Zeno Gantner", "Christoph Freudenthaler", "Lars Schmidt-Thieme"}, year=2011, type=PublicationType.Conference, publication="Proceeding of the 34th international ACM SIGIR conference on Research and development in Information Retrieval (SIGIR)", url="http://www.inf.uni-konstanz.de/~rendle/pdf/Rendle2011-CARS.pdf"), @PublicationReference( title="Factorization Machines with libFM", author="Steffen Rendle", year=2012, type=PublicationType.Journal, publication="ACM Transactions on Intelligent Systems Technology", url="http://www.csie.ntu.edu.tw/~b97053/paper/Factorization%20Machines%20with%20libFM.pdf", notes="Algorithm 2: Alternating Least Squares (ALS)") }) public class FactorizationMachineAlternatingLeastSquares extends AbstractFactorizationMachineLearner { /** The default minimum change is {@value}. */ public static final double DEFAULT_MIN_CHANGE = 0.00001; /** The minimum change allowed in an iteration. The algorithm stops if the * change is less than this value. Cannot be negative. */ protected double minChange; /** The size of the data. */ protected transient int dataSize; /** The data in the form that it can be accessed in O(1) as a list. */ protected transient ArrayList<? extends InputOutputPair<? extends Vector, Double>> dataList; /** A list representing a transposed form of the matrix of inputs. It is a * d by n sparse matrix stored as an array list of sparse vectors. This is * used to speed up the computation of the per-coordinate updates. */ protected transient ArrayList<Vector> inputsTransposed; /** The total change from the current iteration. */ protected double totalChange; /** The total error from the current iteration. */ protected double totalError; /** * Creates a new {@link FactorizationMachineAlternatingLeastSquares} with * default parameter values. */ public FactorizationMachineAlternatingLeastSquares() { this(DEFAULT_FACTOR_COUNT, DEFAULT_BIAS_REGULARIZATION, DEFAULT_WEIGHT_REGULARIZATION, DEFAULT_FACTOR_REGULARIZATION, DEFAULT_SEED_SCALE, DEFAULT_MAX_ITERATIONS, DEFAULT_MIN_CHANGE, new Random()); } /** * Creates a new {@link FactorizationMachineAlternatingLeastSquares}. * * @param factorCount * The number of factors to use. Zero means no factors. Cannot be * negative. * @param biasRegularization * The regularization term for the bias. Cannot be negative. * @param weightRegularization * The regularization term for the linear weights. Cannot be negative. * @param factorRegularization * The regularization term for the factor matrix. Cannot be negative. * @param seedScale * The random initialization scale for the factors. * Multiplied by a random Gaussian to initialize each factor value. * Cannot be negative. * @param maxIterations * The maximum number of iterations for the algorithm to run. Cannot * be negative. * @param minChange * The minimum change allowed in an iteration. The algorithm stops if * the change is less than this value. Cannot be negative. * @param random * The random number generator. */ public FactorizationMachineAlternatingLeastSquares( final int factorCount, final double biasRegularization, final double weightRegularization, final double factorRegularization, final double seedScale, final int maxIterations, final double minChange, final Random random) { super(factorCount, biasRegularization, weightRegularization, factorRegularization, seedScale, maxIterations, random); this.setMinChange(minChange); } @Override protected boolean initializeAlgorithm() { if (!super.initializeAlgorithm()) { return false; } this.dataSize = this.data.size(); if (this.dataSize <= 0) { return false; } // Convert the input data to an array list. this.dataList = CollectionUtil.asArrayList(this.data); // Create a way to store the transposed input data in sparse vectors. final VectorFactory<?> sparseFactory = VectorFactory.getSparseDefault(); this.inputsTransposed = new ArrayList<>(this.dimensionality); for (int i = 0; i < this.dimensionality; i++) { this.inputsTransposed.add(sparseFactory.createVector(this.dataSize)); } // Fill in the transposed data. for (int i = 0; i < this.dataSize; i++) { final InputOutputPair<? extends Vector, ?> example = this.dataList.get(i); for (final VectorEntry entry : example.getInput()) { if (entry.getValue() != 0.0) { this.inputsTransposed.get(entry.getIndex()).set(i, entry.getValue()); } } } return true; } @Override protected boolean step() { this.totalChange = 0.0; // TODO: Should errors just be computed once and then updated? final Vector errors = VectorFactory.getDenseDefault().createVector( this.dataSize); // Compute the initial prediction and error terms per input. for (int i = 0; i < this.dataSize; i++) { final InputOutputPair<? extends Vector, Double> example = this.dataList.get(i); final double prediction = this.result.evaluateAsDouble( example.getInput()); final double actual = example.getOutput(); final double error = actual - prediction; errors.set(i, error); } // Update the bias. if (this.isBiasEnabled()) { final double oldBias = this.result.getBias(); final double newBias = (oldBias * this.dataSize + errors.sum()) / (this.dataSize + this.biasRegularization); this.result.setBias(newBias); // Update the running errors. final double biasChange = oldBias - newBias; for (int i = 0; i < this.dataSize; i++) { errors.increment(i, biasChange); } this.totalChange += Math.abs(biasChange); } // Update the weights. if (this.isWeightsEnabled()) { final Vector weights = this.result.getWeights(); for (int j = 0; j < this.dimensionality; j++) { final double oldWeight = weights.getElement(j); final Vector inputs = this.inputsTransposed.get(j); // TODO: This could be cached and computed once. final Vector derivative = inputs; final double sumOfSquares = derivative.norm2Squared(); final double newWeight = sumOfSquares == 0.0 ? 0.0 : (oldWeight * sumOfSquares + derivative.dot(errors)) / (sumOfSquares + this.weightRegularization); weights.set(j, newWeight); // Update the running errors. final double weightChange = oldWeight - newWeight; errors.scaledPlusEquals(weightChange, inputs); this.totalChange += Math.abs(weightChange); } this.result.setWeights(weights); } // Update the factors. if (this.isFactorsEnabled()) { final Matrix factors = result.getFactors(); for (int k = 0; k < this.factorCount; k++) { final Vector factorTimesInput = VectorFactory.getDefault().createVector( this.dataSize); final Vector factorRow = factors.getRow(k); for (int i = 0; i < this.dataSize; i++) { factorTimesInput.set(i, this.dataList.get(i).getInput().dot(factorRow)); } for (int j = 0; j < this.dimensionality; j++) { final double oldFactor = factors.get(k, j); final Vector inputs = this.inputsTransposed.get(j); final Vector derivative = inputs.dotTimes(factorTimesInput); // TODO: This inputs^2 could be cached and computed once. derivative.scaledMinusEquals(oldFactor, inputs.dotTimes(inputs)); final double sumOfSquares = derivative.norm2Squared(); final double newFactor = sumOfSquares == 0.0 ? 0.0 : (oldFactor * sumOfSquares + derivative.dotProduct(errors)) / (sumOfSquares + this.factorRegularization); factors.set(k, j, newFactor); // Update the running errors and factor times input. final double factorChange = oldFactor - newFactor; errors.scaledPlusEquals(factorChange, derivative); factorTimesInput.scaledPlusEquals(-factorChange, inputs); this.totalChange += Math.abs(factorChange); } } this.result.setFactors(factors); } this.totalError = errors.norm2Squared(); return this.totalChange >= this.minChange; } @Override protected void cleanupAlgorithm() { this.dataList = null; this.inputsTransposed = null; } /** * Gets the total change in the model parameters from the current iteration. * * @return * The total change in model parameters. */ public double getTotalChange() { return this.totalChange; } /** * Gets the total squared error from the current iteration. * * @return * The total squared error. */ public double getTotalError() { return this.totalError; } /** * Gets the regularization penalty term in the error for the objective. * * @return * The regularization penalty term. */ public double getRegularizationPenalty() { if (this.result == null) { return 0.0; } double penalty = this.biasRegularization * this.result.getBias(); if (this.result.hasWeights()) { penalty += this.weightRegularization * this.result.getWeights().norm2Squared(); } if (this.result.hasFactors()) { penalty += this.factorRegularization * this.result.getFactors().normFrobeniusSquared(); } return penalty; } /** * Gets the total objective, which is the squared error plus the * regularization terms. * * @return * The value of the optimization objective. */ public double getObjective() { return this.getTotalError() / Math.max(1, this.dataSize) + 0.5 * this.getRegularizationPenalty(); } /** * Gets the total objective, which is the squared error plus the * regularization terms. * * @return * The value of the optimization objective. */ public double computeObjective() { double error = 0.0; for (int i = 0; i < this.dataSize; i++) { final InputOutputPair<? extends Vector, Double> example = this.dataList.get(i); final double prediction = this.result.evaluateAsDouble( example.getInput()); final double actual = example.getOutput(); double difference = actual - prediction; error += difference * difference; } return error / Math.max(1, this.dataSize) + 0.5 * this.getRegularizationPenalty(); } @Override public NamedValue<? extends Number> getPerformance() { return DefaultNamedValue.create("objective", this.getObjective()); } /** * Gets the minimum change allowed in an iteration. The algorithm stops if * the change is less than this value. * * @return * The minimum change. Cannot be negative. */ public double getMinChange() { return this.minChange; } /** * Sets the minimum change allowed in an iteration. The algorithm stops if * the change is less than this value. * * @param minChange * The minimum change. Cannot be negative. */ public void setMinChange( final double minChange) { ArgumentChecker.assertIsNonNegative("minChange", minChange); this.minChange = minChange; } }