/*
* File: MultivariateGaussianTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright March 27, 2006, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.statistics.distribution;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.math.MultivariateStatisticsUtil;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.mtj.DenseMatrix;
import gov.sandia.cognition.math.matrix.mtj.Vector3;
import gov.sandia.cognition.math.matrix.mtj.decomposition.CholeskyDecompositionMTJ;
import gov.sandia.cognition.statistics.MultivariateClosedFormComputableDistributionTestHarness;
import java.util.ArrayList;
/**
* This class implements JUnit tests for the following classes:
*
* MultivariateGaussian
*
* @author Justin Basilico
* @since 1.0
*/
@CodeReview(
reviewer="Jonathan McClain",
date="2006-05-16",
changesNeeded=false,
comments={
"Fixed missing documentation.",
"Replace calls to MathAssert.assertFuzzyEquals with assertEquals(double, double, delta)."
}
)
public class MultivariateGaussianTest
extends MultivariateClosedFormComputableDistributionTestHarness<Vector>
{
/**
* Creates a new instance of MultivariateGaussianTest.
*
* @param testName The name of the test.
*/
public MultivariateGaussianTest(
String testName )
{
super( testName );
}
@Override
public void testProbabilityFunctionKnownValues()
{
Vector3 input = new Vector3( 0.0, 0.0, 0.0 );
MultivariateGaussian.PDF instance = new MultivariateGaussian.PDF(
new Vector3( 0.0, 0.0, 0.0 ),
MatrixFactory.getDefault().createIdentity( 3, 3 ) );
double expResult = 0.063494;
double result = instance.evaluate( input );
assertEquals( expResult, result, TOLERANCE );
}
/**
* Test of getCovarianceInverse method, of class
* gov.sandia.isrc.math.MultivariateGaussian.
*/
public void testGetCovarianceInverse()
{
Vector3 mean = new Vector3( 1.0, -2.0, 3.0 );
Matrix covariance = MatrixFactory.getDefault().createIdentity( 3, 3 );
MultivariateGaussian instance =
new MultivariateGaussian( mean, covariance );
assertNotNull( instance.getCovarianceInverse() );
assertTrue( covariance.equals( instance.getCovarianceInverse(), TOLERANCE ) );
}
/**
* Test of maximumLikelihoodEstimate method, of class
* gov.sandia.isrc.math.MultivariateGaussian.
*/
public void testMaximumLikelihoodEstimate()
{
ArrayList<Vector> samples = new ArrayList<Vector>();
samples.add( new Vector3( 0.0, 0.0, 0.0 ) );
samples.add( new Vector3( 0.0, 0.1, 0.2 ) );
samples.add( new Vector3( 0.0, -0.1, 0.3 ) );
samples.add( new Vector3( 1.0, 0.0, 0.0 ) );
samples.add( new Vector3( -1.0, 0.0, 0.0 ) );
MultivariateGaussian estimate =
MultivariateGaussian.MaximumLikelihoodEstimator.learn( samples, 0.0 );
Vector3 expectedMean = new Vector3( 0.0, 0.0, 0.1 );
Vector3 covarianceColumn1 = new Vector3( 0.5, 0.0, 0.0 );
Vector3 covarianceColumn2 = new Vector3( 0.0, 0.005, -0.0025 );
Vector3 covarianceColumn3 = new Vector3( 0.0, -0.0025, 0.02 );
Matrix expectedCovariance = MatrixFactory.getDefault().copyColumnVectors(
covarianceColumn1, covarianceColumn2, covarianceColumn3 );
Matrix estimatedCovariance = estimate.getCovariance();
System.out.println( "Estimated:\n" + estimatedCovariance );
System.out.println( "Expected:\n" + expectedCovariance );
assertNotNull( estimate );
assertEquals( expectedMean, estimate.getMean() );
assertTrue( expectedCovariance.equals( estimatedCovariance, TOLERANCE ) );
}
/**
* Test of getCovariance method, of class gov.sandia.isrc.math.MultivariateGaussian.
*/
public void testGetCovariance()
{
System.out.println( "getCovariance" );
int N = 2;
double range = 1;
Vector mean = VectorFactory.getDefault().createUniformRandom( N, -range, range, RANDOM );
Matrix sqrt = MatrixFactory.getDefault().createUniformRandom( N, N, -range, range, RANDOM );
Matrix covariance = sqrt.times( sqrt.transpose() );
MultivariateGaussian instance =
new MultivariateGaussian( mean, covariance );
assertEquals( instance.getCovariance(), covariance );
}
/**
* Test of randomCovarianceSquareRoot method, of class gov.sandia.isrc.math.MultivariateGaussian.
*/
public void testRandomCovarianceSquareRoot()
{
System.out.println( "randomCovarianceSquareRoot" );
int N = 2;
double range = 1;
Vector mean = VectorFactory.getDefault().createUniformRandom( N, -range, range, RANDOM );
Matrix sqrt = MatrixFactory.getDefault().createUniformRandom( N, N, -range, range, RANDOM );
DenseMatrix covariance = DenseMatrixFactoryMTJ.INSTANCE.copyMatrix( sqrt.times( sqrt.transpose() ) );
DenseMatrix covariancesqrt = CholeskyDecompositionMTJ.create( covariance ).getR();
int numDraws = NUM_SAMPLES;
ArrayList<Vector> samples =
MultivariateGaussian.sample( mean, covariancesqrt, RANDOM, numDraws );
MultivariateGaussian g2 =
MultivariateGaussian.MaximumLikelihoodEstimator.learn( samples, 0.0 );
double tolerance = N / Math.log( numDraws ) / range;
System.out.println( "Tolerance: " + tolerance );
System.out.println( "Expected Mean:\n" + mean );
System.out.println( "Resulting Mean:\n" + g2.getMean() );
assertTrue( mean.equals( g2.getMean(), tolerance ) );
System.out.println( "Expected Covariance:\n" + covariance );
System.out.println( "Resulting Covariance:\n" + g2.getCovariance() );
assertTrue( covariance.equals( g2.getCovariance(), tolerance ) );
}
/**
* Test of setMean method, of class gov.sandia.isrc.math.MultivariateGaussian.
*/
public void testSetMean()
{
System.out.println( "setMean" );
MultivariateGaussian g = this.createInstance();
int N = g.getInputDimensionality();
Vector m1 = VectorFactory.getDefault().createVector(N, RANDOM.nextGaussian());
g.setMean(m1);
assertEquals( m1, g.getMean() );
g.setMean( m1.clone() );
assertNotSame( m1, g.getMean() );
try
{
g.setMean( null );
fail( "Cannot set null mean" );
}
catch (Exception e)
{
System.out.println( "Good: " + e );
}
}
/**
* Test of scale method, of class gov.sandia.isrc.math.MultivariateGaussian.
*/
public void testScale()
{
System.out.println( "scale" );
MultivariateGaussian g1 = this.createInstance();
final int M = RANDOM.nextInt( 10 ) + 1;
final int N = g1.getInputDimensionality();
final double range = -1.0;
Matrix test = MatrixFactory.getDefault().createUniformRandom( M, N, -range, range, RANDOM );
MultivariateGaussian g2 = g1.scale( test );
assertEquals( M, g2.getInputDimensionality() );
assertEquals( test.times( g1.getMean() ), g2.getMean() );
assertEquals( test.times( g1.getCovariance() ).times( test.transpose() ), g2.getCovariance() );
}
/**
* Test of plus method, of class gov.sandia.isrc.math.MultivariateGaussian.
*/
public void testPlus()
{
System.out.println( "plus" );
MultivariateGaussian g1 = this.createInstance();
MultivariateGaussian g2 = this.createInstance();
MultivariateGaussian sum = g1.plus( g2 );
assertEquals( g1.getMean().plus( g2.getMean() ), sum.getMean() );
assertEquals( g1.getCovariance().plus( g2.getCovariance() ),
sum.getCovariance() );
final int N = g1.getInputDimensionality() + 1;
double range = -1.0;
Vector mean = VectorFactory.getDefault().createUniformRandom( N, -range, range, RANDOM );
Matrix sqrt = MatrixFactory.getDefault().createUniformRandom( N, N, -range, range, RANDOM );
Matrix covariance = sqrt.times( sqrt.transpose() );
MultivariateGaussian g3 = new MultivariateGaussian( mean, covariance );
try
{
g1.plus( g3 );
fail( "Should have thrown exception" );
}
catch (Exception e)
{
System.out.println( "Good: " + e );
}
try
{
g1.plus( null );
fail( "Should have thrown exception" );
}
catch (Exception e)
{
System.out.println( "Good: " + e );
}
}
/**
* Test of times method, of class gov.sandia.isrc.math.MultivariateGaussian.
*/
public void testTimes()
{
System.out.println( "times" );
Vector m1 = VectorFactory.getDefault().createVector( 1 );
m1.setElement( 0, 0.0 );
Matrix c1 = MatrixFactory.getDefault().createMatrix( 1, 1 );
c1.setElement( 0, 0, 0.1 );
MultivariateGaussian g1 = new MultivariateGaussian( m1, c1 );
Vector m2 = VectorFactory.getDefault().createVector( 1 );
m2.setElement( 0, 1.0 );
Matrix c2 = MatrixFactory.getDefault().createMatrix( 1, 1 );
c2.setElement( 0, 0, 0.1 );
MultivariateGaussian g2 = new MultivariateGaussian( m2, c2 );
MultivariateGaussian belief = g1;
for (int i = 0; i < 10; i++)
{
belief = belief.times( g2 );
System.out.println( i + ": Belief: " + belief );
}
}
/**
* Test of convolve method, of class gov.sandia.isrc.math.MultivariateGaussian.
*/
public void testConvolve()
{
System.out.println( "convolve" );
Vector m1 = VectorFactory.getDefault().createVector( 1 );
m1.setElement( 0, 0.0 );
Matrix c1 = MatrixFactory.getDefault().createMatrix( 1, 1 );
c1.setElement( 0, 0, 0.1 );
MultivariateGaussian g1 = new MultivariateGaussian( m1, c1 );
Vector m2 = VectorFactory.getDefault().createVector( 1 );
m2.setElement( 0, 1.0 );
Matrix c2 = MatrixFactory.getDefault().createMatrix( 1, 1 );
c2.setElement( 0, 0, 0.1 );
MultivariateGaussian g2 = new MultivariateGaussian( m2, c2 );
MultivariateGaussian result = g1.convolve(g2);
assertEquals( g1.getMean().plus(g2.getMean()), result.getMean() );
assertEquals( g1.getCovariance().plus( g2.getCovariance()), result.getCovariance() );
}
/**
* Test of setCovariance method, of class gov.sandia.isrc.math.MultivariateGaussian.
*/
public void testSetCovariance()
{
System.out.println( "setCovariance" );
MultivariateGaussian instance = this.createInstance();
try
{
instance.setCovariance( null );
fail( "Should have thrown NullPointerException" );
}
catch (Exception e)
{
System.out.println( "Properly thrown NullPointerException: " + e );
}
int N = instance.getInputDimensionality();
Matrix c2 = MatrixFactory.getDefault().createIdentity( N, N );
assertFalse( c2.equals( instance.getCovariance() ) );
instance.setCovariance( c2 );
assertEquals( c2, instance.getCovariance() );
assertTrue( c2.inverse().equals( instance.getCovarianceInverse(), TOLERANCE ) );
Matrix Cbad = MatrixFactory.getDefault().createMatrix(3,3);
Cbad.setElement(1, 0, 1.0);
instance.setCovariance(Cbad);
assertNotSame( Cbad, instance.getCovariance() );
assertTrue( instance.getCovariance().isSymmetric() );
Cbad = MatrixFactory.getDefault().createIdentity(3, 2);
try
{
instance.setCovariance(Cbad);
fail( "Covariance must be PD" );
}
catch (Exception e)
{
System.out.println( "Good: " + e );
}
}
/**
* Test of equals method, of class gov.sandia.cognition.math.MultivariateGaussian.
*/
public void testEquals()
{
System.out.println( "equals" );
MultivariateGaussian instance = this.createInstance();
MultivariateGaussian instance2 = this.createInstance();
assertNotSame( instance, instance2 );
assertEquals( instance, instance.clone() );
assertEquals( instance2, instance2.clone() );
assertFalse( instance.equals( instance2 ) );
}
/**
* getInputDimensionality
*/
public void testGetInputDimensionality()
{
System.out.println( "getInputDimensionality" );
MultivariateGaussian.PDF instance = this.createInstance().getProbabilityFunction();
assertEquals( instance.getMean().getDimensionality(), instance.getInputDimensionality() );
}
@Override
public MultivariateGaussian createInstance()
{
int N = 3;
double range = 2.0;
Vector mean = VectorFactory.getDefault().createUniformRandom( N, -range, range, RANDOM );
Matrix sqrt = MatrixFactory.getDefault().createUniformRandom( N, N, -range, range, RANDOM );
Matrix covariance = sqrt.times( sqrt.transpose() );
return new MultivariateGaussian( mean, covariance );
}
@Override
public void testGetMean()
{
double temp = TOLERANCE;
int ns = NUM_SAMPLES;
NUM_SAMPLES = 10000;
TOLERANCE = 1e-1;
super.testGetMean();
TOLERANCE = temp;
NUM_SAMPLES = ns;
}
@Override
public void testProbabilityFunctionConstructors()
{
System.out.println( "PDF.Constructors" );
MultivariateGaussian.PDF g = new MultivariateGaussian.PDF();
assertEquals( MultivariateGaussian.DEFAULT_DIMENSIONALITY, g.getInputDimensionality() );
int dim = RANDOM.nextInt(10) + 1;
g = new MultivariateGaussian.PDF( dim );
assertEquals( dim, g.getInputDimensionality() );
Vector mean = g.getMean();
Matrix covariance = g.getCovariance();
g = new MultivariateGaussian.PDF( mean, covariance );
assertSame( mean, g.getMean() );
assertSame( covariance, g.getCovariance() );
MultivariateGaussian.PDF g2 = new MultivariateGaussian.PDF( g );
assertNotSame( g, g2 );
assertNotSame( g.getMean(), g2.getMean() );
assertEquals( g.getMean(), g2.getMean() );
assertNotSame( g.getCovariance(), g2.getCovariance() );
assertEquals( g.getCovariance(), g2.getCovariance() );
}
@Override
public void testConstructors()
{
System.out.println( "Constructors" );
MultivariateGaussian g = new MultivariateGaussian();
assertEquals( MultivariateGaussian.DEFAULT_DIMENSIONALITY, g.getInputDimensionality() );
int dim = RANDOM.nextInt(10) + 1;
g = new MultivariateGaussian( dim );
assertEquals( dim, g.getInputDimensionality() );
Vector mean = g.getMean();
Matrix covariance = g.getCovariance();
g = new MultivariateGaussian( mean, covariance );
assertSame( mean, g.getMean() );
assertSame( covariance, g.getCovariance() );
MultivariateGaussian g2 = new MultivariateGaussian( g );
assertNotSame( g, g2 );
assertNotSame( g.getMean(), g2.getMean() );
assertEquals( g.getMean(), g2.getMean() );
assertNotSame( g.getCovariance(), g2.getCovariance() );
assertEquals( g.getCovariance(), g2.getCovariance() );
}
@Override
public void testKnownValues()
{
}
@Override
public void testKnownConvertToVector()
{
System.out.println( "Known convertToVector" );
MultivariateGaussian g = this.createInstance();
Vector p = g.convertToVector();
int d = g.getInputDimensionality();
Vector mhat = p.subVector(0, d-1);
assertEquals( g.getMean(), mhat );
Vector Chat = p.subVector(d, p.getDimensionality()-1);
Matrix C = MatrixFactory.getDefault().createMatrix(d, d);
C.convertFromVector(Chat);
assertEquals( g.getCovariance(), C );
}
/**
* tests the incremental estimator
*/
public void testIncrementalEstimator()
{
System.out.println( "Incremental Estimator" );
MultivariateGaussian.IncrementalEstimator estimator =
new MultivariateGaussian.IncrementalEstimator();
MultivariateGaussian.IncrementalEstimatorCovarianceInverse ei =
new MultivariateGaussian.IncrementalEstimatorCovarianceInverse();
MultivariateGaussian target = new MultivariateGaussian( 3 );
ArrayList<Vector> samples = target.sample(RANDOM,NUM_SAMPLES);
Vector mean = MultivariateStatisticsUtil.computeMean(samples);
MultivariateGaussian.SufficientStatistic ss = estimator.learn(samples);
assertEquals( samples.size(), ss.getCount() );
assertTrue( mean.equals( ss.getMean(), TOLERANCE ) );
assertTrue( MultivariateStatisticsUtil.computeVariance(samples, mean).equals( ss.getCovariance(), TOLERANCE ) );
MultivariateGaussian result = ss.create();
MultivariateGaussian.SufficientStatisticCovarianceInverse ssi = ei.learn(samples);
MultivariateGaussian ri = ssi.create();
MultivariateGaussian batch =
MultivariateGaussian.MaximumLikelihoodEstimator.learn(samples, 0.0);
System.out.println( "Target: " + target );
System.out.println( "Result: " + result );
System.out.println( "Inverse: " + ri );
System.out.println( "Batch : " + batch );
assertTrue( batch.getMean().equals( result.getMean(), TOLERANCE ) );
assertTrue( batch.getCovariance().equals( result.getCovariance(), TOLERANCE ) );
MultivariateGaussian.SufficientStatistic clone = ss.clone();
assertEquals( ss.getCount(), clone.getCount() );
assertEquals( ss.getMean(), clone.getMean() );
assertEquals( ss.getCovariance(), clone.getCovariance() );
}
}