/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.hmm.alog;
import java.util.EnumSet;
import java.util.Iterator;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.hmm.HiddenMarkovModel;
/**
* The forward-backward algorithm is an inference algorithm for hidden Markov
* models which computes the posterior marginals of all hidden state variables
* given a sequence of observations.
*
*
*/
public class ForwardBackwardCalculator {
public static enum Computation {
ALPHA, BETA
};
/**
* Alpha matrix.
*/
protected double[][] alpha = null;
/**
* Beta matrix.
*/
protected double[][] beta = null;
/**
* Probability.
*/
protected double probability;
/**
* Construct an empty object.
*/
protected ForwardBackwardCalculator() {
};
/**
* Construct the forward/backward calculator.
* @param oseq The sequence to use.
* @param hmm THe hidden markov model to use.
*/
public ForwardBackwardCalculator(final MLDataSet oseq,
final HiddenMarkovModel hmm) {
this(oseq, hmm, EnumSet.of(Computation.ALPHA));
}
/**
* Construct the object.
* @param oseq The sequence.
* @param hmm The hidden markov model to use.
* @param flags Flags, alpha or beta.
*/
public ForwardBackwardCalculator(final MLDataSet oseq,
final HiddenMarkovModel hmm, final EnumSet<Computation> flags) {
if (oseq.size() < 1) {
throw new IllegalArgumentException("Empty sequence");
}
if (flags.contains(Computation.ALPHA)) {
computeAlpha(hmm, oseq);
}
if (flags.contains(Computation.BETA)) {
computeBeta(hmm, oseq);
}
computeProbability(oseq, hmm, flags);
}
/**
* Alpha element.
* @param t The row.
* @param i The column.
* @return The element.
*/
public double alphaElement(final int t, final int i) {
if (this.alpha == null) {
throw new UnsupportedOperationException("Alpha array has not "
+ "been computed");
}
return this.alpha[t][i];
}
/**
* Beta element, best element.
* @param t From.
* @param i To.
* @return The element.
*/
public double betaElement(final int t, final int i) {
if (this.beta == null) {
throw new UnsupportedOperationException("Beta array has not "
+ "been computed");
}
return this.beta[t][i];
}
/**
* Compute alpha.
* @param hmm The hidden markov model.
* @param oseq The sequence.
*/
protected void computeAlpha(final HiddenMarkovModel hmm,
final MLDataSet oseq) {
this.alpha = new double[oseq.size()][hmm.getStateCount()];
for (int i = 0; i < hmm.getStateCount(); i++) {
computeAlphaInit(hmm, oseq.get(0), i);
}
final Iterator<MLDataPair> seqIterator = oseq.iterator();
if (seqIterator.hasNext()) {
seqIterator.next();
}
for (int t = 1; t < oseq.size(); t++) {
final MLDataPair observation = seqIterator.next();
for (int i = 0; i < hmm.getStateCount(); i++) {
computeAlphaStep(hmm, observation, t, i);
}
}
}
/**
* Compute the alpha init.
* @param hmm THe hidden markov model.
* @param o The element.
* @param i The state.
*/
protected void computeAlphaInit(final HiddenMarkovModel hmm,
final MLDataPair o, final int i) {
this.alpha[0][i] = hmm.getPi(i)
* hmm.getStateDistribution(i).probability(o);
}
/**
* Compute the alpha step.
* @param hmm The hidden markov model.
* @param o The sequence element.
* @param t The alpha step.
* @param j Thr column.
*/
protected void computeAlphaStep(final HiddenMarkovModel hmm,
final MLDataPair o, final int t, final int j) {
double sum = 0.;
for (int i = 0; i < hmm.getStateCount(); i++) {
sum += this.alpha[t - 1][i] * hmm.getTransitionProbability(i, j);
}
this.alpha[t][j] = sum * hmm.getStateDistribution(j).probability(o);
}
/**
* Compute the beta step.
* @param hmm The hidden markov model.
* @param oseq The sequence.
*/
protected void computeBeta(final HiddenMarkovModel hmm, final MLDataSet oseq) {
this.beta = new double[oseq.size()][hmm.getStateCount()];
for (int i = 0; i < hmm.getStateCount(); i++) {
this.beta[oseq.size() - 1][i] = 1.;
}
for (int t = oseq.size() - 2; t >= 0; t--) {
for (int i = 0; i < hmm.getStateCount(); i++) {
computeBetaStep(hmm, oseq.get(t + 1), t, i);
}
}
}
/**
* Compute the beta step.
* @param hmm The hidden markov model.
* @param o THe data par to compute.
* @param t THe matrix row.
* @param i THe matrix column.
*/
protected void computeBetaStep(final HiddenMarkovModel hmm,
final MLDataPair o, final int t, final int i) {
double sum = 0.;
for (int j = 0; j < hmm.getStateCount(); j++) {
sum += this.beta[t + 1][j] * hmm.getTransitionProbability(i, j)
* hmm.getStateDistribution(j).probability(o);
}
this.beta[t][i] = sum;
}
/**
* Compute the probability.
* @param oseq The sequence.
* @param hmm THe hidden markov model.
* @param flags The flags.
*/
private void computeProbability(final MLDataSet oseq,
final HiddenMarkovModel hmm, final EnumSet<Computation> flags) {
this.probability = 0.;
if (flags.contains(Computation.ALPHA)) {
for (int i = 0; i < hmm.getStateCount(); i++) {
this.probability += this.alpha[oseq.size() - 1][i];
}
} else {
for (int i = 0; i < hmm.getStateCount(); i++) {
this.probability += hmm.getPi(i)
* hmm.getStateDistribution(i).probability(oseq.get(0))
* this.beta[0][i];
}
}
}
/**
* @return The probability.
*/
public double probability() {
return this.probability;
}
}