/*
* Encog(tm) Core v3.4 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2016 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.ml.hmm.train.bw;
import java.util.Arrays;
import java.util.List;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.MLSequenceSet;
import org.encog.ml.hmm.HiddenMarkovModel;
import org.encog.ml.hmm.alog.ForwardBackwardCalculator;
import org.encog.ml.hmm.distributions.StateDistribution;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
/**
* This class provides the base implementation for Baum-Welch learning for
* HMM's. There are currently two implementations provided.
*
* TrainBaumWelch - Regular Baum Welch Learning.
*
* TrainBaumWelchScaled - Regular Baum Welch Learning, which can handle
* underflows in long sequences.
*
* L. E. Baum, T. Petrie, G. Soules, and N. Weiss,
* "A maximization technique occurring in the statistical analysis of probabilistic functions of Markov chains"
* , Ann. Math. Statist., vol. 41, no. 1, pp. 164-171, 1970.
*
* Hidden Markov Models and the Baum-Welch Algorithm, IEEE Information Theory
* Society Newsletter, Dec. 2003.
*
*/
public abstract class BaseBaumWelch implements MLTrain {
private int iterations;
private HiddenMarkovModel method;
private final MLSequenceSet training;
public BaseBaumWelch(final HiddenMarkovModel hmm,
final MLSequenceSet training) {
this.method = hmm;
this.training = training;
}
@Override
public void addStrategy(final Strategy strategy) {
}
@Override
public boolean canContinue() {
return false;
}
protected double[][] estimateGamma(final double[][][] xi,
final ForwardBackwardCalculator fbc) {
final double[][] gamma = new double[xi.length + 1][xi[0].length];
for (int t = 0; t < (xi.length + 1); t++) {
Arrays.fill(gamma[t], 0.);
}
for (int t = 0; t < xi.length; t++) {
for (int i = 0; i < xi[0].length; i++) {
for (int j = 0; j < xi[0].length; j++) {
gamma[t][i] += xi[t][i][j];
}
}
}
for (int j = 0; j < xi[0].length; j++) {
for (int i = 0; i < xi[0].length; i++) {
gamma[xi.length][j] += xi[xi.length - 1][i][j];
}
}
return gamma;
}
public abstract double[][][] estimateXi(MLDataSet sequence,
ForwardBackwardCalculator fbc, HiddenMarkovModel hmm);
@Override
public void finishTraining() {
}
public abstract ForwardBackwardCalculator generateForwardBackwardCalculator(
MLDataSet sequence, HiddenMarkovModel hmm);
@Override
public double getError() {
return 0;
}
@Override
public TrainingImplementationType getImplementationType() {
return TrainingImplementationType.Iterative;
}
@Override
public int getIteration() {
return this.iterations;
}
@Override
public MLMethod getMethod() {
return this.method;
}
@Override
public List<Strategy> getStrategies() {
return null;
}
@Override
public MLDataSet getTraining() {
return this.training;
}
@Override
public boolean isTrainingDone() {
return false;
}
@Override
public void iteration() {
HiddenMarkovModel nhmm;
try {
nhmm = this.method.clone();
} catch (final CloneNotSupportedException e) {
throw new InternalError();
}
final double allGamma[][][] = new double[this.training
.getSequenceCount()][][];
final double aijNum[][] = new double[this.method.getStateCount()][this.method
.getStateCount()];
final double aijDen[] = new double[this.method.getStateCount()];
Arrays.fill(aijDen, 0.0);
for (int i = 0; i < this.method.getStateCount(); i++) {
Arrays.fill(aijNum[i], 0.);
}
int g = 0;
for (final MLDataSet obsSeq : this.training.getSequences()) {
final ForwardBackwardCalculator fbc = generateForwardBackwardCalculator(
obsSeq, this.method);
final double xi[][][] = estimateXi(obsSeq, fbc, this.method);
final double gamma[][] = allGamma[g++] = estimateGamma(xi, fbc);
for (int i = 0; i < this.method.getStateCount(); i++) {
for (int t = 0; t < (obsSeq.size() - 1); t++) {
aijDen[i] += gamma[t][i];
for (int j = 0; j < this.method.getStateCount(); j++) {
aijNum[i][j] += xi[t][i][j];
}
}
}
}
for (int i = 0; i < this.method.getStateCount(); i++) {
if (aijDen[i] == 0.0) {
for (int j = 0; j < this.method.getStateCount(); j++) {
nhmm.setTransitionProbability(i, j,
this.method.getTransitionProbability(i, j));
}
} else {
for (int j = 0; j < this.method.getStateCount(); j++) {
nhmm.setTransitionProbability(i, j, aijNum[i][j]
/ aijDen[i]);
}
}
}
/* compute pi */
for (int i = 0; i < this.method.getStateCount(); i++) {
nhmm.setPi(i, 0.);
}
for (int o = 0; o < this.training.getSequenceCount(); o++) {
for (int i = 0; i < this.method.getStateCount(); i++) {
nhmm.setPi(
i,
nhmm.getPi(i)
+ (allGamma[o][0][i] / this.training
.getSequenceCount()));
}
}
/* compute pdfs */
for (int i = 0; i < this.method.getStateCount(); i++) {
final double[] weights = new double[this.training.size()];
double sum = 0.;
int j = 0;
int o = 0;
for (final MLDataSet obsSeq : this.training.getSequences()) {
for (int t = 0; t < obsSeq.size(); t++, j++) {
sum += weights[j] = allGamma[o][t][i];
}
o++;
}
for (j--; j >= 0; j--) {
weights[j] /= sum;
}
final StateDistribution opdf = nhmm.getStateDistribution(i);
opdf.fit(this.training, weights);
}
this.method = nhmm;
}
@Override
public void iteration(final int count) {
for (int i = 0; i < count; i++) {
iteration();
}
}
@Override
public TrainingContinuation pause() {
return null;
}
@Override
public void resume(final TrainingContinuation state) {
}
@Override
public void setError(final double error) {
}
@Override
public void setIteration(final int iteration) {
this.iterations = iteration;
}
}