/** Viterbi.java
*
* Viterbi search.
*
* @author Sunita Sarawagi
* @since 1.1
* @version 1.3
*/
package iitb.CRF;
import java.io.Serializable;
import cern.colt.matrix.tdouble.DoubleMatrix1D;
import cern.colt.matrix.tdouble.DoubleMatrix2D;
import cern.colt.matrix.tdouble.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.tdouble.impl.DenseDoubleMatrix2D;
public class Viterbi implements Serializable {
private static final long serialVersionUID = 8122L;
protected CRF model;
protected int beamsize;
public Viterbi(CRF model, int bs) {
this.model = model;
beamsize = bs;
if (model != null && model.params.miscOptions.getProperty("beamSize") != null)
beamsize = Integer.parseInt(model.params.miscOptions.getProperty("beamSize"));
}
protected class Entry {
public Soln solns[]; // TODO.
boolean valid=true;
protected Entry() {}
protected Entry(int beamsize, int id, int pos) {
solns = new Soln[beamsize];
for (int i = 0; i < solns.length; i++)
solns[i] = newSoln(id, pos);
}
protected Soln newSoln(int label, int pos) {
return new Soln(label,pos);
}
protected void clear() {
valid = false;
for (int i = 0; i < solns.length; i++)
solns[i].clear();
}
public int size() {return solns.length;}
public Soln get(int i) {return solns[i];}
protected void insert(int i, float score, Soln prev) {
Soln saved = solns[size()-1];
for (int k = size()-1; k > i; k--) {
//solns[k].copy(solns[k-1]);
solns[k] = solns[k-1];
}
solns[i] = saved;
solns[i].setPrevSoln(prev,score);
}
protected void add(Entry e, float thisScore) {
assert(valid);
if (e == null) {
add(thisScore);
return;
}
// the soln within each entry are sorted.
int insertPos = 0;
for (int i = 0; (i < e.size()) && (insertPos < size()); i++) {
float score = e.get(i).score + thisScore;
insertPos = findInsert(insertPos, score, e.get(i));
}
//print()
}
protected int findInsert(int insertPos, float score, Soln prev) {
for (; insertPos < size(); insertPos++) {
if (score >= get(insertPos).score) {
insert(insertPos, score, prev);
insertPos++;
break;
}
}
return insertPos;
}
protected void add(float thisScore) {
findInsert(0, thisScore, null);
}
public int numSolns() {
for (int i = 0; i < solns.length; i++)
if (solns[i].isClear())
return i;
return size();
}
public void setValid() {valid=true;}
void print() {
String str = "";
for (int i = 0; i < size(); i++)
str += ("["+i + " " + solns[i].score + " i:" + solns[i].pos + " y:" + solns[i].label+"]");
System.out.println(str);
}
public String toString(){
assert(solns != null && solns[0] != null);
String toString = "";
toString += "[" + solns[0].pos + " " + solns[0].label + " " + solns[0].score;
if(solns[0].prevSoln != null)
toString += " : " + solns[0].prevSoln.pos + " " + solns[0].prevSoln.label + " " + solns[0].prevSoln.score;
toString += "]";
return toString;
}
public void sortEntries() {
}
};
Entry winningLabel[][];
protected Entry finalSoln;
protected DoubleMatrix2D Mi;
protected DoubleMatrix1D Ri;
void allocateScratch(int numY) {
Mi = new DenseDoubleMatrix2D(numY,numY);
Ri = new DenseDoubleMatrix1D(numY);
winningLabel = new Entry[numY][];
finalSoln = new Entry(beamsize,0,0);
}
protected void computeLogMi(DataSequence dataSeq, int i, int ell, double lambda[]) {
Trainer.computeLogMi(model.featureGenerator,lambda,dataSeq,i,Mi,Ri,false);
}
double fillArray(DataSequence dataSeq, double lambda[], boolean calcScore) {
double corrScore = 0;
int numY = model.numY;
for (int i = 0; i < dataSeq.length(); i++) {
// compute Mi.
computeLogMi(dataSeq,i,1,lambda);
for (int yi = 0; yi < numY; yi++) {
winningLabel[yi][i].clear();
winningLabel[yi][i].valid = true;
}
for (int yi = model.edgeGen.firstY(i); yi < numY; yi = model.edgeGen.nextY(yi,i)) {
if (i > 0) {
for (int yp = model.edgeGen.first(yi); yp < numY; yp = model.edgeGen.next(yi,yp)) {
double val = Mi.get(yp,yi)+Ri.get(yi);
winningLabel[yi][i].add(winningLabel[yp][i-1], (float)val);
}
} else {
winningLabel[yi][i].add((float)Ri.get(yi));
}
}
if (calcScore)
corrScore += (Ri.get(dataSeq.y(i)) + ((i > 0)?Mi.get(dataSeq.y(i-1),dataSeq.y(i)):0));
}
return corrScore;
}
public double viterbiSearchBackward(DataSequence dataSeq, double lambda[], DoubleMatrix2D Mis[], DoubleMatrix1D Ris[], boolean calcCorrectScore) {
if ((Mi == null)||(winningLabel==null)) {
allocateScratch(model.numY);
}
if ((winningLabel[0] == null) || (winningLabel[0].length < dataSeq.length())) {
for (int yi = 0; yi < winningLabel.length; yi++) {
winningLabel[yi] = new Entry[dataSeq.length()];
for (int l = 0; l < dataSeq.length(); l++)
winningLabel[yi][l] = new Entry(beamsize, yi, l);
}
}
Entry firstEntries[] = new Entry[model.numY];
for (int yi = 0; yi < winningLabel.length; yi++) {
firstEntries[yi] = new Entry(1, yi, 0);
}
double corrScore = fillArrayBackward(dataSeq, lambda,firstEntries, Mis, Ris, calcCorrectScore);
finalSoln.clear();
finalSoln.valid = true;
for (int yi = 0; yi < model.numY; yi++) {
finalSoln.add(firstEntries[yi], 0);
}
return corrScore;
}
double fillArrayBackward(DataSequence dataSeq, double lambda[], Entry firstEntries[], DoubleMatrix2D Mis[], DoubleMatrix1D Ris[], boolean calcScore) {
double corrScore = 0;
int numY = model.numY;
for (int i = dataSeq.length() - 1; i >= 0; i--) {
for (int yi = 0; yi < numY; yi++) {
winningLabel[yi][i].clear();
winningLabel[yi][i].valid = true;
if(i == dataSeq.length() - 1)
winningLabel[yi][i].add(0);
}
}
for (int i = dataSeq.length() - 1; i >= 0; i--) {
// compute Mi.
computeLogMi(dataSeq,i,1,lambda);
Mis[i].assign(Mi);
Ris[i].assign(Ri);
if(i == 0)
break;
for (int yi = model.edgeGen.firstY(i); yi < numY; yi = model.edgeGen.nextY(yi,i)) {
for (int yp = model.edgeGen.first(yi); yp < numY; yp = model.edgeGen.next(yi,yp)){
double val = Mi.get(yp,yi)+Ri.get(yi);
winningLabel[yp][i-1].add(winningLabel[yi][i], (float)val);
}
}
if (calcScore)
corrScore += (Ri.get(dataSeq.y(i)) + ((i > 0)?Mi.get(dataSeq.y(i-1),dataSeq.y(i)):0));
}
for(int yi = 0; yi < numY; yi++){
firstEntries[yi].clear();
firstEntries[yi].valid = true;
firstEntries[yi].add(winningLabel[yi][0], (float)Ri.get(yi));
}
return corrScore;
}
protected void setSegment(DataSequence dataSeq, int prevPos, int pos, int label) {
dataSeq.set_y(pos, label);
}
public double bestLabelSequence(DataSequence dataSeq, double lambda[]) {
double corrScore = viterbiSearch(dataSeq, lambda,false);
if(model.params.debugLvl > 1)
System.out.println("Score of best sequence "+finalSoln.get(0).score + " corrScore " + corrScore);
/*if (finalSoln.get(0).prevSoln == null) {
viterbiSearch(dataSeq, lambda,false);
assert(false);
}
*/
assignLabels(dataSeq);
return finalSoln.get(0).score;
}
protected void assignLabels(DataSequence dataSeq) {
Soln ybest = finalSoln.get(0);
ybest = ybest.prevSoln;
int pos=-1;
assert(ybest.pos == dataSeq.length()-1);
while (ybest != null) {
pos = ybest.pos;
setSegment(dataSeq,ybest.prevPos(),ybest.pos, ybest.label);
ybest = ybest.prevSoln;
}
assert(pos>=0);
}
public double viterbiSearch(DataSequence dataSeq, double lambda[], boolean calcCorrectScore) {
if ((Mi == null)||(winningLabel==null)) {
allocateScratch(model.numY);
}
if ((winningLabel[0] == null) || (winningLabel[0].length < dataSeq.length())) {
for (int yi = 0; yi < winningLabel.length; yi++) {
winningLabel[yi] = new Entry[dataSeq.length()];
for (int l = 0; l < dataSeq.length(); l++)
winningLabel[yi][l] = new Entry((l==0)?1:beamsize, yi, l);
}
}
double corrScore = fillArray(dataSeq, lambda,calcCorrectScore);
finalSoln.clear();
finalSoln.valid = true;
if (dataSeq.length() > 0)
for (int yi = 0; yi < model.numY; yi++) {
finalSoln.add(winningLabel[yi][dataSeq.length()-1], 0);
}
return corrScore;
}
public int numSolutions() {return finalSoln.numSolns();}
public Soln getBestSoln(int k) {
return finalSoln.get(k).prevSoln;
}
protected LabelSequence newLabelSequence(int len){
return new LabelSequence(len);
}
/**
* @param dataSeq
* @param lambda
* @param numLabelSeqs
* @param scores
* @return
*/
public LabelSequence[] topKLabelSequences(DataSequence dataSeq, double[] lambda, int numLabelSeqs, boolean getScores) {
viterbiSearch(dataSeq, lambda,false);
double lZx=0;
if (getScores) {
lZx = model.getLogZx(dataSeq);
}
int numSols = Math.min(finalSoln.numSolns(), numLabelSeqs);
LabelSequence labelSequences[] = new LabelSequence[numSols];
for (int k = numSols-1; k >= 0; k--) {
Soln ybest = finalSoln.get(k);
labelSequences[k] = newLabelSequence(dataSeq.length());
labelSequences[k].score = ybest.score;
if (getScores) labelSequences[k].score = Math.exp((double)ybest.score-lZx);
ybest = ybest.prevSoln;
while (ybest != null) {
labelSequences[k].add(ybest.prevPos(), ybest.pos, ybest.label);
ybest = ybest.prevSoln;
}
labelSequences[k].doneAdd();
}
return labelSequences;
}
};