/* * File: AbstractFactorizationMachineLearner.java * Authors: Justin Basilico * Project: Cognitive Foundry * * Copyright 2013 Cognitive Foundry. All rights reserved. */ package gov.sandia.cognition.learning.algorithm.factor.machine; import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm; import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner; import gov.sandia.cognition.learning.data.DatasetUtil; import gov.sandia.cognition.math.matrix.Matrix; import gov.sandia.cognition.math.matrix.MatrixFactory; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.VectorFactory; import gov.sandia.cognition.util.ArgumentChecker; import gov.sandia.cognition.util.Randomized; import java.util.Random; /** * An abstract class for learning {@link FactorizationMachine}s. It defines * the common parameters for learning algorithms and how to initialize the * factorization machine for learning. * * @author Justin Basilico * @since 3.4.0 */ public abstract class AbstractFactorizationMachineLearner extends AbstractAnytimeSupervisedBatchLearner<Vector, Double, FactorizationMachine> implements Randomized, MeasurablePerformanceAlgorithm { /** The default number of factors is {@value}. */ public static final int DEFAULT_FACTOR_COUNT = 10; /** The default for bias enabled is {@value}. */ public static final boolean DEFAULT_BIAS_ENABLED = true; /** The default for weights enabled is {@value}. */ public static final boolean DEFAULT_WEIGHTS_ENABLED = true; /** The default bias regularization parameter is {@value}. */ public static final double DEFAULT_BIAS_REGULARIZATION = 0.0; /** The default weight regularization parameter is {@value}. */ public static final double DEFAULT_WEIGHT_REGULARIZATION = 0.001; /** The default factor regularization parameter is {@value}. */ public static final double DEFAULT_FACTOR_REGULARIZATION = 0.01; /** The default seed initialization scale is {@value}. */ public static final double DEFAULT_SEED_SCALE = 0.01; /** The default maximum number of iterations is {@value}. */ public static final int DEFAULT_MAX_ITERATIONS = 100; /** True if the bias term is enabled. */ protected boolean biasEnabled; /** True if the linear weight term is enabled. */ protected boolean weightsEnabled; /** The number of factors to use. Zero means no factors. */ protected int factorCount; /** The regularization term for the bias. Cannot be negative. */ protected double biasRegularization; /** The regularization term for the linear weights. Cannot be negative. */ protected double weightRegularization; /** The regularization term for the factor matrix. Cannot be negative. */ protected double factorRegularization; /** The standard deviation for initializing the factors. Cannot be negative. */ protected double seedScale; /** The random number generator to use. */ protected Random random; /** The current factorization machine output learned by the algorithm. */ protected transient FactorizationMachine result; /** The dimensionality of the input to the factorization machine. */ protected transient int dimensionality; /** * Creates a new {@link AbstractFactorizationMachineLearner}. */ public AbstractFactorizationMachineLearner() { this(DEFAULT_FACTOR_COUNT, DEFAULT_BIAS_REGULARIZATION, DEFAULT_WEIGHT_REGULARIZATION, DEFAULT_FACTOR_REGULARIZATION, DEFAULT_SEED_SCALE, DEFAULT_MAX_ITERATIONS, new Random()); } /** * Creates a new {@link AbstractFactorizationMachineLearner}. * * @param factorCount * The number of factors to use. Zero means no factors. Cannot be * negative. * @param biasRegularization * The regularization term for the bias. Cannot be negative. * @param weightRegularization * The regularization term for the linear weights. Cannot be negative. * @param factorRegularization * The regularization term for the factor matrix. Cannot be negative. * @param seedScale * The random initialization scale for the factors. * Multiplied by a random Gaussian to initialize each factor value. * Cannot be negative. * @param maxIterations * The maximum number of iterations for the algorithm to run. Cannot * be negative. * @param random * The random number generator. */ public AbstractFactorizationMachineLearner( final int factorCount, final double biasRegularization, final double weightRegularization, final double factorRegularization, final double seedScale, final int maxIterations, final Random random) { super(maxIterations); this.setFactorCount(factorCount); this.setBiasEnabled(DEFAULT_BIAS_ENABLED); this.setWeightsEnabled(DEFAULT_WEIGHTS_ENABLED); this.setBiasRegularization(biasRegularization); this.setWeightRegularization(weightRegularization); this.setFactorRegularization(factorRegularization); this.setSeedScale(seedScale); this.setRandom(random); } @Override protected boolean initializeAlgorithm() { // Initialize the weight vectors. this.dimensionality = DatasetUtil.getInputDimensionality(this.data); final VectorFactory<?> vectorFactory = VectorFactory.getDenseDefault(); final Vector weights = vectorFactory.createVector(this.dimensionality); // Initialize the factors. final Matrix factors; if (this.factorCount <= 0) { factors = null; } else { // Initialize the factors to small random gaussian values. factors = MatrixFactory.getDenseDefault().createMatrix( this.factorCount, this.dimensionality); for (int i = 0; i < this.dimensionality; i++) { for (int j = 0; j < this.factorCount; j++) { factors.setElement(j, i, this.seedScale * this.random.nextGaussian()); } } } // Initialize the factorization machine. this.result = new FactorizationMachine(0.0, weights, factors); return true; } @Override public FactorizationMachine getResult() { return this.result; } /** * Gets the number of factors. The factors are used to represent the * pairwise interaction between dimensions in the input vector. * * @return * The number of factors to use. Zero means no factors. Cannot be * negative. */ public int getFactorCount() { return this.factorCount; } /** * Sets the number of factors. The factors are used to represent the * pairwise interaction between dimensions in the input vector. * * @param factorCount * The number of factors to use. Zero means no factors. Cannot be * negative. */ public void setFactorCount( final int factorCount) { ArgumentChecker.assertIsNonNegative("factorCount", factorCount); this.factorCount = factorCount; } /** * Gets whether or not the bias term is enabled. If it is not enabled, * it will default to zero and not be updated. * * @return * True if the bias term is enabled, otherwise false. */ public boolean isBiasEnabled() { return this.biasEnabled; } /** * Sets whether or not the bias term is enabled. If it is not enabled, * it will default to zero and not be updated. * * @param biasEnabled * True if the bias term is enabled, otherwise false. */ public void setBiasEnabled( final boolean biasEnabled) { this.biasEnabled = biasEnabled; } /** * Gets whether or not the linear weight term is enabled. If it is not * enabled, it will default to a zero vector and not be updated. * * @return * True if the linear term is enabled, otherwise false. */ public boolean isWeightsEnabled() { return this.weightsEnabled; } /** * Sets whether or not the linear weight term is enabled. If it is not * enabled, it will default to a zero vector and not be updated. * * @param weightsEnabled * True if the linear term is enabled, otherwise false. */ public void setWeightsEnabled( final boolean weightsEnabled) { this.weightsEnabled = weightsEnabled; } /** * Gets whether or not the factors are enabled. This is true when the * number of factors is greater than zero. * * @return * True if the number of factors is greater than zero, otherwise * false. */ public boolean isFactorsEnabled() { return this.getFactorCount() > 0; } /** * Gets the value for the parameter controlling the bias regularization. * * @return * The regularization term for the bias. Cannot be negative. */ public double getBiasRegularization() { return this.biasRegularization; } /** * Sets the value for the parameter controlling the bias regularization. * * @param biasRegularization * The regularization term for the bias. Cannot be negative. */ public void setBiasRegularization( final double biasRegularization) { ArgumentChecker.assertIsNonNegative("biasRegularization", biasRegularization); this.biasRegularization = biasRegularization; } /** * Gets the value for the parameter controlling the linear weight * regularization. * * @return * The regularization term for the weights. Cannot be negative. */ public double getWeightRegularization() { return this.weightRegularization; } /** * * Sets the value for the parameter controlling the linear weight * regularization. * * @param weightRegularization * The regularization term for the weights. Cannot be negative. */ public void setWeightRegularization( final double weightRegularization) { ArgumentChecker.assertIsNonNegative("weightRegularization", weightRegularization); this.weightRegularization = weightRegularization; } /** * Gets the value for the parameter controlling the factor matrix * regularization. * * @return * The regularization term for the factors. Cannot be negative. */ public double getFactorRegularization() { return this.factorRegularization; } /** * Sets the value for the parameter controlling the factor matrix * regularization. * * @param factorRegularization * The regularization term for the factors. Cannot be negative. */ public void setFactorRegularization( final double factorRegularization) { ArgumentChecker.assertIsNonNegative("factorRegularization", factorRegularization); this.factorRegularization = factorRegularization; } /** * Gets the seed initialization scale. It is multiplied by a random * Gaussian to initialize each factor value. * * @return * The random initialization scale for the factors. Cannot be negative. */ public double getSeedScale() { return this.seedScale; } /** * Sets the seed initialization scale. It is multiplied by a random * Gaussian to initialize each factor value. * * @param seedScale * The random initialization scale for the factors. Cannot be negative. */ public void setSeedScale( final double seedScale) { ArgumentChecker.assertIsNonNegative("seedScale", seedScale); this.seedScale = seedScale; } @Override public Random getRandom() { return this.random; } @Override public void setRandom( final Random random) { this.random = random; } }