/* * File: MultivariateGaussianMeanCovarianceBayesianEstimatorTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Mar 29, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.statistics.bayesian.conjugate; import gov.sandia.cognition.math.MultivariateStatisticsUtil; import gov.sandia.cognition.math.matrix.Matrix; import gov.sandia.cognition.math.matrix.MatrixFactory; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.VectorFactory; import gov.sandia.cognition.statistics.distribution.MultivariateGaussian; import gov.sandia.cognition.statistics.distribution.NormalInverseWishartDistribution; import java.util.ArrayList; /** * Unit tests for MultivariateGaussianMeanCovarianceBayesianEstimatorTest. * * @author krdixon */ public class MultivariateGaussianMeanCovarianceBayesianEstimatorTest extends ConjugatePriorBayesianEstimatorTestHarness<Vector,Matrix,NormalInverseWishartDistribution> { /** * Dimensionality */ public int DIM = 3; /** * Tests for class MultivariateGaussianMeanCovarianceBayesianEstimatorTest. * @param testName Name of the test. */ public MultivariateGaussianMeanCovarianceBayesianEstimatorTest( String testName) { super(testName); } /** * Tests the constructors of class MultivariateGaussianMeanCovarianceBayesianEstimatorTest. */ public void testConstructors() { System.out.println( "Constructors" ); MultivariateGaussianMeanCovarianceBayesianEstimator instance = new MultivariateGaussianMeanCovarianceBayesianEstimator(); NormalInverseWishartDistribution belief = instance.getInitialBelief(); assertNotNull( belief ); instance = new MultivariateGaussianMeanCovarianceBayesianEstimator( belief ); assertSame( belief, instance.getInitialBelief() ); } @Override public MultivariateGaussianMeanCovarianceBayesianEstimator createInstance() { return new MultivariateGaussianMeanCovarianceBayesianEstimator(3); } @Override public MultivariateGaussian createConditionalDistribution() { Vector mean = VectorFactory.getDefault().createUniformRandom( DIM,-3.0, 3.0, RANDOM); Matrix R = MatrixFactory.getDefault().createUniformRandom( DIM, DIM, -1.0, 1.0, RANDOM); Matrix C = R.times( R.transpose() ); return new MultivariateGaussian( mean, C ); } @Override public void testKnownValues() { System.out.println( "Known Values" ); MultivariateGaussian g = this.createConditionalDistribution(); ArrayList<? extends Vector> samples = g.sample(RANDOM,NUM_SAMPLES); MultivariateGaussianMeanCovarianceBayesianEstimator instance = this.createInstance(); long start = System.currentTimeMillis(); NormalInverseWishartDistribution result = instance.learn(samples); long stop = System.currentTimeMillis(); System.out.println( "NUM: " + samples.size() + ", LearnTime: " + (stop-start)/1000.0); ArrayList<? extends Matrix> parameters = result.sample(RANDOM,NUM_SAMPLES); Matrix averageParameter = MultivariateStatisticsUtil.computeMean(parameters); MultivariateGaussian ghat = instance.createConditionalDistribution(averageParameter); Matrix parameterMean = result.getMean(); MultivariateGaussian ghat2 = instance.createConditionalDistribution(parameterMean); System.out.println( "G:\n" + g ); System.out.println( "Ghat:\n" + ghat ); System.out.println( "Ghat2:\n" + ghat2 ); } }