/** SparseViterbi.java
* Viterbi search.
*
* @author Sunita Sarawagi
* @since 1.2
* @version 1.3
*/
package iitb.CRF;
import iitb.Utils.StaticObjectHeap;
import java.util.Stack;
import cern.colt.function.tdouble.IntDoubleFunction;
import cern.colt.function.tdouble.IntIntDoubleFunction;
import cern.colt.matrix.tdouble.DoubleMatrix1D;
import cern.colt.matrix.tdouble.DoubleMatrix2D;
import cern.colt.matrix.tobject.impl.DenseObjectMatrix1D;
public class SparseViterbi extends Viterbi {
/**
*
*/
private static final long serialVersionUID = -496598232351755202L;
protected SparseViterbi(CRF model, int bs) {
super(model,bs);
}
public class Context extends DenseObjectMatrix1D {
/**
*
*/
private static final long serialVersionUID = 6590594788796602787L;
protected int pos;
protected int beamsize;
protected int startPos=0;
protected Context(int numY, int beamsize, int pos, int startPos){
super(numY);
this.pos = pos;
this.beamsize = beamsize;
this.startPos = startPos;
}
protected Entry newEntry(int beamsize, int label, int pos) {
return new Entry(beamsize,label,pos);
}
public void add(int y, Entry prevEntry, float thisScore) {
if (getQuick(y) == null) {
setQuick(y, newEntry((pos==startPos)?1:beamsize, y, pos));
}
getEntry(y).valid = true;
getEntry(y).add(prevEntry,thisScore);
}
public void clear() {
// assign((Object)null);
for (int i = 0; i < size(); i++)
if (getQuick(i) != null)
getEntry(i).clear();
}
public Entry getEntry(int y) {return (Entry)getQuick(y);}
/**
* @param y
* @return
*/
public boolean entryNotNull(int y) {
return ((getQuick(y) != null) && getEntry(y).valid);
}
void assign(LogSparseDoubleMatrix1D Ri) {
for (int y = 0; y < Ri.size(); y++) {
if (Ri.getQuick(y) != 0)
add(y,null,(float)Ri.get(y));
}
}
public String toString(){
String toString = "";
for (int i = 0; i < size(); i++)
if (getQuick(i) != null)
toString += getEntry(i).toString() + " ";
toString += "\n";
return toString;
}
};
public Context context[];
protected LogSparseDoubleMatrix1D Ri;
StaticHeapLogSparseDoubleMatrix2D staticHeapMi = null;
StaticHeapLogSparseDoubleMatrix1D staticHeapRi = null;
DoubleMatrix2D Mis[][] = null;
DoubleMatrix1D Ris[][] = null;
protected BackwardContextUpdate backwardContextUpdate;
protected ContextUpdate contextUpdate;
protected void computeLogMi(DataSequence dataSeq, int i, int ell, double lambda[]) {
model.featureGenerator.startScanFeaturesAt(dataSeq, i);
SparseTrainer.computeLogMi(model.featureGenerator,lambda,Mi,Ri);
}
protected class Iter {
protected int ell;
protected void start(int i, DataSequence dataSeq) {ell = 1;}
protected int nextEll(int i) {return ell--;}
}
protected Iter getIter(){return new Iter();}
protected void finishContext(int i2) {;}
/**
* @param lambda TODO
* @return
*/
protected double getCorrectScore(DataSequence dataSeq, int i, int ell, double[] lambda) {
return (Ri.getQuick(dataSeq.y(i)) + ((i > 0)?Mi.get(dataSeq.y(i-1),dataSeq.y(i)):0));
}
protected class ContextUpdate implements IntIntDoubleFunction, IntDoubleFunction {
protected int i, ell;
protected Iter iter;
public double apply(int yp, int yi, double val) {
if (context[i-ell].entryNotNull(yp))
context[i].add(yi, context[i-ell].getEntry(yp),(float)(Mi.get(yp,yi)+Ri.get(yi)));
return val;
}
public double apply(int yi, double val) {
context[i].add(yi,null,(float)Ri.get(yi));
return val;
}
double fillArray(DataSequence dataSeq, double lambda[], boolean calcScore) {
return fillArray(dataSeq, lambda, null, null, -1, calcScore);
}
public double fillArray(DataSequence dataSeq, double[] lambda,
DoubleMatrix2D[][] Mis, DoubleMatrix1D[][] Ris,
boolean calcScore) {
return fillArray(dataSeq, lambda, Mis, Ris, -1, calcScore);
}
public double fillArray(DataSequence dataSeq, double[] lambda,
DoubleMatrix2D[][] Mis, DoubleMatrix1D[][] Ris,
Soln soln, boolean calcScore) {
if(soln == null)
return fillArray(dataSeq, lambda, Mis, Ris, -1, calcScore);
for (i = soln.pos; i >= 0; i--)
context[i].clear();
Stack<Soln> stack = new Stack<Soln>();
while(soln != null){
stack.push(soln);
soln = soln.prevSoln;
}
int lastPos = -1;
while(!stack.empty()){
soln = (Soln)stack.pop();
switch(soln.prevPos()){
case -1:
context[soln.pos].add(soln.label,null,(float)soln.score);
break;
default:
if (context[soln.prevPos()].entryNotNull(soln.prevLabel()))
context[soln.pos].add(soln.label, context[soln.prevPos()].getEntry(soln.prevLabel()),soln.score - soln.prevSoln.score);
break;
}
if(lastPos < soln.pos)
lastPos = soln.pos;
}
return fillArray(dataSeq, lambda, Mis, Ris, lastPos, calcScore);
}
private double fillArray(DataSequence dataSeq, double[] lambda,
DoubleMatrix2D[][] Mis, DoubleMatrix1D[][] Ris,
int lastPos, boolean calcScore) {
double corrScore = 0;
DoubleMatrix1D tempRi = null;
DoubleMatrix2D tempMi = null;
if(Mis != null){
tempRi = Ri;
tempMi = Mi;
}
for (i = lastPos + 1; i < dataSeq.length(); i++) {
context[i].clear();
for (iter.start(i,dataSeq); (ell = iter.nextEll(i)) > 0;) {
// i - ell = i'
if(lastPos < 0 || (i - ell) >= lastPos){
// compute Mi.
if(Mis != null){
Ri = (LogSparseDoubleMatrix1D) Ris[i][ell];
Mi = Mis[i][ell];
}else
computeLogMi(dataSeq, i, ell, lambda);
if (i - ell < 0) {
Ri.forEachNonZero(this);
} else {
Mi.forEachNonZero(this);
}
}
if (model.params.debugLvl > 1) {
System.out.println("Ri :"+Ri);
System.out.println("Mi :"+Mi);
}
if (calcScore) {
corrScore += getCorrectScore(dataSeq, i, ell, null);
}
}
finishContext(i);
}
/*
i = dataSeq.length();
context[i].clear();
if (i >= 1) {
for (int yp = 0; yp < context[i-1].size(); yp++) {
if (context[i-1].entryNotNull(yp))
context[i].add(0, context[i-1].getEntry(yp),0);
}
}
*/
if(Mis != null){
Ri = (LogSparseDoubleMatrix1D) tempRi;
Mi = tempMi;
}
return corrScore;
}
};
class BackwardContextUpdate implements IntIntDoubleFunction, IntDoubleFunction {
int i, ell;
Iter iter;
DataSequence dataSeq;
Context firstContext;
public double apply(int yp, int yi, double val) {
if (context[i+1].entryNotNull(yi))
context[i-ell+1].add(yp, context[i+1].getEntry(yi),(float)(Mi.get(yp,yi)+Ri.get(yi)));
return val;
}
public double apply(int yi, double val) {
// this is not quite right since there is no yp value..
context[0].add(0, context[i+1].getEntry(yi),(float) val);
return val;
}
double fillArray(DataSequence dataSeq, double lambda[], boolean calcScore) {
return fillArray(dataSeq, lambda, null, null, calcScore);
}
double fillArray(DataSequence dataSeq, double lambda[],
DoubleMatrix2D Mis[][], DoubleMatrix1D Ris[][], boolean calcScore) {
this.dataSeq = dataSeq;
double corrScore = 0;
DoubleMatrix1D tempRi = null;
DoubleMatrix2D tempMi = null;
if(Mis != null){
tempRi = Ri;
tempMi = Mi;
}
for (i = dataSeq.length(); i >= 0; i--) {
context[i].clear();
}
boolean notInit = true;
for (i = dataSeq.length() - 1; i >= 0; i--) {
for (iter.start(i,dataSeq); (ell = iter.nextEll(i)) > 0;) {
// compute Mi.
// i - ell = i'
if(Mis != null){
Mi = Mis[i][ell];
Ri = (LogSparseDoubleMatrix1D) Ris[i][ell];
}else
computeLogMi(dataSeq, i, ell, lambda);
if (notInit) {
for(int yi=0; yi < Ri.size(); yi++)
context[dataSeq.length()].add(yi, null, 0);
notInit = false;
}
if (i - ell >= 0)
Mi.forEachNonZero(this);
else
Ri.forEachNonZero(this);
if (model.params.debugLvl > 1) {
System.out.println("Ri "+Ri);
System.out.println("Mi "+ Mi);
}
if (calcScore) {
corrScore += getCorrectScore(dataSeq, i, ell, null);
}
}
finishContext(i);
}
if(Mis != null){
Ri = (LogSparseDoubleMatrix1D) tempRi;
Mi = tempMi;
}
return corrScore;
}
};
protected ContextUpdate newContextUpdate() {
return new ContextUpdate();
}
protected void allocateScratch(int numY) {
Mi = new LogSparseDoubleMatrix2D(numY,numY);
Ri = new LogSparseDoubleMatrix1D(numY);
context = new Context[0];
finalSoln = new Entry(beamsize,0,0);
backwardContextUpdate = new BackwardContextUpdate();
backwardContextUpdate.iter = getIter();
contextUpdate = newContextUpdate();
contextUpdate.iter = getIter();
allocateStaticHeaps();
}
void allocateStaticHeaps(){
staticHeapMi = new StaticHeapLogSparseDoubleMatrix2D(0, model.numY);
staticHeapRi = new StaticHeapLogSparseDoubleMatrix1D(0, model.numY);
}
void allocateContext(int numY, int seqLength, int startPos){
Context oldContext[] = context;
context = new Context[seqLength + 1];
for (int l = 0; l < oldContext.length; l++) {
context[l] = oldContext[l];
if ((context[l].startPos == l) && (l != startPos))
context[l] = newContext(numY,beamsize,l,startPos);
}
for (int l = oldContext.length; l < context.length; l++) {
context[l] = newContext(numY,beamsize,l, startPos);
}
}
protected Context newContext(int numY, int beamsize, int pos, int startPos){
return new Context(numY,beamsize,pos, startPos);
}
public double viterbiSearch(DataSequence dataSeq, double lambda[], boolean calcCorrectScore) {
return viterbiSearch(dataSeq, lambda, null, null, calcCorrectScore);
}
public double viterbiSearch(DataSequence dataSeq, double lambda[],
DoubleMatrix2D[][] Mis, DoubleMatrix1D[][] Ris, boolean calScore) {
initSearch(dataSeq.length());
double corrScore = contextUpdate.fillArray(dataSeq, lambda,Mis, Ris, calScore);
if(dataSeq.length() > 0)
calculateFinalSolution(context[dataSeq.length() - 1]);
if (model.params.debugLvl > 1) {
System.out.println("Score of best sequence "+finalSoln.get(0).score + " corrScore " + corrScore);
}
return corrScore;
}
public double viterbiSearch(DataSequence dataSeq, double[] lambda,
DoubleMatrix2D[][] Mis, DoubleMatrix1D[][] Ris,
Soln soln, boolean calScore) {
if(soln == null)
return viterbiSearch(dataSeq, lambda, Mis, Ris, calScore);
initSearch(dataSeq.length());
double corrScore = contextUpdate.fillArray(dataSeq, lambda, Mis, Ris, soln, calScore);
if(dataSeq.length() > 0)
calculateFinalSolution(context[dataSeq.length() - 1]);
if (model.params.debugLvl > 1) {
System.out.println("Score of best sequence "+finalSoln.get(0).score + " corrScore " + corrScore);
}
return corrScore;
}
public double viterbiSearchBackward(DataSequence dataSeq, double lambda[],
DoubleMatrix2D Mis[][], DoubleMatrix1D Ris[][], boolean calcCorrectScore) {
initSearch(dataSeq.length(), dataSeq.length());
double corrScore = backwardContextUpdate.fillArray(dataSeq, lambda, Mis, Ris, calcCorrectScore);
if(context.length > 0)
calculateFinalSolution(context[0]);
if (model.params.debugLvl > 1) {
System.out.println("Score of best sequence "+finalSoln.get(0).score + " corrScore " + corrScore);
}
return corrScore;
}
protected void initSearch(int seqLength) {
initSearch(seqLength,0);
}
protected void initSearch(int seqLength, int startPos){
if (Mi == null)
allocateScratch(model.numY);
if(context.length <= seqLength)
allocateContext(model.numY, seqLength, startPos);
finalSoln.clear();
}
protected void calculateFinalSolution(Context context){
finalSoln.valid = true;
for (int y = 0; y < context.size(); y++) {
if (context.entryNotNull(y)) {
((Entry)context.getQuick(y)).sortEntries();
finalSoln.add((Entry)context.getQuick(y),0);
}
}
}
public void cacheMis(DataSequence dataSeq, double lambda[]){
if(Mi == null)
allocateScratch(model.numY);
allocateCacheArray(dataSeq);
int ell;
Iter iter = contextUpdate.iter;
for (int i = dataSeq.length() - 1; i >= 0; i--) {
for (iter.start(i,dataSeq); (ell = iter.nextEll(i)) > 0;) {
// compute Mi.
// i - ell = i'
computeLogMi(dataSeq, i, ell, lambda);
Mis[i][ell].assign(Mi);
Ris[i][ell].assign(Ri);
}
}
}
private void allocateCacheArray(DataSequence dataSeq) {
int i = -1, seqLength = dataSeq.length(), ell;
Iter iter = getIter();
Mis = new LogSparseDoubleMatrix2D[seqLength][];
Ris = new LogSparseDoubleMatrix1D[seqLength][];
int size = 0, maxEll = 0;
staticHeapMi.reset();
staticHeapRi.reset();
for(i++; i < seqLength; i++){
iter.start(i, dataSeq);
while((ell = iter.nextEll(i)) > 0 )
maxEll = (maxEll < ell) ? ell : maxEll;
size = (i < maxEll ? i+1: maxEll);
Mis[i] = new LogSparseDoubleMatrix2D[size + 1];
Ris[i] = new LogSparseDoubleMatrix1D[size + 1];
for(int j = 0; j < Mis[i].length; j++){
Mis[i][j] = (DoubleMatrix2D) staticHeapMi.getObject(); //new LogSparseDoubleMatrix2D(numY,numY);
Ris[i][j] = (DoubleMatrix1D) staticHeapRi.getObject(); //new LogSparseDoubleMatrix1D(numY);
}
}
}
class StaticHeapLogSparseDoubleMatrix2D extends StaticObjectHeap{
int numY;
public StaticHeapLogSparseDoubleMatrix2D(int initCapacity, int numY) {
super(initCapacity);
this.numY = numY;
}
protected Object getObject() {
return getFreeObject();
}
protected Object newObject() {
return new LogSparseDoubleMatrix2D(numY,numY);
}
}
class StaticHeapLogSparseDoubleMatrix1D extends StaticObjectHeap{
int numY;
public StaticHeapLogSparseDoubleMatrix1D(int initCapacity, int numY) {
super(initCapacity);
this.numY = numY;
}
protected Object newObject() {
return new LogSparseDoubleMatrix1D(numY);
}
protected Object getObject() {
return getFreeObject();
}
}
public DoubleMatrix2D[][] getMis() {
return Mis;
}
public DoubleMatrix1D[][] getRis() {
return Ris;
}
};