/*
* avenir: Predictive analytic based on Hadoop Map Reduce
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.avenir.markov;
import org.apache.log4j.Logger;
import org.chombo.util.DoubleTable;
import org.chombo.util.TabularData;
/**
* State sequence predictor based on given observation sequence and HMM model
* using viterbi algorithm
* @author pranab
*
*/
public class ViterbiDecoder {
private DoubleTable statePathProb;
private TabularData statePtr;
private HiddenMarkovModel model;
private boolean processed;
private int numObs;
private int curObsIndex;
private int numStates;
private static Logger LOG;
/**
* @param model
*/
public ViterbiDecoder(HiddenMarkovModel model, Logger LOG) {
ViterbiDecoder.LOG = LOG;
this.model = model;
numStates = model.getNumStates();
}
/**
* @param numObs
*/
public void initialize(int numObs) {
this.numObs = numObs;
statePathProb = new DoubleTable(numObs, numStates);
statePtr = new TabularData(numObs, numStates);
processed = false;
curObsIndex = 0;
LOG.debug("numObs:" + numObs);
}
/**
* process next observation
* @param observation
*/
public void nextObservation(String observation) {
int obsIndx = model.getObservationIndex(observation);
double stateProb, obsProb, pathProb, priorPathProb, transProb, maxPathProb;
int maxPathProbStateIndx;
LOG.debug("curObsIndex:" + curObsIndex);
if (!processed) {
//first use initial state probability
for (int stateIndx = 0; stateIndx < numStates; ++stateIndx) {
stateProb = model.getIntialStateProbability(stateIndx);
obsProb = model.getObservationProbabiility(stateIndx, obsIndx);
pathProb = stateProb * obsProb;
LOG.debug("pathProb:" + pathProb );
statePathProb.set(curObsIndex, stateIndx, pathProb);
statePtr.set(curObsIndex, stateIndx, -1);
}
processed = true;
} else {
//iterative for subsequent using prevoious state path probability
for (int stateIndx = 0; stateIndx < numStates; ++stateIndx) {
maxPathProb = 0;
maxPathProbStateIndx = 0;
obsProb = model.getObservationProbabiility(stateIndx, obsIndx);
for (int priorStateIndx = 0; priorStateIndx < numStates; ++priorStateIndx) {
priorPathProb =statePathProb.get(curObsIndex-1, priorStateIndx);
transProb = model.getDestStateProbility(priorStateIndx, stateIndx);
pathProb = priorPathProb * transProb;
if (pathProb > maxPathProb) {
maxPathProb = pathProb;
maxPathProbStateIndx = priorStateIndx;
}
}
LOG.debug("maxPathProb:" + maxPathProb + " maxPathProbStateIndx:" + maxPathProbStateIndx);
statePathProb.set(curObsIndex, stateIndx, maxPathProb * obsProb);
statePtr.set(curObsIndex, stateIndx, maxPathProbStateIndx);
}
}
++curObsIndex;
}
/**
* Get state sequence starting with latest
* @return
*/
public String[] getStateSequence() {
String[] states = new String[numObs];
int stateSeqIndx = 0;
double pathProb;
double maxPathProb = 0;
int maxProbStateIndx = -1;
int priorStateIndx;
int nextStateIndx = -1;
//state at end of observation sequence
LOG.debug("state seq" );
maxPathProb = 0;
for (int stateIndx = 0; stateIndx < numStates; ++stateIndx) {
//max path probability for the last observation
pathProb = statePathProb.get(numObs -1, stateIndx);
if (pathProb > maxPathProb) {
maxPathProb = pathProb;
maxProbStateIndx = stateIndx;
}
}
LOG.debug("maxProbStateIndx:" + maxProbStateIndx);
states[stateSeqIndx++] = model.getState(maxProbStateIndx);
nextStateIndx = maxProbStateIndx;
//backtrack for rest of the states going back ward
for (int obsIndx = numObs -1 ; obsIndx >= 1; --obsIndx) {
priorStateIndx = statePtr.get(obsIndx, nextStateIndx);
LOG.debug("priorStateIndx:" + priorStateIndx);
states[stateSeqIndx++] = model.getState(priorStateIndx);
nextStateIndx = priorStateIndx;
}
return states;
}
}