/*
* File: GeneralizedLinearModelTest.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright February 28, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.function.vector;
import gov.sandia.cognition.learning.function.scalar.AtanFunction;
import gov.sandia.cognition.learning.function.scalar.IdentityScalarFunction;
import gov.sandia.cognition.learning.function.scalar.SigmoidFunction;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFunction;
import java.util.Random;
import junit.framework.TestCase;
/**
* Unit tests for GeneralizedLinearModelTest
*
* @author Kevin R. Dixon
* @since 1.0
*/
public class GeneralizedLinearModelTest
extends TestCase
{
/** The random number generator for the tests. */
public Random random = new Random(1);
public GeneralizedLinearModelTest(
String testName)
{
super(testName);
}
public DifferentiableGeneralizedLinearModel createRandom()
{
double A = 1.0;
int M = random.nextInt(10) + 1;
int N = random.nextInt(10) + 1;
Matrix m = MatrixFactory.getDefault().createUniformRandom(M, N, -A, A, random);
return new DifferentiableGeneralizedLinearModel(
new MultivariateDiscriminant(m),
new AtanFunction());
}
public void testConstructors()
{
System.out.println( "Constructors" );
GeneralizedLinearModel f =
new GeneralizedLinearModel();
assertEquals( 1, f.getInputDimensionality() );
assertEquals( 1, f.getOutputDimensionality() );
assertNotNull( f.getSquashingFunction() );
assertTrue(f.getSquashingFunction() instanceof ElementWiseVectorFunction);
assertTrue(((ElementWiseVectorFunction) f.getSquashingFunction()).getScalarFunction() instanceof IdentityScalarFunction);
SigmoidFunction s = new SigmoidFunction();
f = new GeneralizedLinearModel(5, 2, s);
assertEquals(5, f.getInputDimensionality());
assertEquals(2, f.getOutputDimensionality());
assertTrue(f.getSquashingFunction() instanceof ElementWiseVectorFunction);
assertSame(s, ((ElementWiseVectorFunction) f.getSquashingFunction()).getScalarFunction());
MultivariateDiscriminant d = new MultivariateDiscriminant(3, 4);
f = new GeneralizedLinearModel(d, s);
assertEquals(3, f.getInputDimensionality());
assertEquals(4, f.getOutputDimensionality());
assertSame(d, f.getDiscriminant());
assertTrue(f.getSquashingFunction() instanceof ElementWiseVectorFunction);
assertSame(s, ((ElementWiseVectorFunction) f.getSquashingFunction()).getScalarFunction());
VectorFunction v = new LinearVectorFunction();
f = new GeneralizedLinearModel(d, v);
assertEquals(3, f.getInputDimensionality());
assertEquals(4, f.getOutputDimensionality());
assertSame(d, f.getDiscriminant());
assertSame(v, f.getSquashingFunction());
f = new GeneralizedLinearModel(f);
assertEquals(3, f.getInputDimensionality());
assertEquals(4, f.getOutputDimensionality());
assertNotSame(d, f.getDiscriminant());
assertEquals(d.convertToVector(), f.getDiscriminant().convertToVector());
assertSame(v, f.getSquashingFunction());
}
/**
* Test of getDiscriminant method, of class gov.sandia.isrc.learning.util.function.GeneralizedLinearModel.
*/
public void testGetMatrixMultiply()
{
System.out.println("getMatrixMultiply");
GeneralizedLinearModel instance = this.createRandom();
assertNotNull(instance.getDiscriminant());
}
/**
* Test of setDiscriminant method, of class gov.sandia.isrc.learning.util.function.GeneralizedLinearModel.
*/
public void testSetMatrixMultiply()
{
System.out.println("setMatrixMultiply");
GeneralizedLinearModel instance = this.createRandom();
assertNotNull(instance.getDiscriminant());
Matrix m = MatrixFactory.getDefault().createUniformRandom(2, 3, -1, 1, random);
MultivariateDiscriminant mult =
new MultivariateDiscriminant(m);
assertNotSame(mult, instance.getDiscriminant());
instance.setDiscriminant(mult);
assertSame(mult, instance.getDiscriminant());
}
/**
* Test of getSquashingFunction method, of class gov.sandia.isrc.learning.util.function.GeneralizedLinearModel.
*/
public void testGetSquashingFunction()
{
System.out.println("getSquashingFunction");
GeneralizedLinearModel instance = this.createRandom();
assertNotNull(instance.getSquashingFunction());
}
/**
* Test of setSquashingFunction method, of class gov.sandia.isrc.learning.util.function.GeneralizedLinearModel.
*/
public void testSetSquashingFunction()
{
System.out.println("setSquashingFunction");
GeneralizedLinearModel instance = this.createRandom();
LinearVectorFunction f = new LinearVectorFunction(random.nextGaussian());
assertNotSame(f, instance.getSquashingFunction());
instance.setSquashingFunction(f);
assertSame(f, instance.getSquashingFunction());
}
/**
* Test of convertToVector method, of class gov.sandia.isrc.learning.util.function.GeneralizedLinearModel.
*/
public void testConvertToVector()
{
System.out.println("convertToVector");
GeneralizedLinearModel instance = this.createRandom();
Vector result = instance.convertToVector();
Vector expected = instance.getDiscriminant().convertToVector();
assertEquals(expected, result);
}
/**
* Test of convertFromVector method, of class gov.sandia.isrc.learning.util.function.GeneralizedLinearModel.
*/
public void testConvertFromVector()
{
System.out.println("convertFromVector");
GeneralizedLinearModel instance = this.createRandom();
Vector p1 = instance.convertToVector();
Vector p2 = p1.scale(random.nextGaussian());
instance.convertFromVector(p2);
Vector p3 = instance.convertToVector();
assertEquals(p2, p3);
}
/**
* Test of evaluate method, of class gov.sandia.isrc.learning.util.function.GeneralizedLinearModel.
*/
public void testEvaluate()
{
System.out.println("evaluate");
GeneralizedLinearModel instance = this.createRandom();
int M = instance.getDiscriminant().getDiscriminant().getNumRows();
int N = instance.getDiscriminant().getDiscriminant().getNumColumns();
Vector input = VectorFactory.getDefault().createUniformRandom(N, -10.0, 10.0, random);
Vector result = instance.evaluate(input);
assertEquals(M, result.getDimensionality());
Vector expected = instance.getSquashingFunction().evaluate(instance.getDiscriminant().getDiscriminant().times(input));
assertEquals(expected.getDimensionality(), result.getDimensionality());
assertEquals(expected, result);
}
/**
* Test of clone method, of class gov.sandia.isrc.learning.util.function.GeneralizedLinearModel.
*/
public void testClone()
{
System.out.println("clone");
GeneralizedLinearModel instance = this.createRandom();
GeneralizedLinearModel clone = instance.clone();
assertEquals(instance.getDiscriminant().getDiscriminant(), instance.getDiscriminant().getDiscriminant());
int N = instance.getDiscriminant().getDiscriminant().getNumColumns();
Vector input = VectorFactory.getDefault().createUniformRandom(N, -1.0, 1.0, random);
Vector v1 = instance.evaluate(input);
assertEquals(v1, clone.evaluate(input));
clone.getDiscriminant().getDiscriminant().setElement(0, 0, random.nextGaussian());
assertEquals(v1, instance.evaluate(input));
assertFalse(v1.equals(clone.evaluate(input)));
}
public void testDimensionality()
{
GeneralizedLinearModel instance = this.createRandom();
assertEquals( instance.getDiscriminant().getInputDimensionality(), instance.getInputDimensionality() );
assertEquals( instance.getDiscriminant().getOutputDimensionality(), instance.getOutputDimensionality() );
double r = 2.0;
Vector x = VectorFactory.getDefault().createUniformRandom(
instance.getInputDimensionality(), -r, r, random);
Vector y = instance.evaluate(x);
assertEquals( x.getDimensionality(), instance.getInputDimensionality() );
assertEquals( y.getDimensionality(), instance.getOutputDimensionality() );
}
}