/* * File: ParallelBaumWelchAlgorithmTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Feb 4, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.algorithm.hmm; import gov.sandia.cognition.algorithm.ParallelUtil; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.statistics.distribution.MultivariateGaussian; import java.util.Collection; import java.util.concurrent.ThreadPoolExecutor; /** * Unit tests for ParallelBaumWelchAlgorithmTest. * * @author krdixon */ public class ParallelBaumWelchAlgorithmTest extends BaumWelchAlgorithmTest { /** * Tests for class ParallelBaumWelchAlgorithmTest. * @param testName Name of the test. */ public ParallelBaumWelchAlgorithmTest( String testName) { super(testName); } @Override public ParallelBaumWelchAlgorithm<Vector> createInstance() { BaumWelchAlgorithm<Vector> si = super.createInstance(); ParallelHiddenMarkovModel<Vector> hmm = new ParallelHiddenMarkovModel<Vector>( si.getInitialGuess().getInitialProbability(), si.getInitialGuess().getTransitionProbability(), si.getInitialGuess().getEmissionFunctions() ); ParallelBaumWelchAlgorithm<Vector> instance = new ParallelBaumWelchAlgorithm<Vector>( hmm, si.getDistributionLearner(), si.getReestimateInitialProbabilities() ); return instance; } /** * Tests the constructors of class ParallelBaumWelchAlgorithmTest. */ @Override public void testConstructors() { System.out.println( "Constructors" ); ParallelBaumWelchAlgorithm<Vector> instance = new ParallelBaumWelchAlgorithm<Vector>(); assertEquals( BaumWelchAlgorithm.DEFAULT_REESTIMATE_INITIAL_PROBABILITY, instance.getReestimateInitialProbabilities() ); assertNull( instance.getInitialGuess() ); assertNull( instance.getDistributionLearner() ); assertNull( instance.getResult() ); MultivariateGaussian.WeightedMaximumLikelihoodEstimator learner = new MultivariateGaussian.WeightedMaximumLikelihoodEstimator(); HiddenMarkovModel<Vector> hmm = this.createHMMInstance(); boolean reestimate = !instance.getReestimateInitialProbabilities(); instance = new ParallelBaumWelchAlgorithm<Vector>( hmm, learner, reestimate ); assertSame( hmm, instance.getInitialGuess() ); assertSame( learner, instance.getDistributionLearner() ); assertEquals( reestimate, instance.getReestimateInitialProbabilities() ); assertNull( instance.getResult() ); } /** * Test of getThreadPool method, of class ParallelBaumWelchAlgorithm. */ public void testGetThreadPool() { System.out.println("getThreadPool"); ParallelBaumWelchAlgorithm<?> instance = this.createInstance(); assertNotNull( instance.getThreadPool() ); } /** * Test of setThreadPool method, of class ParallelBaumWelchAlgorithm. */ public void testSetThreadPool() { System.out.println("setThreadPool"); ThreadPoolExecutor threadPool = ParallelUtil.createThreadPool(); ParallelBaumWelchAlgorithm<?> instance = this.createInstance(); instance.setThreadPool(threadPool); assertSame( threadPool, instance.getThreadPool() ); } /** * Test of getNumThreads method, of class ParallelBaumWelchAlgorithm. */ public void testGetNumThreads() { System.out.println("getNumThreads"); ParallelBaumWelchAlgorithm<Vector> instance = this.createInstance(); assertTrue( instance.getNumThreads() >= 1 ); } /** * Tests equivalence between parallel and serial versions. */ public void testEquivalenttoSerialBW() { System.out.println( "Tests equivalance" ); BaumWelchAlgorithm<Vector> serial = super.createInstance(); ParallelBaumWelchAlgorithm<Vector> parallel = this.createInstance(); HiddenMarkovModel<Vector> hmm = createHMMInstance(); final int NUM_SAMPLES = 100; Collection<Vector> samples = hmm.sample(RANDOM, NUM_SAMPLES ); serial.learn(samples); parallel.learn(samples); assertEquals( serial.getIteration(), parallel.getIteration() ); assertEquals( serial.getResult().computeObservationLogLikelihood(samples), parallel.getResult().computeObservationLogLikelihood(samples) ); } }