/* * File: BaumWelchAlgorithm.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Jan 19, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.algorithm.hmm; import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm; import gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner; import gov.sandia.cognition.learning.algorithm.BatchLearner; import gov.sandia.cognition.statistics.ComputableDistribution; import gov.sandia.cognition.util.DefaultNamedValue; import gov.sandia.cognition.util.NamedValue; import gov.sandia.cognition.util.ObjectUtil; import gov.sandia.cognition.util.WeightedValue; import java.util.Collection; /** * Partial implementation of the Baum-Welch algorithm. * @param <ObservationType> Type of Observations handled by the HMM. * @param <DataType> Type of data (Collection of ObservationType, for instance) * sent to the learn method. * @author Kevin R. Dixon * @since 3.0 */ public abstract class AbstractBaumWelchAlgorithm<ObservationType,DataType> extends AbstractAnytimeBatchLearner<DataType,HiddenMarkovModel<ObservationType>> implements MeasurablePerformanceAlgorithm { /** * Default maximum number of iterations, {@value}. */ public static final int DEFAULT_MAX_ITERATIONS = 100; /** * Default flag to re-estimate initial probabilities, {@value}. */ public static final boolean DEFAULT_REESTIMATE_INITIAL_PROBABILITY = true; /** * Name of the performance statistic, {@value}. */ public static final String PERFORMANCE_NAME = "Log Likelihood"; /** * Learner for the Distribution Functions of the HMM. */ protected BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>,? extends ComputableDistribution<ObservationType>> distributionLearner; /** * Result of the Baum-Welch Algorithm */ protected HiddenMarkovModel<ObservationType> result; /** * Initial guess for the iterations. */ protected HiddenMarkovModel<ObservationType> initialGuess; /** * Last Log Likelihood of the iterations */ protected double lastLogLikelihood; /** * Flag to re-estimate the initial probability Vector. */ protected boolean reestimateInitialProbabilities; /** * Creates a new instance of AbstractBaumWelchAlgorithm * @param initialGuess * Initial guess for the iterations. * @param distributionLearner * Learner for the Distribution Functions of the HMM. * @param reestimateInitialProbabilities * Flag to re-estimate the initial probability Vector. */ public AbstractBaumWelchAlgorithm( HiddenMarkovModel<ObservationType> initialGuess, BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>,? extends ComputableDistribution<ObservationType>> distributionLearner, boolean reestimateInitialProbabilities ) { super(DEFAULT_MAX_ITERATIONS); this.setInitialGuess(initialGuess); this.setDistributionLearner(distributionLearner); this.setReestimateInitialProbabilities(reestimateInitialProbabilities); this.result = null; this.lastLogLikelihood = Double.NEGATIVE_INFINITY; } @Override public AbstractBaumWelchAlgorithm<ObservationType,DataType> clone() { AbstractBaumWelchAlgorithm<ObservationType,DataType> clone = (AbstractBaumWelchAlgorithm<ObservationType,DataType>) super.clone(); clone.setDistributionLearner( ObjectUtil.cloneSafe( this.getDistributionLearner() ) ); clone.result = ObjectUtil.cloneSafe( this.getResult() ); clone.setInitialGuess( ObjectUtil.cloneSafe( this.getInitialGuess() ) ); return clone; } public NamedValue<Double> getPerformance() { return new DefaultNamedValue<Double>( PERFORMANCE_NAME, this.getLastLogLikelihood()); } public HiddenMarkovModel<ObservationType> getResult() { return this.result; } /** * Getter for initialGuess. * @return * Initial guess for the iterations. */ public HiddenMarkovModel<ObservationType> getInitialGuess() { return this.initialGuess; } /** * Setter for initialGuess. * @param initialGuess * Initial guess for the iterations. */ public void setInitialGuess( HiddenMarkovModel<ObservationType> initialGuess) { this.initialGuess = initialGuess; } /** * Getter for reestimateInitialProbabilities * @return the reestimateInitialProbabilities * Flag to re-estimate the initial probability Vector. */ public boolean getReestimateInitialProbabilities() { return this.reestimateInitialProbabilities; } /** * Setter for reestimateInitialProbabilities * @param reestimateInitialProbabilities * Flag to re-estimate the initial probability Vector. */ public void setReestimateInitialProbabilities( boolean reestimateInitialProbabilities) { this.reestimateInitialProbabilities = reestimateInitialProbabilities; } /** * Getter for distributionLearner * @return * Learner for the Distribution Functions of the HMM. */ public BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>, ? extends ComputableDistribution<ObservationType>> getDistributionLearner() { return this.distributionLearner; } /** * Setter for distributionLearner * @param distributionLearner * Learner for the Distribution Functions of the HMM. */ public void setDistributionLearner( BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>, ? extends ComputableDistribution<ObservationType>> distributionLearner) { this.distributionLearner = distributionLearner; } /** * Gets the log likelihood of the last completed step of the algorithm. * * @return * The last log likelihood. */ public double getLastLogLikelihood() { return this.lastLogLikelihood; } }