package dist.hmm;
import shared.DataSet;
import shared.Instance;
import shared.Trainer;
/**
* An implementation of the baum welch re estimation algorithm.
* Takes in a hidden markov model and set of observation sequences,
* then re estimates the parameters of the hidden markov model
* based on expected values calculated through the use
* of forward backward calculator
* @author Andrew Guillory gtg008g@mail.gatech.edu
* @version 1.0
*/
public class HiddenMarkovModelReestimator implements Trainer {
/**
* The array of observation sequences
*/
private DataSet[] observationSequences;
/**
* The hidden markov model itself
*/
private HiddenMarkovModel model;
/**
* The [k][t][i][j] value is the expected number of
* transitions between i and j at time t given
* observation sequence k
*/
private double[][][][] transitionExpectations;
/**
* The [k][t][i] value is the expected number of
* times in state i at time t given sequence k
*/
private double[][][] stateExpectations;
/**
* The output observations
*/
private DataSet outputObservations;
/**
* The transition observations
*/
private DataSet transitionObservations;
/**
* The initial observations
*/
private DataSet initialObservations;
/**
* Make a new reestimator
* @param model the hidden markov model
* @param observationSequences the observation sequencess
* @param inputSequences the corresponding input sequences
*/
public HiddenMarkovModelReestimator(HiddenMarkovModel model,
DataSet[] observationSequences) {
this.model = model;
this.observationSequences = observationSequences;
stateExpectations = new double[observationSequences.length][][];
transitionExpectations = new double[observationSequences.length][][][];
initializeObservations();
}
/**
* Initialize the sequences used in training
*/
public void initializeObservations() {
initializeOutputObservations();
initializeTransitionObservations();
initializeInitialObservations();
}
/**
* Initialize the output observations
*
*/
public void initializeOutputObservations() {
int totalTime = 0;
for (int k = 0; k < observationSequences.length; k++) {
totalTime += observationSequences[k].size();
}
Instance[] outputObservationsInstances = new Instance[totalTime];
int j = 0;
for (int k = 0; k < observationSequences.length; k++) {
Instance[] cur = observationSequences[k].getInstances();
System.arraycopy(cur, 0, outputObservationsInstances, j, cur.length);
j += cur.length;
}
outputObservations = new DataSet(outputObservationsInstances,
observationSequences[0].getDescription());
}
/**
* Initialize the initial observations array
*/
public void initializeInitialObservations() {
Instance[] initialObservationsInstances = new Instance[observationSequences.length];
for (int k = 0; k < observationSequences.length; k++) {
initialObservationsInstances[k] = observationSequences[k].get(0);
}
initialObservations = new DataSet(initialObservationsInstances,
observationSequences[0].getDescription());
}
/**
* Initialize the transition observations array
*/
public void initializeTransitionObservations() {
int totalTime = 0;
for (int k = 0; k < observationSequences.length; k++) {
totalTime += observationSequences[k].size() - 1;
}
Instance[] transitionObservationsInstances = new Instance[totalTime];
int j = 0;
for (int k = 0; k < observationSequences.length; k++) {
Instance[] cur = observationSequences[k].getInstances();
System.arraycopy(cur, 1, transitionObservationsInstances, j, cur.length - 1);
j += cur.length - 1;
}
transitionObservations = new DataSet(transitionObservationsInstances,
observationSequences[0].getDescription());
}
/**
* Restimate the model
* @return the sum of log probabilities for the model / sequences
*/
public double train() {
double probability = 0;
for (int k = 0; k < observationSequences.length; k++) {
DataSet observationSequence = observationSequences[k];
ForwardBackwardProbabilityCalculator fbc =
new ForwardBackwardProbabilityCalculator(model, observationSequence);
double[][] forwardProbabilities = fbc.calculateForwardProbabilities();
double[][] backwardProbabilities = fbc.calculateBackwardProbabilities();
stateExpectations[k] = calculateStateExpectations(
observationSequence, forwardProbabilities,
backwardProbabilities);
transitionExpectations[k] = calculateTransitionExpectations(
observationSequence, forwardProbabilities,
backwardProbabilities);
probability += fbc.calculateLogProbability();
}
reestimateInitialStateDistribution();
reestimateTransitionDistributions();
reestimateOutputDistributions();
return probability / observationSequences.length;
}
/**
* Calculate the transition probabilities for observation sequence k
*/
public double[][][] calculateTransitionExpectations(
DataSet observationSequence, double[][] forwardProbabilities,
double[][] backwardProbabilities) {
double[][][] transitions = new double[observationSequence.size() - 1]
[model.getStateCount()][model.getStateCount()];
for (int t = 0; t < observationSequence.size() - 1; t++) {
double sum = 0;
for (int i = 0; i < model.getStateCount(); i++) {
for (int j = 0; j < model.getStateCount(); j++) {
transitions[t][i][j] = forwardProbabilities[t][i]
* model.transitionProbability(i, j, observationSequence.get(t + 1))
* model.observationProbability(j, observationSequence.get(t + 1))
* backwardProbabilities[t + 1][j];
sum += transitions[t][i][j];
}
}
for (int i = 0; i < model.getStateCount(); i++) {
for (int j = 0; j < model.getStateCount(); j++) {
transitions[t][i][j] /= sum;
}
}
}
return transitions;
}
/**
* Calculate the state probabilities for observation sequence k
*/
public double[][] calculateStateExpectations(
DataSet observationSequence, double[][] forwardProbabilities,
double[][] backwardProbabilities) {
double[][] states = new double[observationSequence.size()]
[model.getStateCount()];
for (int t = 0; t < observationSequence.size(); t++) {
double sum = 0;
for (int i = 0; i < model.getStateCount(); i++) {
states[t][i] = forwardProbabilities[t][i] *
backwardProbabilities[t][i];
sum += states[t][i];
}
for (int i = 0; i < model.getStateCount(); i++) {
states[t][i] /= sum;
}
}
return states;
}
/**
* Reestimate the initial state probabilities
*/
public void reestimateInitialStateDistribution() {
double[][] initialStateProbabilities = new double[observationSequences.length]
[model.getStateCount()];
for (int k = 0; k < observationSequences.length; k++) {
for (int i = 0; i < model.getStateCount(); i++) {
initialStateProbabilities[k][i] = stateExpectations[k][0][i];
}
}
model.estimateIntialStateDistribution(initialStateProbabilities, initialObservations);
}
/**
* Reestimate the transition probabilities
*/
public void reestimateTransitionDistributions() {
double[][] probabilities = new double[transitionObservations.size()]
[model.getStateCount()];
for (int i = 0; i < model.getStateCount(); i++) {
for (int j = 0; j < model.getStateCount(); j++) {
int counter = 0;
for (int k = 0; k < observationSequences.length; k++) {
for (int t = 0; t < observationSequences[k].size() - 1; t++) {
probabilities[counter][j] = transitionExpectations[k][t][i][j];
counter++;
}
}
}
model.estimateTransitionDistribution(i, probabilities, transitionObservations);
}
}
/**
* Reestimate the output probabilities
*/
public void reestimateOutputDistributions() {
for (int i = 0; i < model.getStateCount(); i++) {
int counter = 0;
for (int k = 0; k < observationSequences.length; k++) {
for (int t = 0; t < observationSequences[k].size(); t++) {
observationSequences[k].get(t).setWeight(stateExpectations[k][t][i]);
counter++;
}
}
model.estimateOutputDistribution(i, outputObservations);
}
}
/**
* Get the model
* @return returns the model
*/
public HiddenMarkovModel getModel() {
return model;
}
/**
* Set the model}
* @param model The model to set
*/
public void setModel(HiddenMarkovModel model) {
this.model = model;
}
}