/* * File: MultivariateDecorrelatorTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Jul 6, 2009, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.data.feature; import gov.sandia.cognition.math.MultivariateStatisticsUtil; import gov.sandia.cognition.math.UnivariateStatisticsUtil; import gov.sandia.cognition.math.matrix.Matrix; import gov.sandia.cognition.math.matrix.MatrixFactory; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.mtj.DenseMatrixFactoryMTJ; import gov.sandia.cognition.math.matrix.mtj.Vector3; import gov.sandia.cognition.math.matrix.mtj.decomposition.CholeskyDecompositionMTJ; import gov.sandia.cognition.statistics.distribution.MultivariateGaussian; import java.util.ArrayList; import junit.framework.TestCase; import java.util.Random; /** * Unit tests for MultivariateDecorrelatorTest. * * @author krdixon */ public class MultivariateDecorrelatorTest extends TestCase { /** * Random number generator to use for a fixed random seed. */ public final Random RANDOM = new Random( 1 ); /** * Default tolerance of the regression tests, {@value}. */ public final double TOLERANCE = 1e-5; /** * Tests for class MultivariateDecorrelatorTest. * @param testName Name of the test. */ public MultivariateDecorrelatorTest( String testName) { super(testName); } public MultivariateDecorrelator createInstance() { Vector mean = this.createRandomInput(); Matrix A = MatrixFactory.getDefault().copyColumnVectors( this.createRandomInput(), this.createRandomInput(), this.createRandomInput() ); Matrix C = A.transpose().times(A); return new MultivariateDecorrelator( mean, C ); } public Vector createRandomInput() { return Vector3.createRandom(RANDOM); } /** * Tests the constructors of class MultivariateDecorrelatorTest. */ public void testConstructors() { System.out.println( "Constructors" ); MultivariateDecorrelator f = new MultivariateDecorrelator(); assertNotNull( f ); assertNull( f.getGaussian() ); assertNull( f.getCovarianceInverseSquareRoot() ); f = this.createInstance(); assertNotNull( f ); assertNotNull( f.getGaussian() ); assertNotNull( f.getMean() ); assertNotNull( f.getCovariance() ); assertNotNull( f.getCovarianceInverseSquareRoot() ); } /** * Test of clone method, of class MultivariateDecorrelator. */ public void testClone() { System.out.println("clone"); MultivariateDecorrelator instance = this.createInstance(); MultivariateDecorrelator clone = instance.clone(); assertNotNull( clone ); assertNotSame( instance, clone ); assertNotSame( instance.getGaussian(), clone.getGaussian() ); assertNotSame( instance.getCovarianceInverseSquareRoot(), clone.getCovarianceInverseSquareRoot() ); assertEquals( instance.getMean(), clone.getMean() ); assertEquals( instance.getCovariance(), clone.getCovariance() ); } /** * Test of evaluate method, of class MultivariateDecorrelator. */ public void testEvaluate() { System.out.println("evaluate"); MultivariateDecorrelator instance = this.createInstance(); Vector input = this.createRandomInput(); Vector expected = input.minus(instance.getMean()).times( instance.getCovarianceInverseSquareRoot() ); Vector result = instance.evaluate(input); if( !expected.equals( result, TOLERANCE ) ) { assertEquals( expected, result ); } } /** * Test of getInputDimensionality method, of class MultivariateDecorrelator. */ public void testGetInputDimensionality() { MultivariateDecorrelator instance = this.createInstance(); assertEquals(3, instance.getInputDimensionality()); } /** * Test of getOutputDimensionality method, of class MultivariateDecorrelator. */ public void testGetOutputDimensionality() { MultivariateDecorrelator instance = this.createInstance(); assertEquals(3, instance.getOutputDimensionality()); } /** * Test of getMean method, of class MultivariateDecorrelator. */ public void testGetMean() { System.out.println("getMean"); MultivariateDecorrelator instance = this.createInstance(); assertNotNull( instance.getMean() ); instance = new MultivariateDecorrelator(); try { instance.getMean(); fail( "No gaussian, throws NullPointerException" ); } catch (Exception e) { System.out.println( "Good: " + e ); } } /** * Test of getCovariance method, of class MultivariateDecorrelator. */ public void testGetCovariance() { System.out.println("getCovariance"); MultivariateDecorrelator instance = this.createInstance(); assertNotNull( instance.getCovariance() ); instance = new MultivariateDecorrelator(); try { instance.getCovariance(); fail( "No gaussian, throws NullPointerException" ); } catch (Exception e) { System.out.println( "Good: " + e ); } } /** * Test of getGaussian method, of class MultivariateDecorrelator. */ public void testGetGaussian() { System.out.println("getGaussian"); MultivariateDecorrelator instance = this.createInstance(); assertNotNull( instance.getGaussian() ); } /** * Test of setGaussian method, of class MultivariateDecorrelator. */ public void testSetGaussian() { System.out.println("setGaussian"); MultivariateDecorrelator instance = this.createInstance(); MultivariateGaussian gaussian = instance.getGaussian(); assertNotNull( gaussian ); instance.setGaussian(null); assertNull( instance.getGaussian() ); assertNull( instance.getCovarianceInverseSquareRoot() ); instance.setGaussian(gaussian); assertNotSame( gaussian, instance.getGaussian() ); assertNotNull( instance.getCovarianceInverseSquareRoot() ); assertEquals( gaussian.getMean(), instance.getMean() ); assertEquals( gaussian.getCovariance(), instance.getCovariance() ); } /** * Test of getCovarianceInverseSquareRoot method, of class MultivariateDecorrelator. */ public void testGetCovarianceInverseSquareRoot() { System.out.println("getCovarianceInverseSquareRoot"); MultivariateDecorrelator instance = this.createInstance(); Matrix sqrt = CholeskyDecompositionMTJ.create( DenseMatrixFactoryMTJ.INSTANCE.copyMatrix( instance.getCovariance().inverse() ) ).getR(); if( !sqrt.equals( instance.getCovarianceInverseSquareRoot(), TOLERANCE ) ) { assertEquals( sqrt, instance.getCovarianceInverseSquareRoot() ); } } public ArrayList<Vector> createDataset() { final int num = 100; ArrayList<Vector> data = new ArrayList<Vector>( num ); for( int n = 0; n < num; n++ ) { data.add( this.createRandomInput() ); } return data; } /** * Test of learnFullCovariance method, of class MultivariateDecorrelator. */ public void testLearnFullCovariance() { System.out.println("learnFullCovariance"); MultivariateDecorrelator.FullCovarianceLearner learner = new MultivariateDecorrelator.FullCovarianceLearner(); learner.setDefaultCovariance(0.0); ArrayList<Vector> data = this.createDataset(); MultivariateDecorrelator instance = learner.learn( data ); Vector mean = MultivariateStatisticsUtil.computeMean(data); if( !mean.equals( instance.getMean() ) ) { assertEquals( mean, instance.getMean() ); } Matrix covariance = MultivariateStatisticsUtil.computeVariance(data,mean); if( !covariance.equals(instance.getCovariance(), TOLERANCE ) ) { assertEquals( covariance, instance.getCovariance() ); } } /** * Test of learnDiagonalCovariance method, of class MultivariateDecorrelator. */ public void testLearnDiagonalCovariance() { System.out.println("learnDiagonalCovariance"); ArrayList<Vector> data = this.createDataset(); MultivariateDecorrelator.DiagonalCovarianceLearner learner = new MultivariateDecorrelator.DiagonalCovarianceLearner(); learner.setDefaultCovariance(0.0); MultivariateDecorrelator instance = learner.learn(data); Vector mean = MultivariateStatisticsUtil.computeMean(data); if( !mean.equals( instance.getMean() ) ) { assertEquals( mean, instance.getMean() ); } Matrix Chat = instance.getCovariance(); final int M = mean.getDimensionality(); assertEquals( M, Chat.getNumRows() ); assertEquals( M, Chat.getNumColumns() ); double biasedAdjustment = (data.size()-1.0)/data.size(); for( int i = 0; i < M; i++ ) { ArrayList<Double> di = new ArrayList<Double>( data.size() ); for( Vector v : data ) { di.add(v.getElement(i)); } for( int j = 0; j < M; j++ ) { if( i == j ) { double variance = biasedAdjustment * UnivariateStatisticsUtil.computeVariance(di); assertEquals( variance, Chat.getElement(i,i), TOLERANCE ); } else { assertEquals( 0.0, Chat.getElement(i, j) ); } } } } }