/*
* File: FactorizationMachineAlternatingLeastSquares.java
* Authors: Justin Basilico
* Project: Cognitive Foundry
*
* Copyright 2015 Cognitive Foundry. All rights reserved.
*/
package gov.sandia.cognition.learning.algorithm.factor.machine;
import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.algorithm.event.AbstractIterativeAlgorithmListener;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.performance.MeanSquaredErrorEvaluator;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.util.NamedValue;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Unit tests for class {@link FactorizationMachineAlternatingLeastSquares}.
*
* @author Justin Basilico
* @since 3.4.1
*/
public class FactorizationMachineAlternatingLeastSquaresTest
extends Object
{
public static final NumberFormat NUMBER_FORMAT = new DecimalFormat("0.0000");
protected Random random = new Random(47474747);
/**
* Creates a new test.
*/
public FactorizationMachineAlternatingLeastSquaresTest()
{
super();
}
/**
* Test of constructors, of class FactorizationMachineAlternatingLeastSquares.
*/
@Test
public void testConstructors()
{
int factorCount = FactorizationMachineAlternatingLeastSquares.DEFAULT_FACTOR_COUNT;
double biasRegularization = FactorizationMachineAlternatingLeastSquares.DEFAULT_BIAS_REGULARIZATION;
double weightRegularization = FactorizationMachineAlternatingLeastSquares.DEFAULT_WEIGHT_REGULARIZATION;
double factorRegularization = FactorizationMachineAlternatingLeastSquares.DEFAULT_FACTOR_REGULARIZATION;
double seedScale = FactorizationMachineAlternatingLeastSquares.DEFAULT_SEED_SCALE;
int maxIterations = FactorizationMachineAlternatingLeastSquares.DEFAULT_MAX_ITERATIONS;
double minChange = FactorizationMachineAlternatingLeastSquares.DEFAULT_MIN_CHANGE;
FactorizationMachineAlternatingLeastSquares instance =
new FactorizationMachineAlternatingLeastSquares();
assertEquals(factorCount, instance.getFactorCount());
assertEquals(biasRegularization, instance.getBiasRegularization(), 0.0);
assertEquals(weightRegularization, instance.getWeightRegularization(), 0.0);
assertEquals(factorRegularization, instance.getFactorRegularization(), 0.0);
assertEquals(seedScale, instance.getSeedScale(), 0.0);
assertEquals(maxIterations, instance.getMaxIterations());
assertEquals(minChange, instance.getMinChange(), 0.0);
assertNotNull(instance.getRandom());
assertSame(instance.getRandom(), instance.getRandom());
factorCount = 22;
biasRegularization = 3.33;
weightRegularization = 44.44;
factorRegularization = 555.55;
seedScale = 0.6;
maxIterations = 777;
minChange = 0.88;
Random random = new Random();
instance = new FactorizationMachineAlternatingLeastSquares(factorCount,
biasRegularization, weightRegularization, factorRegularization,
seedScale, maxIterations, minChange, random);
assertEquals(factorCount, instance.getFactorCount());
assertEquals(biasRegularization, instance.getBiasRegularization(), 0.0);
assertEquals(weightRegularization, instance.getWeightRegularization(), 0.0);
assertEquals(factorRegularization, instance.getFactorRegularization(), 0.0);
assertEquals(seedScale, instance.getSeedScale(), 0.0);
assertEquals(maxIterations, instance.getMaxIterations());
assertEquals(minChange, instance.getMinChange(), 0.0);
assertSame(random, instance.getRandom());
// No negative factor counts.
boolean exceptionThrown = false;
try
{
instance = new FactorizationMachineAlternatingLeastSquares(-1,
biasRegularization, weightRegularization, factorRegularization,
seedScale, maxIterations, minChange, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No negative bias regularization.
exceptionThrown = false;
try
{
instance = new FactorizationMachineAlternatingLeastSquares(factorCount,
-1.0, weightRegularization, factorRegularization,
seedScale, maxIterations, minChange, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No negative weight regularization.
exceptionThrown = false;
try
{
instance = new FactorizationMachineAlternatingLeastSquares(factorCount,
biasRegularization, -1.0, factorRegularization,
seedScale, maxIterations, minChange, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No negative factor regularization.
exceptionThrown = false;
try
{
instance = new FactorizationMachineAlternatingLeastSquares(factorCount,
biasRegularization, weightRegularization, -1.0,
seedScale, maxIterations, minChange, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No negative seed scale.
exceptionThrown = false;
try
{
instance = new FactorizationMachineAlternatingLeastSquares(factorCount,
biasRegularization, weightRegularization, factorRegularization,
-1.0, maxIterations, minChange, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No negative max iterations.
exceptionThrown = false;
try
{
instance = new FactorizationMachineAlternatingLeastSquares(factorCount,
biasRegularization, weightRegularization, factorRegularization,
seedScale, -1, minChange, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
// No negative min change.
exceptionThrown = false;
try
{
instance = new FactorizationMachineAlternatingLeastSquares(factorCount,
biasRegularization, weightRegularization, factorRegularization,
seedScale, maxIterations, -0.1, random);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
}
/**
* Test of learn method, of class FactorizationMachineAlternatingLeastSquares.
*/
@Test
public void testLearn()
{
System.out.println("learn");
boolean useBias = true;
boolean useWeights = true;
boolean useFactors = true;
int n = 500;
int d = 5;
int k = 2;
FactorizationMachine actual = new FactorizationMachine(d, k);
actual.setBias(this.random.nextGaussian() * 10.0 * (useBias ? 1.0 : 0.0));
actual.setWeights(VectorFactory.getDenseDefault().createUniformRandom(d,
-1.0, 1.0, this.random).scale(useWeights ? 1.0 : 0.0));
actual.setFactors(MatrixFactory.getDenseDefault().createUniformRandom(k,
d, -1.0, 1.0, this.random).scale(useFactors ? 1.0 : 0.0));
int trainSize = n;
int testSize = n;
int totalSize = trainSize + testSize;
List<InputOutputPair<Vector, Double>> trainData = new ArrayList<InputOutputPair<Vector, Double>>();
final List<InputOutputPair<Vector, Double>> testData = new ArrayList<InputOutputPair<Vector, Double>>();
for (int i = 0; i < totalSize; i++)
{
Vector input = VectorFactory.getDenseDefault().createUniformRandom(
d, -10.0, 10.0, this.random);
final DefaultInputOutputPair<Vector, Double> example =
DefaultInputOutputPair.create(input, actual.evaluateAsDouble(input));
if (i < trainSize)
{
trainData.add(example);
}
else
{
testData.add(example);
}
}
FactorizationMachineAlternatingLeastSquares instance =
new FactorizationMachineAlternatingLeastSquares();
instance.setFactorCount(useFactors ? k : 0);
instance.setSeedScale(0.2);
instance.setBiasRegularization(0.0);
instance.setWeightRegularization(0.01);
instance.setFactorRegularization(0.1);
instance.setMaxIterations(1000);
instance.setMinChange(1e-4);
instance.setWeightsEnabled(useWeights);
instance.setBiasEnabled(useBias);
instance.setRandom(random);
// instance.addIterativeAlgorithmListener(new IterationMeasurablePerformanceReporter());
// TODO: Part of this may be good as a general class (printing validation metrics).
instance.addIterativeAlgorithmListener(new AbstractIterativeAlgorithmListener()
{
@Override
public void stepEnded(IterativeAlgorithm algorithm)
{
final FactorizationMachineAlternatingLeastSquares a = (FactorizationMachineAlternatingLeastSquares) algorithm;
MeanSquaredErrorEvaluator<Vector> performance =
new MeanSquaredErrorEvaluator<>();
System.out.println("Iteration " + a.getIteration()
+ " RMSE: Train: " + NUMBER_FORMAT.format(Math.sqrt(performance.evaluatePerformance(a.getResult(), a.getData())))
+ " Validation: " + NUMBER_FORMAT.format(Math.sqrt(performance.evaluatePerformance(a.getResult(), testData)))
+ " Objective: " + NUMBER_FORMAT.format(a.getObjective())
+ " Change: " + NUMBER_FORMAT.format(a.getTotalChange())
+ " Error: " + NUMBER_FORMAT.format(Math.sqrt(a.getTotalError() / a.getData().size()))
+ " Regularization: " + NUMBER_FORMAT.format(a.getRegularizationPenalty()));
}
});
// TODO: Figure out why this doesn't work with real factors.
FactorizationMachine result = instance.learn(trainData);
System.out.println(actual.getBias());
System.out.println(actual.getWeights());
System.out.println(actual.getFactors());
System.out.println(result.getBias());
System.out.println(result.getWeights());
System.out.println(result.getFactors());
MeanSquaredErrorEvaluator<Vector> performance = new MeanSquaredErrorEvaluator<>();
System.out.println("RMSE: " + Math.sqrt(performance.evaluatePerformance(result, testData)));
assertTrue(Math.sqrt(performance.evaluatePerformance(result, testData)) < 0.05);
}
/**
* Test of getRegularizationPenalty method, of class FactorizationMachineAlternatingLeastSquares.
*/
@Test
public void testGetRegularizationPenalty()
{
FactorizationMachineAlternatingLeastSquares instance =
new FactorizationMachineAlternatingLeastSquares();
assertEquals(0.0, instance.getRegularizationPenalty(), 0.0);
instance.result = new FactorizationMachine(10, 5);
assertEquals(0.0, instance.getRegularizationPenalty(), 0.0);
instance.result.setBias(3);
assertEquals(0.0, instance.getRegularizationPenalty(), 0.0);
instance.setBiasRegularization(2.0);
assertEquals(6.0, instance.getRegularizationPenalty(), 0.0);
instance.setBiasRegularization(0);
instance.result.getWeights().set(0, 4);
instance.result.getWeights().set(2, 3);
instance.setWeightRegularization(2.0);
assertEquals(50.0, instance.getRegularizationPenalty(), 0.0);
instance.setWeightRegularization(0.0);
assertEquals(0.0, instance.getRegularizationPenalty(), 0.0);
instance.result.getFactors().set(0, 0, 5);
instance.result.getFactors().set(4, 2, 2);
instance.setFactorRegularization(3.0);
assertEquals(87.0, instance.getRegularizationPenalty(), 1e-10);
instance.setFactorRegularization(0.0);
assertEquals(0.0, instance.getRegularizationPenalty(), 0.0);
instance.setWeightRegularization(2.0);
instance.setFactorRegularization(3.0);
assertEquals(137.0, instance.getRegularizationPenalty(), 1e-10);
}
/**
* Test of getPerformance method, of class FactorizationMachineAlternatingLeastSquares.
*/
@Test
public void testGetPerformance()
{
FactorizationMachineAlternatingLeastSquares instance =
new FactorizationMachineAlternatingLeastSquares();
NamedValue<? extends Number> result = instance.getPerformance();
assertEquals("objective", result.getName());
assertEquals(0.0, result.getValue());
}
/**
* Test of getMinChange method, of class FactorizationMachineAlternatingLeastSquares.
*/
@Test
public void testGetMinChange()
{
this.testSetMinChange();
}
/**
* Test of setMinChange method, of class FactorizationMachineAlternatingLeastSquares.
*/
@Test
public void testSetMinChange()
{
double minChange = FactorizationMachineAlternatingLeastSquares.DEFAULT_MIN_CHANGE;
FactorizationMachineAlternatingLeastSquares instance =
new FactorizationMachineAlternatingLeastSquares();
assertEquals(minChange, instance.getMinChange(), 0.0);
minChange = 0.1;
instance.setMinChange(minChange);
assertEquals(minChange, instance.getMinChange(), 0.0);
double[] badValues = {-0.1, -2.2, Double.NaN};
for (double badValue : badValues)
{
boolean exceptionThrown = false;
try
{
instance.setMinChange(badValue);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(minChange, instance.getMinChange(), 0.0);
}
}
}