/*
* File: FactorizationMachineStochasticGradientTest.java
* Authors: Justin Basilico
* Project: Cognitive Foundry
*
* Copyright 2013 Cognitive Foundry. All rights reserved.
*/
package gov.sandia.cognition.learning.algorithm.factor.machine;
import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.performance.MeanSquaredErrorEvaluator;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Unit tests for class {@link FactorizationMachineStochasticGradient}.
*
* @author Justin Basilico
* @since 3.4.0
*/
public class FactorizationMachineStochasticGradientTest
extends Object
{
protected final NumberFormat NUMBER_FORMAT = new DecimalFormat("0.0000");
protected Random random = new Random(47474747);
/**
* Creates a new test.
*/
public FactorizationMachineStochasticGradientTest()
{
super();
}
/**
* Test of constructors of class FactorizationMachineStochasticGradient.
*/
@Test
public void testConstructors()
{
int factorCount = FactorizationMachineStochasticGradient.DEFAULT_FACTOR_COUNT;
double learningRate = FactorizationMachineStochasticGradient.DEFAULT_LEARNING_RATE;
double biasRegularization = FactorizationMachineStochasticGradient.DEFAULT_BIAS_REGULARIZATION;
double weightRegularization = FactorizationMachineStochasticGradient.DEFAULT_WEIGHT_REGULARIZATION;
double factorRegularization = FactorizationMachineStochasticGradient.DEFAULT_FACTOR_REGULARIZATION;
double seedScale = FactorizationMachineStochasticGradient.DEFAULT_SEED_SCALE;
int maxIterations = FactorizationMachineStochasticGradient.DEFAULT_MAX_ITERATIONS;
FactorizationMachineStochasticGradient instance =
new FactorizationMachineStochasticGradient();
assertEquals(factorCount, instance.getFactorCount());
assertEquals(learningRate, instance.getLearningRate(), 0.0);
assertEquals(biasRegularization, instance.getBiasRegularization(), 0.0);
assertEquals(weightRegularization, instance.getWeightRegularization(), 0.0);
assertEquals(factorRegularization, instance.getFactorRegularization(), 0.0);
assertEquals(seedScale, instance.getSeedScale(), 0.0);
assertEquals(maxIterations, instance.getMaxIterations());
assertNotNull(instance.getRandom());
assertSame(instance.getRandom(), instance.getRandom());
factorCount = 22;
learningRate = 0.12321;
biasRegularization = 3.33;
weightRegularization = 44.44;
factorRegularization = 555.55;
seedScale = 0.6;
maxIterations = 777;
Random random = new Random();
instance = new FactorizationMachineStochasticGradient(factorCount,
learningRate, biasRegularization, weightRegularization, factorRegularization,
seedScale, maxIterations, random);
assertEquals(factorCount, instance.getFactorCount());
assertEquals(biasRegularization, instance.getBiasRegularization(), 0.0);
assertEquals(weightRegularization, instance.getWeightRegularization(), 0.0);
assertEquals(factorRegularization, instance.getFactorRegularization(), 0.0);
assertEquals(seedScale, instance.getSeedScale(), 0.0);
assertEquals(maxIterations, instance.getMaxIterations());
assertSame(random, instance.getRandom());
// No negative factor counts.
boolean exceptionThrown = false;
try
{
instance = new FactorizationMachineStochasticGradient(-1, learningRate,
biasRegularization, weightRegularization, factorRegularization,
seedScale, maxIterations, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No zero learning rate.
exceptionThrown = false;
try
{
instance = new FactorizationMachineStochasticGradient(factorCount, 0,
biasRegularization, weightRegularization, factorRegularization,
seedScale, maxIterations, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No negative bias regularization.
exceptionThrown = false;
try
{
instance = new FactorizationMachineStochasticGradient(factorCount, learningRate,
-1.0, weightRegularization, factorRegularization,
seedScale, maxIterations, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No negative weight regularization.
exceptionThrown = false;
try
{
instance = new FactorizationMachineStochasticGradient(factorCount, learningRate,
biasRegularization, -1.0, factorRegularization,
seedScale, maxIterations, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No negative factor regularization.
exceptionThrown = false;
try
{
instance = new FactorizationMachineStochasticGradient(factorCount, learningRate,
biasRegularization, weightRegularization, -1.0,
seedScale, maxIterations, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No negative seed scale.
exceptionThrown = false;
try
{
instance = new FactorizationMachineStochasticGradient(factorCount, learningRate,
biasRegularization, weightRegularization, factorRegularization,
-1.0, maxIterations, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No negative max iterations.
exceptionThrown = false;
try
{
instance = new FactorizationMachineStochasticGradient(factorCount, learningRate,
biasRegularization, weightRegularization, factorRegularization,
seedScale, -1, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
}
/**
* Test of step method, of class FactorizationMachineStochasticGradient.
*/
@Test
public void testLearn()
{
System.out.println("learn");
int n = 400;
int d = 5;
int k = 2;
FactorizationMachine actual = new FactorizationMachine(d, k);
actual.setBias(this.random.nextGaussian() * 10.0);
actual.setWeights(VectorFactory.getDenseDefault().createUniformRandom(d,
-1.0, 1.0, this.random));
actual.setFactors(MatrixFactory.getDenseDefault().createUniformRandom(k,
d, -1.0, 1.0, this.random));
int trainSize = n;
int testSize = n;
int totalSize = trainSize + testSize;
List<InputOutputPair<Vector, Double>> trainData = new ArrayList<InputOutputPair<Vector, Double>>();
final List<InputOutputPair<Vector, Double>> testData = new ArrayList<InputOutputPair<Vector, Double>>();
for (int i = 0; i < totalSize; i++)
{
Vector input = VectorFactory.getDenseDefault().createUniformRandom(
d, -10.0, 10.0, this.random);
final DefaultInputOutputPair<Vector, Double> example =
DefaultInputOutputPair.create(input, actual.evaluateAsDouble(input));
if (i < trainSize)
{
trainData.add(example);
}
else
{
testData.add(example);
}
}
FactorizationMachineStochasticGradient instance =
new FactorizationMachineStochasticGradient();
instance.setFactorCount(k);
instance.setSeedScale(0.2);
instance.setBiasRegularization(0.0);
instance.setWeightRegularization(0.01);
instance.setFactorRegularization(0.1);
instance.setLearningRate(0.005);
instance.setMaxIterations(1000);
instance.setRandom(random);
// instance.addIterativeAlgorithmListener(new IterationMeasurablePerformanceReporter());
instance.addIterativeAlgorithmListener(new AbstractIterativeAlgorithmListener()
{
@Override
public void stepEnded(IterativeAlgorithm algorithm)
{
final FactorizationMachineStochasticGradient a = (FactorizationMachineStochasticGradient) algorithm;
MeanSquaredErrorEvaluator<Vector> performance =
new MeanSquaredErrorEvaluator<Vector>();
System.out.println("Iteration " + a.getIteration()
+ " RMSE: Train: " + NUMBER_FORMAT.format(Math.sqrt(performance.evaluatePerformance(a.getResult(), a.getData())))
+ " Validation: " + NUMBER_FORMAT.format(Math.sqrt(performance.evaluatePerformance(a.getResult(), testData)))
+ " Objective: " + NUMBER_FORMAT.format(a.getObjective())
+ " Change: " + NUMBER_FORMAT.format(a.getTotalChange())
+ " Error: " + NUMBER_FORMAT.format(Math.sqrt(a.getTotalError() / a.getData().size()))
+ " Regularization: " + NUMBER_FORMAT.format(a.getRegularizationPenalty()));
}
});
// TODO: Figure out why this doesn't work sometimes with real factors. Is it just learning rate?
FactorizationMachine result = instance.learn(trainData);
assertEquals(d, result.getInputDimensionality());
assertEquals(k, result.getFactorCount());
System.out.println(actual.getBias());
System.out.println(actual.getWeights());
System.out.println(actual.getFactors());
System.out.println(result.getBias());
System.out.println(result.getWeights());
System.out.println(result.getFactors());
MeanSquaredErrorEvaluator<Vector> performance =
new MeanSquaredErrorEvaluator<Vector>();
System.out.println("RMSE: " + Math.sqrt(performance.evaluatePerformance(result, testData)));
assertTrue(Math.sqrt(performance.evaluatePerformance(result, testData)) < 0.05);
}
/**
* Test of getLearningRate method, of class FactorizationMachineStochasticGradient.
*/
@Test
public void testGetLearningRate()
{
this.testSetLearningRate();
}
/**
* Test of setLearningRate method, of class FactorizationMachineStochasticGradient.
*/
@Test
public void testSetLearningRate()
{
double learningRate = FactorizationMachineStochasticGradient.DEFAULT_LEARNING_RATE;
FactorizationMachineStochasticGradient instance =
new FactorizationMachineStochasticGradient();
assertEquals(learningRate, instance.getLearningRate(), 0.0);
learningRate = 0.2;
instance.setLearningRate(learningRate);
assertEquals(learningRate, instance.getLearningRate(), 0.0);
double[] badValues = {0.0, -0.1, -2.2, Double.NaN };
for (double badValue : badValues)
{
boolean exceptionThrown = false;
try
{
instance.setLearningRate(badValue);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(learningRate, instance.getLearningRate(), 0.0);
}
}
}