/*
* File: FactorizationMachineTest.java
* Authors: Justin Basilico
* Project: Cognitive Foundry
*
* Copyright 2013 Cognitive Foundry. All rights reserved.
*/
package gov.sandia.cognition.learning.algorithm.factor.machine;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.mtj.Vector1;
import gov.sandia.cognition.math.matrix.mtj.Vector3;
import java.util.Arrays;
import java.util.Random;
import org.junit.Test;
import static org.junit.Assert.*;
/**
* Unit tests for {@link FactorizationMachine}.
*
* @author Justin Basilico
* @since 3.4.0
*/
public class FactorizationMachineTest
extends Object
{
protected Random random = new Random(4444);
protected double epsilon = 1e-5;
/**
* Creates a new test.
*/
public FactorizationMachineTest()
{
}
/**
* Test of constructors, of class FactorizationMachine.
*/
@Test
public void testConstructors()
{
double bias = 0.0;
Vector weights = null;
Matrix factors = null;
FactorizationMachine instance = new FactorizationMachine();
assertEquals(bias, instance.getBias(), 0.0);
assertSame(weights, instance.getWeights());
assertSame(factors, instance.getFactors());
instance = new FactorizationMachine(12, 5);
assertEquals(bias, instance.getBias(), 0.0);
assertEquals(12, instance.getWeights().getDimensionality());
assertEquals(5, instance.getFactors().getNumRows());
assertEquals(12, instance.getFactors().getNumColumns());
assertEquals(0.0, instance.getWeights().sum(), 0.0);
assertEquals(0.0, instance.getFactors().sumOfRows().sum(), 0.0);
bias = 0.4;
weights = VectorFactory.getSparseDefault().createVector(11);
factors = MatrixFactory.getSparseDefault().createMatrix(11, 6);
instance = new FactorizationMachine(bias, weights, factors);
assertEquals(bias, instance.getBias(), 0.0);
assertSame(weights, instance.getWeights());
assertSame(factors, instance.getFactors());
}
/**
* Test of clone method, of class FactorizationMachine.
*/
@Test
public void testClone()
{
FactorizationMachine instance = new FactorizationMachine();
FactorizationMachine clone = instance.clone();
assertNotSame(instance, clone);
assertNotNull(clone);
assertNotSame(clone, instance.clone());
assertEquals(instance.getBias(), clone.getBias(), 0.0);
assertNull(clone.getWeights());
assertNull(clone.getFactors());
instance.setBias(3.4);
instance.setWeights(VectorFactory.getDefault().createUniformRandom(
5, -1, 1, random));
instance.setFactors(MatrixFactory.getDefault().createUniformRandom(
3, 5, -1, 1, random));
clone = instance.clone();
assertNotSame(instance, clone);
assertNotNull(clone);
assertNotSame(clone, instance.clone());
assertEquals(instance.getBias(), clone.getBias(), 0.0);
assertEquals(instance.getWeights(), clone.getWeights());
assertEquals(instance.getFactors(), clone.getFactors());
assertNotSame(instance.getWeights(), clone.getWeights());
assertNotSame(instance.getFactors(), clone.getFactors());
}
/**
* Test of evaluateAsDouble method, of class FactorizationMachine.
*/
@Test
public void testEvaluateAsDouble()
{
int d = 1 + this.random.nextInt(10);
int k = 1 + this.random.nextInt(d - 1);
Vector x1 = VectorFactory.getDefault().createUniformRandom(d, -10, 10, random);
Vector x2 = VectorFactory.getDefault().createUniformRandom(d, -10, 10, random);
FactorizationMachine instance = new FactorizationMachine();
assertEquals(0.0, instance.evaluateAsDouble(null), 0.0);
double b = this.random.nextGaussian();
instance.setBias(b);
assertEquals(b, instance.evaluateAsDouble(x1), 0.0);
assertEquals(b, instance.evaluateAsDouble(x2), 0.0);
Vector w = VectorFactory.getDefault().createVector(d);
instance.setWeights(w);
assertEquals(b, instance.evaluateAsDouble(x1), 0.0);
assertEquals(b, instance.evaluateAsDouble(x2), 0.0);
w = VectorFactory.getDefault().createUniformRandom(d, -1, 1, random);
instance.setWeights(w);
assertEquals(b + w.dotProduct(x1), instance.evaluateAsDouble(x1), 0.0);
assertEquals(b + w.dotProduct(x2), instance.evaluateAsDouble(x2), 0.0);
Matrix v = MatrixFactory.getDenseDefault().createMatrix(k, d);
instance.setFactors(v);
assertEquals(b + w.dotProduct(x1), instance.evaluateAsDouble(x1), 0.0);
assertEquals(b + w.dotProduct(x2), instance.evaluateAsDouble(x2), 0.0);
v = MatrixFactory.getDenseDefault().createUniformRandom(k, d, -1, 1,
random);
instance.setFactors(v);
// This is the way in the base formula to compute it that is O(kn^2)
// rather than the way it is really computed as O(kn).
for (Vector x : Arrays.asList(x1, x2))
{
double expected = b + w.dotProduct(x);
for (int i = 0; i < d; i++)
{
for (int j = i + 1; j < d; j++)
{
expected += x.getElement(i) * x.getElement(j)
* v.getColumn(i).dotProduct(v.getColumn(j));
}
}
assertEquals(expected, instance.evaluateAsDouble(x), epsilon);
}
}
/**
* Test of getInputDimensionality method, of class FactorizationMachine.
*/
@Test
public void testGetInputDimensionality()
{
FactorizationMachine instance = new FactorizationMachine();
assertEquals(0, instance.getInputDimensionality());
instance.setWeights(new Vector3());
assertEquals(3, instance.getInputDimensionality());
instance.setWeights(null);
instance.setFactors(MatrixFactory.getDefault().createMatrix(12, 4));
assertEquals(4, instance.getInputDimensionality());
instance = new FactorizationMachine(4, 7);
assertEquals(4, instance.getInputDimensionality());
}
/**
* Test of getFactorCount method, of class FactorizationMachine.
*/
@Test
public void testGetFactorCount()
{
FactorizationMachine instance = new FactorizationMachine();
assertEquals(0, instance.getFactorCount());
instance.setFactors(MatrixFactory.getDefault().createMatrix(12, 4));
assertEquals(12, instance.getFactorCount());
instance = new FactorizationMachine(4, 7);
assertEquals(7, instance.getFactorCount());
}
/**
* Test of computeParameterGradient method, of class FactorizationMachine.
*/
@Test
public void testComputeParameterGradient()
{
VectorFactory<?> vf = VectorFactory.getSparseDefault();
FactorizationMachine instance = new FactorizationMachine();
Vector input = vf.createVector(0);
Vector result = instance.computeParameterGradient(input);
assertEquals(1, result.getDimensionality());
assertEquals(1.0, result.getElement(0), 0.0);
int d = 3;
instance.setWeights(VectorFactory.getDenseDefault().createVector(d));
input = vf.createUniformRandom(d, -10, 10, random);
result = instance.computeParameterGradient(input);
assertEquals(1 + d, result.getDimensionality());
assertEquals(1.0, result.getElement(0), 0.0);
assertEquals(input, result.subVector(1, d));
int k = 2;
instance.setFactors(MatrixFactory.getDenseDefault().createUniformRandom(k, d, -10, 10, random));
input = vf.createUniformRandom(d, -10, 10, random);
result = instance.computeParameterGradient(input);
assertEquals(10, result.getDimensionality());
assertEquals(1.0, result.getElement(0), 0.0);
assertEquals(input, result.subVector(1, d));
Vector factorGradients = result.subVector(d + 1, d + d * k);
for (int f = 0; f < k; f++)
{
for (int l = 0; l < d; l++)
{
double actual = factorGradients.getElement(f * d + l);
double expected = 0.0;
for (int j = 0; j < d; j++)
{
if (j != l)
{
double xl = input.getElement(l);
expected += xl * instance.getFactors().getElement(f, j) * input.getElement(j);
}
}
assertEquals(expected, actual, epsilon);
}
}
}
/**
* Test of getParameterCount method, of class FactorizationMachine.
*/
@Test
public void testGetParameterCount()
{
FactorizationMachine instance = new FactorizationMachine();
assertEquals(1, instance.getParameterCount());
instance.setFactors(MatrixFactory.getDefault().createMatrix(12, 4));
assertEquals(1 + 12 * 4, instance.getParameterCount());
instance.setFactors(null);
instance.setWeights(VectorFactory.getDefault().createVector(4));
assertEquals(1 + 4, instance.getParameterCount());
instance = new FactorizationMachine(4, 7);
assertEquals(1 + 4 + 4 * 7, instance.getParameterCount());
}
/**
* Test of convertToVector method, of class FactorizationMachine.
*/
@Test
public void testConvertToVector()
{
FactorizationMachine instance = new FactorizationMachine();
Vector result = instance.convertToVector();
assertEquals(instance.getParameterCount(), result.getDimensionality());
assertTrue(result.isZero());
int d = 7;
int k = 4;
instance = new FactorizationMachine(d, k);
result = instance.convertToVector();
assertEquals(instance.getParameterCount(), result.getDimensionality());
assertTrue(result.isZero());
double bias = this.random.nextGaussian();
Vector weights = VectorFactory.getDefault().createUniformRandom(d, -1, 1, random);
Matrix factors = MatrixFactory.getDefault().createUniformRandom(k, d, -1, 1, random);
instance = new FactorizationMachine(bias, weights.clone(), factors.clone());
result = instance.convertToVector();
assertEquals(instance.getParameterCount(), result.getDimensionality());
assertTrue(result.equals(new Vector1(bias).stack(weights).stack(factors.transpose().convertToVector())));
// Try with weights disabled.
instance.setWeights(null);
result = instance.convertToVector();
assertEquals(instance.getParameterCount(), result.getDimensionality());
assertTrue(result.equals(new Vector1(bias).stack(factors.transpose().convertToVector())));
// Try with factors disabled.
instance.setWeights(weights.clone());
instance.setFactors(null);
result = instance.convertToVector();
assertEquals(instance.getParameterCount(), result.getDimensionality());
assertTrue(result.equals(new Vector1(bias).stack(weights)));
}
/**
* Test of convertFromVector method, of class FactorizationMachine.
*/
@Test
public void testConvertFromVector()
{
FactorizationMachine instance = new FactorizationMachine();
Vector converted = instance.convertToVector();
Vector expected = converted.clone();
instance.convertFromVector(converted);
assertTrue(expected.equals(instance.convertToVector()));
int d = 7;
int k = 4;
instance = new FactorizationMachine(d, k);
converted = instance.convertToVector();
expected = converted.clone();
instance.convertFromVector(converted);
assertTrue(expected.equals(instance.convertToVector()));
double bias = this.random.nextGaussian();
Vector weights = VectorFactory.getDefault().createUniformRandom(d, -1, 1, random);
Matrix factors = MatrixFactory.getDefault().createUniformRandom(k, d, -1, 1, random);
instance = new FactorizationMachine(bias, weights.clone(), factors.clone());
converted = instance.convertToVector();
expected = converted.clone();
instance.convertFromVector(converted);
assertEquals(expected, instance.convertToVector());
assertEquals(bias, instance.getBias(), 0.0);
assertEquals(weights, instance.getWeights());
assertEquals(factors, instance.getFactors());
instance = new FactorizationMachine(d, k);
instance.convertFromVector(converted);
assertTrue(expected.equals(instance.convertToVector()));
assertEquals(bias, instance.getBias(), 0.0);
assertEquals(weights, instance.getWeights());
assertEquals(factors, instance.getFactors());
// Try with weights disabled.
instance.setWeights(null);
converted = instance.convertToVector();
expected = converted.clone();
instance.convertFromVector(converted);
assertTrue(expected.equals(instance.convertToVector()));
instance.setBias(0.0);
instance.getFactors().zero();
instance.convertFromVector(converted);
assertTrue(expected.equals(instance.convertToVector()));
assertEquals(bias, instance.getBias(), 0.0);
assertNull(instance.getWeights());
assertEquals(factors, instance.getFactors());
// Try with factors disabled.
instance.setWeights(weights.clone());
instance.setFactors(null);
converted = instance.convertToVector();
expected = converted.clone();
instance.convertFromVector(converted);
assertTrue(expected.equals(instance.convertToVector()));
instance.setBias(0.0);
instance.getWeights().zero();
instance.convertFromVector(converted);
assertTrue(expected.equals(instance.convertToVector()));
assertEquals(bias, instance.getBias(), 0.0);
assertEquals(weights, instance.getWeights());
assertNull(instance.getFactors());
}
/**
* Test of hasWeights method, of class FactorizationMachine.
*/
@Test
public void testHasWeights()
{
FactorizationMachine instance = new FactorizationMachine();
assertFalse(instance.hasWeights());
instance.setWeights(new Vector3());
assertTrue(instance.hasWeights());
instance.setWeights(null);
assertFalse(instance.hasWeights());
instance = new FactorizationMachine(4, 7);
assertTrue(instance.hasWeights());
}
/**
* Test of hasFactors method, of class FactorizationMachine.
*/
@Test
public void testHasFactors()
{
FactorizationMachine instance = new FactorizationMachine();
assertFalse(instance.hasFactors());
instance.setFactors(MatrixFactory.getDefault().createMatrix(12, 4));
assertTrue(instance.hasFactors());
instance.setFactors(null);
assertFalse(instance.hasFactors());
instance = new FactorizationMachine(4, 7);
assertTrue(instance.hasFactors());
}
/**
* Test of getBias method, of class FactorizationMachine.
*/
@Test
public void testGetBias()
{
this.testSetBias();
}
/**
* Test of setBias method, of class FactorizationMachine.
*/
@Test
public void testSetBias()
{
double bias = 0.0;
FactorizationMachine instance = new FactorizationMachine();
assertEquals(bias, instance.getBias(), 0.0);
double[] values = {0.4, -0.4, 40, -40};
for (double value : values)
{
bias = value;
instance.setBias(bias);
assertEquals(bias, instance.getBias(), 0.0);
}
}
/**
* Test of getWeights method, of class FactorizationMachine.
*/
@Test
public void testGetWeights()
{
this.testSetWeights();;
}
/**
* Test of setWeights method, of class FactorizationMachine.
*/
@Test
public void testSetWeights()
{
Vector weights = null;
FactorizationMachine instance = new FactorizationMachine();
assertSame(weights, instance.getWeights());
weights = VectorFactory.getSparseDefault().createVector(11);
instance.setWeights(weights);
assertSame(weights, instance.getWeights());
weights = null;
instance.setWeights(weights);
assertSame(weights, instance.getWeights());
}
/**
* Test of getFactors method, of class FactorizationMachine.
*/
@Test
public void testGetFactors()
{
this.testSetFactors();
}
/**
* Test of setFactors method, of class FactorizationMachine.
*/
@Test
public void testSetFactors()
{
Matrix factors = null;
FactorizationMachine instance = new FactorizationMachine();
assertSame(factors, instance.getFactors());
factors = MatrixFactory.getSparseDefault().createMatrix(11, 6);
instance.setFactors(factors);
assertSame(factors, instance.getFactors());
factors = null;
instance.setFactors(factors);
assertSame(factors, instance.getFactors());
}
}