/*
* Created on Apr 23, 2005
*
*/
package iitb.BSegmentCRF;
import cern.colt.function.*;
import cern.colt.function.tdouble.IntDoubleFunction;
import cern.colt.function.tdouble.IntIntDoubleFunction;
import cern.colt.matrix.tdouble.DoubleMatrix1D;
import cern.colt.matrix.tdouble.DoubleMatrix2D;
import iitb.BSegmentCRF.BSegmentTrainer.MatrixWithRange;
import iitb.CRF.*;
import iitb.CRF.SegmentViterbi.SegmentationImpl;
/**
* @author Sunita Sarawagi
* @since 1.2
* @version 1.3
*/
public class BSegmentViterbi extends SparseViterbi {
protected double getCorrectScore(DataSequence dataSeq, int i, int ell, double[] lambda) {
Segmentation segmentation = (Segmentation)dataSeq;
int segNum = segmentation.getSegmentId(i);
int segLength = segmentation.segmentEnd(segNum)-segmentation.segmentStart(segNum)+1;
if ((segmentation.segmentEnd(segNum)!=i) ||
((segLength <= m) && (ell != segLength))
|| ((segLength > m) && (ell != 1)))
return 0;
if (segLength > m) {
fstore.getExactR(i-segLength+1,i,Ri);
if (!reuseM) fstore.getLogMi(i-segLength+1,Mi);
}
double val = (Ri.getQuick(dataSeq.y(i)) + ((i-segLength >= 0)?Mi.get(dataSeq.y(i-segLength),dataSeq.y(i)):0));
//System.out.println("Score for segment: "+ (i-segLength+1) + " " + i + " " + val);
return val;
}
private void adjustScore(DataSequence dataSeq, DoubleMatrix1D ri, MatrixWithRange openri, int i, int ell) {
Segmentation segmentation = (Segmentation)dataSeq;
int segNum = segmentation.getSegmentId(openri.start);
int segStart = segmentation.segmentStart(segNum);
for (int y = 0; y < numY; y++) {
ri.set(y, ri.get(y)+1);
}
if (segStart == openri.start) {
ri.set(dataSeq.y(segStart), ri.get(dataSeq.y(segStart))-1);
} else {
// because the previous segment ended wrongly..
for (int y = 0; y < numY; y++) {
ri.set(y, ri.get(y)+1);
}
}
// now see if there is a corr seg included in this segment.
if (openri.start+1 <= i) {
segNum = segmentation.getSegmentId(openri.start+1);
if (segmentation.segmentStart(segNum)==openri.start+1) {
for (int y = 0; y < numY; y++) {
ri.set(y, ri.get(y)+1);
openri.mat.set(y, openri.mat.get(y)+1);
}
}
}
}
FeatureStore fstore;
BSegmentCRF bmodel;
BSegmentTrainer.MatrixWithRange openRi;
LogSparseDoubleMatrix1D deltaRi,openDeltaRi;
boolean reuseM;
Context openContext[];
int m;
int numY;
boolean lossAugmentedScore=false;
protected BSegmentViterbi(BSegmentCRF model, int numY, int bs) {
super(model, bs);
this.bmodel = model;
reuseM = model.params.reuseM;
this.numY = numY;
}
public BSegmentViterbi(BSegmentCRF model, int numY, int bs, boolean lossAugmentedScore) {
super(model, bs);
this.bmodel = model;
reuseM = model.params.reuseM;
this.numY = numY;
this.lossAugmentedScore = lossAugmentedScore;
}
protected void computeLogMi(DataSequence dataSeq, int i, int ell, double lambda[]) {
if ((openRi.end != i) || (openRi.start < i-ell+1)) {
openRi.init(i+1,i);
}
while (openRi.start != i-ell+1) {
fstore.decrementLeftB(Ri,openRi);
if (lossAugmentedScore) adjustScore(dataSeq,Ri,openRi, i,ell);
}
assert((openRi.start==i-ell+1) && (openRi.end == i));
int ip = i-ell+1;
if (ip-1 >= 0) {
if (!reuseM) {
fstore.getLogMi(ip,Mi);
}
}
}
class ApplyFunc implements IntIntDoubleFunction, IntDoubleFunction {
DoubleMatrix1D matRi;
Context prevContext;
Context thisContext;
DoubleMatrix2D matMi;
ApplyFunc init(DoubleMatrix1D matRi, Context thisContext) {
this.prevContext = null;
this.matRi = matRi;
this.thisContext = thisContext;
return this;
}
ApplyFunc init(DoubleMatrix1D matRi, DoubleMatrix2D matMi, Context prevContext, Context thisContext) {
this.matRi = matRi;
this.matMi = matMi;
this.prevContext = prevContext;
this.thisContext = thisContext;
return this;
}
public double apply(int yp, int yi, double val) {
if (prevContext.entryNotNull(yp))
thisContext.add(yi, prevContext.getEntry(yp),(float)(matMi.get(yp,yi)+matRi.get(yi)));
return val;
}
public double apply(int yi, double val) {
thisContext.add(yi,(prevContext==null)?null:prevContext.getEntry(yi),(float)val);
return val;
}
}
ApplyFunc applyFunc = new ApplyFunc();
protected void finishContext(int i) {
// Ri at this point contains features for segment (i-m+1, i)
openContext[i].clear();
if (i-m >= -1) {
fstore.removeExactEndFeatures(Ri, i-m+1, i);
if (i-m==-1) Ri.forEachNonZero(applyFunc.init(Ri,openContext[i]));
}
if (i-m >= 0) {
fstore.deltaR_RShift(i-m,i,deltaRi, openDeltaRi);
deltaRi.forEachNonZero(applyFunc.init(deltaRi,null,openContext[i-1],context[i]));
openDeltaRi.forEachNonZero(applyFunc.init(openDeltaRi,null,openContext[i-1],openContext[i]));
Mi.forEachNonZero(applyFunc.init(Ri,Mi,context[i-m],openContext[i]));
}
}
class MIter extends Iter {
protected void start(int i, DataSequence dataSeq) {ell = 0;}
protected int nextEll(int i) {
if ((ell < m) && (i-ell >= 0)) {
ell++; return ell;
}
return 0;
}
}
protected Iter getIter(){return new MIter();}
protected void setSegment(DataSequence dataSeq, int prevPos, int pos, int label) {
((Segmentation)dataSeq).setSegment(prevPos+1,pos, label);
}
public static class BSegmentationImpl extends SegmentViterbi.SegmentationImpl {
public void apply(DataSequence data) {
for (int i = 0; i < numSegments(); i++)
((Segmentation)data).setSegment(segmentStart(i),segmentEnd(i),segmentLabel(i));
}
}
protected LabelSequence newLabelSequence(int len){
return new BSegmentationImpl();
}
private static final long serialVersionUID = 1L;
protected void allocateScratch(int numY) {
super.allocateScratch(numY);
applyFunc = new ApplyFunc();
fstore = new FeatureStore(false);
openRi = new BSegmentTrainer.MatrixWithRange(new LogSparseDoubleMatrix1D(numY));
deltaRi = new LogSparseDoubleMatrix1D(numY);
openDeltaRi = new LogSparseDoubleMatrix1D(numY);
m = bmodel.bfgen.maxBoundaryGap();
}
class OpenSoln extends Soln {
/**
*
*/
private static final long serialVersionUID = 6332992368741370660L;
/*
protected void setPrevSoln(Soln prevSoln, float score) {
if (prevSoln instanceof OpenSoln) {
prevSoln = prevSoln.prevSoln;
assert (!(prevSoln instanceof OpenSoln));
}
super.setPrevSoln(prevSoln, score);
}
*/
public OpenSoln(int id, int p) {
super(id, p);
}
protected Soln getSoln() {
// this object does not really store the solution,
// so return prevSoln.
return prevSoln;
}
}
class OpenEntry extends Entry {
protected Soln newSoln(int label, int pos) {
return new OpenSoln(label,pos);
}
protected OpenEntry(int beamsize, int id, int pos) {
super(beamsize,id,pos);
}
}
class OpenContext extends Context {
/**
* @param arg0
* @param arg1
* @param arg2
*/
protected OpenContext(int arg0, int arg1, int arg2) {
super(arg0, arg1, arg2, 0);
// TODO Auto-generated constructor stub
}
protected Entry newEntry(int beamsize, int label, int pos) {
return new OpenEntry(beamsize,label,pos);
}
private static final long serialVersionUID = 1L;
}
public double viterbiSearch(DataSequence dataSeq, double lambda[], boolean calcCorrectScore) {
allocateScratch(numY);
fstore.init(dataSeq,bmodel.bfgen,lambda,numY);
openRi.init(1,0);
if ((reuseM) && (dataSeq.length() > 0))
fstore.getLogMi(1,Mi);
if ((openContext == null) || openContext.length < dataSeq.length()) {
int start = 0;
Context oldopenContext[] = openContext;
openContext = new Context[2*dataSeq.length()];
if (oldopenContext != null) {
for (int l = 0; l < oldopenContext.length; l++) {
openContext[l] = oldopenContext[l];
}
start = oldopenContext.length;
}
for (int l = start; l < openContext.length; l++) {
openContext[l] = new OpenContext(numY,beamsize,l);
}
}
double corrScore = super.viterbiSearch(dataSeq,lambda,calcCorrectScore);
/* if (calcCorrectScore) {
double score = 0;
fstore.printFeatures = true;
Segmentation segmentation = (Segmentation)dataSeq;
for (int segNum = 0; segNum < segmentation.numSegments(); segNum++) {
int segStart = segmentation.segmentStart(segNum);
int segEnd = segmentation.segmentEnd(segNum);
fstore.getExactR(segStart,segEnd,Ri);
double val = Ri.get(dataSeq.y(segEnd));
if (segNum > 0) {
if (!reuseM) fstore.getLogMi(segStart,Mi);
val += Mi.get(dataSeq.y(segStart-1),dataSeq.y(segEnd));
}
System.out.println("CScore for segment: "+ segStart + " "+segEnd + " " + val);
score += val;
}
fstore.printFeatures = false;
//assert (score == corrScore);
}
*/
return corrScore;
}
}