/* * File: BaumWelchAlgorithmTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Feb 4, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.algorithm.hmm; import gov.sandia.cognition.collection.DefaultMultiCollection; import gov.sandia.cognition.learning.algorithm.BatchLearner; import gov.sandia.cognition.math.matrix.Matrix; import gov.sandia.cognition.math.matrix.MatrixFactory; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.VectorFactory; import gov.sandia.cognition.statistics.ComputableDistribution; import gov.sandia.cognition.statistics.distribution.MultivariateGaussian; import gov.sandia.cognition.util.NamedValue; import gov.sandia.cognition.util.WeightedValue; import java.util.ArrayList; import java.util.Collection; import junit.framework.TestCase; import java.util.Random; /** * Unit tests for BaumWelchAlgorithmTest. * * @author krdixon */ public class BaumWelchAlgorithmTest extends TestCase { /** * Random number generator to use for a fixed random seed. */ public final Random RANDOM = new Random( 1 ); /** * Default tolerance of the regression tests, {@value}. */ public final double TOLERANCE = 1e-5; /** * Tests for class BaumWelchAlgorithmTest. * @param testName Name of the test. */ public BaumWelchAlgorithmTest( String testName) { super(testName); } /** * Creates an instance * @return * instance */ public BaumWelchAlgorithm<Vector> createInstance() { HiddenMarkovModel<Vector> hmm = new HiddenMarkovModel<Vector>( HiddenMarkovModelTest.DEFAULT_NUM_STATES ); final int dim = HiddenMarkovModelTest.DEFAULT_OBSERVATION_DIM; ArrayList<MultivariateGaussian.PDF> pdfs = new ArrayList<MultivariateGaussian.PDF>( hmm.getNumStates() ); for( int i = 0; i < hmm.getNumStates(); i++ ) { Vector mean = VectorFactory.getDefault().createVector( dim, i ); Matrix cov = MatrixFactory.getDefault().createIdentity( dim, dim ); pdfs.add( new MultivariateGaussian.PDF( mean, cov ) ); } hmm.setEmissionFunctions(pdfs); return new BaumWelchAlgorithm<Vector>( hmm, new MultivariateGaussian.WeightedMaximumLikelihoodEstimator(), true ); } /** * Creates a ContinuousDensityHiddenMarkovModel * @return * ContinuousDensityHiddenMarkovModel */ public HiddenMarkovModel<Vector> createHMMInstance() { return HiddenMarkovModelTest.staticCreateInstance(); } /** * Tests the constructors of class BaumWelchAlgorithmTest. */ public void testConstructors() { System.out.println( "Constructors" ); BaumWelchAlgorithm<Vector> instance = new BaumWelchAlgorithm<Vector>(); assertEquals( BaumWelchAlgorithm.DEFAULT_REESTIMATE_INITIAL_PROBABILITY, instance.getReestimateInitialProbabilities() ); assertNull( instance.getInitialGuess() ); assertNull( instance.getDistributionLearner() ); assertNull( instance.getResult() ); MultivariateGaussian.WeightedMaximumLikelihoodEstimator learner = new MultivariateGaussian.WeightedMaximumLikelihoodEstimator(); HiddenMarkovModel<Vector> hmm = this.createHMMInstance(); boolean reestimate = !instance.getReestimateInitialProbabilities(); instance = new BaumWelchAlgorithm<Vector>( hmm, learner, reestimate ); assertSame( hmm, instance.getInitialGuess() ); assertSame( learner, instance.getDistributionLearner() ); assertEquals( reestimate, instance.getReestimateInitialProbabilities() ); assertNull( instance.getResult() ); } /** * Test of clone method, of class BaumWelchAlgorithm. */ public void testClone() { System.out.println("clone"); BaumWelchAlgorithm<Vector> instance = this.createInstance(); BaumWelchAlgorithm<Vector> clone = instance.clone(); assertNotSame( instance, clone ); assertNotSame( instance.getDistributionLearner(), clone.getDistributionLearner() ); assertNotSame( instance.getInitialGuess(), clone.getInitialGuess() ); assertEquals( instance.getReestimateInitialProbabilities(), clone.getReestimateInitialProbabilities() ); } /** * Learn */ public void testLearn() { System.out.println( "learn" ); HiddenMarkovModel<Vector> target = this.createHMMInstance(); ArrayList<Vector> observations = target.sample(RANDOM, 1000); System.out.println( "TARGET: " + target ); double l1 = target.computeObservationLogLikelihood(observations); System.out.println( "TARGET Log Likelihood: " + l1 ); BaumWelchAlgorithm<Vector> learner = this.createInstance(); double l0 = learner.getInitialGuess().computeObservationLogLikelihood( observations); System.out.println( "INITIAL Log Likelihood: " + l0 ); learner.setMaxIterations(1000); HiddenMarkovModel<Vector> result = learner.learn(observations); System.out.println( "Result: " + learner.getIteration() + ": " + result ); double l2 = result.computeObservationLogLikelihood(observations); System.out.println( "Result Log Likelihood: " + l2 ); assertTrue( l2 > l0 ); } /** * Test of getPerformance method, of class BaumWelchAlgorithm. */ public void testGetPerformance() { System.out.println("getPerformance"); BaumWelchAlgorithm<Vector> instance = this.createInstance(); NamedValue<Double> result = instance.getPerformance(); assertEquals( BaumWelchAlgorithm.PERFORMANCE_NAME, result.getName() ); assertEquals( Double.NEGATIVE_INFINITY, result.getValue() ); } /** * Test of getResult method, of class BaumWelchAlgorithm. */ public void testGetResult() { System.out.println("getResult"); BaumWelchAlgorithm<Vector> instance = this.createInstance(); assertNull( instance.getResult() ); HiddenMarkovModel<Vector> hmm = HiddenMarkovModelTest.staticCreateInstance(); ArrayList<Vector> y = hmm.sample( RANDOM, 100 ); HiddenMarkovModel<Vector> result = instance.learn(y); assertSame( result, instance.getResult() ); assertNotNull( result ); } /** * Test of getInitialGuess method, of class BaumWelchAlgorithm. */ public void testGetInitialGuess() { System.out.println("getInitialGuess"); BaumWelchAlgorithm<Vector> instance = this.createInstance(); assertNotNull( instance.getInitialGuess() ); } /** * Test of setInitialGuess method, of class BaumWelchAlgorithm. */ public void testSetInitialGuess() { System.out.println("setInitialGuess"); HiddenMarkovModel<Vector> target = HiddenMarkovModelTest.staticCreateInstance(); ArrayList<Vector> observations = target.sample(RANDOM, 1000); System.out.println( "TARGET: " + target ); double l1 = target.computeObservationLogLikelihood(observations); System.out.println( "TARGET Log Likelihood: " + l1 ); BaumWelchAlgorithm<Vector> learner = this.createInstance(); HiddenMarkovModel<Vector> initial = learner.getInitialGuess().clone(); learner.setMaxIterations(10); HiddenMarkovModel<Vector> result = learner.learn(observations); assertNotSame( result, initial ); } /** * Test of getReestimateInitialProbabilities method, of class BaumWelchAlgorithm. */ public void testGetReestimateInitialProbabilities() { System.out.println("getReestimateInitialProbabilities"); BaumWelchAlgorithm<Vector> instance = this.createInstance(); boolean flag = instance.getReestimateInitialProbabilities(); flag = !flag; instance.setReestimateInitialProbabilities(flag); assertEquals( flag, instance.getReestimateInitialProbabilities() ); } /** * Test of setReestimateInitialProbabilities method, of class BaumWelchAlgorithm. */ public void testSetReestimateInitialProbabilities() { System.out.println("setReestimateInitialProbabilities"); BaumWelchAlgorithm<Vector> instance = this.createInstance(); boolean flag = instance.getReestimateInitialProbabilities(); flag = !flag; instance.setReestimateInitialProbabilities(flag); assertEquals( flag, instance.getReestimateInitialProbabilities() ); } /** * Test of getDistributionLearner method, of class BaumWelchAlgorithm. */ public void testGetDistributionLearner() { System.out.println("getDistributionLearner"); BaumWelchAlgorithm<Vector> instance = this.createInstance(); BatchLearner<Collection<? extends WeightedValue<? extends Vector>>, ? extends ComputableDistribution<Vector>> learner = instance.getDistributionLearner(); assertNotNull( learner ); } /** * Test of setDistributionLearner method, of class BaumWelchAlgorithm. */ public void testSetDistributionLearner() { System.out.println("setDistributionLearner"); BaumWelchAlgorithm<Vector> instance = this.createInstance(); BatchLearner<Collection<? extends WeightedValue<? extends Vector>>, ? extends ComputableDistribution<Vector>> learner = instance.getDistributionLearner(); assertNotNull( learner ); instance.setDistributionLearner(null); assertNull( instance.getDistributionLearner() ); instance.setDistributionLearner(learner); assertSame( learner, instance.getDistributionLearner() ); } /** * Learn */ public void testMultiSequenceLearn() { System.out.println( "Multi-sequence learn" ); HiddenMarkovModel<Vector> target = this.createHMMInstance(); final int numSequences = 100; ArrayList<ArrayList<Vector>> sequences = new ArrayList<ArrayList<Vector>>( numSequences ); for( int k = 0; k < numSequences; k++ ) { sequences.add( target.sample(RANDOM, 10) ); } DefaultMultiCollection<Vector> data = new DefaultMultiCollection<Vector>( sequences ); sequences = null; System.out.println( "TARGET: " + target ); double l1 = target.computeMultipleObservationLogLikelihood(data.subCollections()); System.out.println( "TARGET Log Likelihood: " + l1 ); // FALSE: Result Log Likelihood: -114550.38265183996 BaumWelchAlgorithm<Vector> learner = this.createInstance(); learner.setReestimateInitialProbabilities(true); double l0 = learner.getInitialGuess().computeMultipleObservationLogLikelihood( data.subCollections() ); System.out.println( "INITIAL Log Likelihood: " + l0 ); learner.setMaxIterations(1000); HiddenMarkovModel<Vector> result = learner.learn(data); System.out.println( "Result: " + learner.getIteration() + ": " + result ); double l2 = result.computeMultipleObservationLogLikelihood(data.subCollections()); System.out.println( "Result Log Likelihood: " + l2 ); assertTrue( l2 > l0 ); } }