/** SegmentViterbi.java
*
* @author Sunita Sarawagi
* @since 1.2
* @version 1.3
*/
package iitb.CRF;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.map.hash.TIntFloatHashMap;
import gnu.trove.procedure.TIntProcedure;
import gnu.trove.set.hash.TIntHashSet;
import java.util.Arrays;
import java.util.Iterator;
import java.util.TreeSet;
import cern.colt.matrix.tdouble.DoubleMatrix1D;
import cern.colt.matrix.tdouble.DoubleMatrix2D;
/**
*
* @author Sunita Sarawagi
*
*/
public class SegmentViterbi extends SparseViterbi {
/**
*
*/
private static final long serialVersionUID = -3055942733801323491L;
protected SegmentCRF segmentModel;
protected FeatureGeneratorNested featureGenNested;
public static class LabelConstraints {
private static final long serialVersionUID = 1L;
protected ConstraintDisallowedPairs disallowedPairs;
protected ConstraintDisallowedPairsExtended disallowedPairsExt;
public class Intersects implements TIntProcedure {
public int label;
public int prevLabel;
public boolean execute(int arg0) {
return !disallowedPairs.conflictingPair(label,arg0,prevLabel);
}
}
protected Intersects intersectTest = new Intersects();
/**
* @param pairs
*/
public LabelConstraints(ConstraintDisallowedPairs pairs) {
disallowedPairs = pairs;
if (disallowedPairs instanceof ConstraintDisallowedPairsExtended)
disallowedPairsExt = (ConstraintDisallowedPairsExtended) disallowedPairs;
else
disallowedPairsExt = null;
}
public LabelConstraints(LabelConstraints labelCons) {
this(labelCons.disallowedPairs);
}
/**
* @param set
* @param prevLabel
* @param i
* @return
*/
public boolean valid(TIntHashSet set, int label, int prevLabel) {
if (!conflicting(label))
return true;
if (disallowedPairs.conflictingPair(label,prevLabel,-1))
return false;
intersectTest.label = label;
intersectTest.prevLabel = prevLabel;
return set.forEach(intersectTest);
}
public boolean valid(TIntHashSet set, int yp, TIntHashSet set2) {
intersectTest.label = yp;
intersectTest.prevLabel = -1;
boolean isValid = !conflicting(yp) || set2.forEach(intersectTest);
for(TIntIterator iter = set.iterator(); iter.hasNext();) {
if (!isValid) return false;
intersectTest.label = iter.next();
isValid = set2.forEach(intersectTest);
}
return isValid;
}
public boolean match(TIntHashSet set1, TIntHashSet set2) {
return set1.equals(set2);
}
public TIntHashSet formPathLabels(TIntHashSet set, int label, TIntHashSet labelsOnPath) {
if (!conflicting(label))
return set;
labelsOnPath.clear();
labelsOnPath.add(canonicalId(label));
for(TIntIterator iter = set.iterator(); iter.hasNext(); labelsOnPath.add(iter.next()));
// labelsOnPath.addAll(set.toArray());
return labelsOnPath;
}
public int canonicalId(int label) {
return disallowedPairsExt!=null?disallowedPairsExt.canonicalId(label):label;
}
/**
* @param dataSeq
* @return
*/
public static LabelConstraints checkConstraints(CandSegDataSequence dataSeq, LabelConstraints labelCons) {
Iterator constraints = dataSeq.constraints(-1,dataSeq.length());
if (constraints != null) {
for (; constraints.hasNext();) {
Constraint constraint = (Constraint)constraints.next();
if (constraint.type() == Constraint.PAIR_DISALLOW) {
if (labelCons != null) {
labelCons.init((ConstraintDisallowedPairs)constraint);
return labelCons;
} else
return new LabelConstraints((ConstraintDisallowedPairs)constraint);
}
}
}
return null;
}
public TIntHashSet formPreviousLabel(TIntHashSet prevLabelsOnPath, TIntHashSet labelsOnPath, int prevLabel) {
labelsOnPath.clear();
for(TIntIterator iter = prevLabelsOnPath.iterator(); iter.hasNext(); labelsOnPath.add(iter.next()));
if (conflicting(prevLabel))
labelsOnPath.add(canonicalId(prevLabel));
return labelsOnPath;
}
protected void init(ConstraintDisallowedPairs pairs) {
disallowedPairs = pairs;
}
/**
* @param label
* @return
*/
public boolean conflicting(int label) {
return disallowedPairs.conflicting(label);
}
public int countConflicting(int numY) {
TIntHashSet maxSet = new TIntHashSet();
for (int i = 0; i < numY; i++) {
if (conflicting(i))
maxSet.add(canonicalId(i));
}
return Math.min(1 << maxSet.size(), 20);
}
public boolean contained(TIntHashSet labelsOnPath, TIntHashSet prevLabels) {
if (labelsOnPath == null) return true;
for(TIntIterator iter = labelsOnPath.iterator(); iter.hasNext();) {
int thisL = iter.next();
if (prevLabels != null && prevLabels.contains(thisL)) continue;
return false;
}
return true;
}
}
LabelConstraints labelConstraints=null;
public class SolnWithLabelsOnPath extends Soln {
public void clear() {
super.clear();
labelsOnPath.clear();
}
protected void copy(Soln soln) {
super.copy(soln);
labelsOnPath.clear();
// labelsOnPath.addAll(((SolnWithLabelsOnPath)soln).labelsOnPath.toArray());
for(TIntIterator iter = ((SolnWithLabelsOnPath)soln).labelsOnPath.iterator(); iter.hasNext(); labelsOnPath.add(iter.next()));
}
private static final long serialVersionUID = 1L;
public TIntHashSet labelsOnPath;
/**
* @param id
* @param p
*/
SolnWithLabelsOnPath(int id, int p) {
super(id, p);
labelsOnPath = new TIntHashSet();
}
public void setPrevSoln(Soln prevSoln, float score) {
super.setPrevSoln(prevSoln,score);
if ((prevSoln != null) && (labelConstraints != null)) {
labelConstraints.formPreviousLabel(((SolnWithLabelsOnPath)prevSoln).labelsOnPath, labelsOnPath, prevSoln.label);
}
}
}
public class EntryForLabelConstraints extends Entry {
TIntHashSet tmpLabels = new TIntHashSet();
/**
* @param beamsize
* @param id
* @param pos
*/
protected EntryForLabelConstraints(int beamsize, int id, int pos, int numStatComb) {
super();
if (beamsize > 1)
throw new UnsupportedOperationException();
solns = new Soln[beamsize*numStatComb];
for (int i = 0; i < solns.length; i++)
solns[i] = new SolnWithLabelsOnPath(id, pos);
}
protected int findInsert(int insertPos, float score, Soln prev) {
SolnWithLabelsOnPath prevSolL = (SolnWithLabelsOnPath)prev;
// this solution conflicts with the label in this state..
if ((prev != null) && labelConstraints!=null
&& !labelConstraints.valid(prevSolL.labelsOnPath,get(0).label, prev.label)) {
return insertPos;
}
TIntHashSet prevLabels = ((prevSolL != null && labelConstraints != null)?labelConstraints.formPreviousLabel(prevSolL.labelsOnPath, tmpLabels, prevSolL.label):null);
for (insertPos=0; insertPos < size(); insertPos++) {
// if a better solution with a less restrictive condition exists..do not keep this one.
if (score <= get(insertPos).score) {
if (prev == null || labelConstraints == null)
return 0;
SolnWithLabelsOnPath thisSolL = (SolnWithLabelsOnPath) get(insertPos);
if (labelConstraints.contained(thisSolL.labelsOnPath, prevLabels))
return 0;
}
}
int minPos = -1; float minScore = Float.MAX_VALUE;
for (insertPos=0; insertPos < size(); insertPos++) {
if (score > get(insertPos).score) {
if (minScore > get(insertPos).score) {
minScore = get(insertPos).score;
minPos = insertPos;
}
if (labelConstraints != null && labelConstraints.contained(prevLabels, ((SolnWithLabelsOnPath) get(insertPos)).labelsOnPath)) {
insert(insertPos, score, prev);
return insertPos;
}
}
}
if (minPos >= 0) {
insert(minPos, score, prev);
}
return minPos;
}
public void sortEntries() {
Arrays.sort(solns);
for (int i = 0; i < solns.length/2; i++) {
Soln tmp = solns[i];
solns[i] = solns[solns.length-1-i];
solns[solns.length-1-i] = tmp;
}
for (int i = 1; i < solns.length; i++) {
assert(solns[i-1].score >= solns[i].score);
}
}
}
class ContextForLabelConstraints extends Context {
ContextForLabelConstraints(int numY, int beamsize, int pos, int startPos) {
super(numY, beamsize, pos, startPos);
}
private static final long serialVersionUID = 1L;
public void add(int y, Entry prevSoln, float thisScore) {
if (labelConstraints==null) {
super.add(y,prevSoln,thisScore);
} else {
if (getQuick(y) == null) {
setQuick(y, new EntryForLabelConstraints((pos==startPos)?1:beamsize, y, pos, labelConstraints.countConflicting((int) size())));
}
super.add(y,prevSoln,thisScore);
}
}
}
boolean markovModel=false;
public SegmentViterbi(SegmentCRF nestedModel, int bs) {
super(nestedModel, bs);
this.segmentModel = nestedModel;
this.featureGenNested = segmentModel.featureGenNested;
setMarkovState(segmentModel.params);
}
private void setMarkovState(CrfParams params) {
if (!params.miscOptions.getProperty("modelGraph", "semi-markov").equalsIgnoreCase("semi-markov"))
markovModel=true;
}
public SegmentViterbi(CRF model,int bs) {
super(model, bs);
this.featureGenNested = (FeatureGeneratorNested) model.featureGenerator;
setMarkovState(model.params);
}
protected void computeLogMi(DataSequence dataSeq, int i, int ell, double lambda[]) {
if (featureGenNested==null) {
featureGenNested = segmentModel.featureGenNested;
}
SegmentTrainer.computeLogMi((CandSegDataSequence)dataSeq,i-ell,i,featureGenNested,lambda,Mi,Ri);
}
class SegmentIter extends Iter {
int nc;
CandidateSegments candidateSegs;
protected void start(int i, DataSequence dataSeq) {
candidateSegs = (CandidateSegments)dataSeq;
nc = candidateSegs.numCandSegmentsEndingAt(i);
}
protected int nextEll(int i) {
nc--;
if (nc >= 0)
return i - candidateSegs.candSegmentStart(i,nc) + 1;
return -1;
}
}
protected Iter getIter(){return (markovModel?new Iter():new SegmentIter());}
/**
* @return
*/
int prevSegEnd = -1;
protected double getCorrectScore(DataSequence dataSeq, int i, int ell, double[] lambda) {
SegmentDataSequence data = (SegmentDataSequence)dataSeq;
if (data.getSegmentEnd(i-ell+1) != i)
return 0;
if ((i - ell >= 0) && (prevSegEnd != i-ell))
return RobustMath.LOG0;
prevSegEnd = i;
if ((labelConstraints != null) && labelConstraints.conflicting(data.y(i))) {
for (int segStart = 0; segStart < i-ell+1; segStart = data.getSegmentEnd(segStart)+1) {
int segEnd = data.getSegmentEnd(segStart);
if (labelConstraints.disallowedPairs.conflictingPair(data.y(i),data.y(segStart),(segEnd==i-ell)?-1:0)) // TODO: 0 here is not correct.
return RobustMath.LOG0;
}
}
if (model.params.debugLvl > 1) {
// output features that hold
featureGenNested.startScanFeaturesAt(dataSeq,i-ell,i);
while (featureGenNested.hasNext()) {
Feature f = featureGenNested.next();
if (((CandSegDataSequence)data).holdsInTrainingData(f,i-ell,i)) {
System.out.println("Feature " + (i-ell) + " " + i + " " + featureGenNested.featureName(f.index()) + " " + lambda[f.index()] + " " + f.value());
}
}
}
double val = (Ri.getQuick(dataSeq.y(i)) + ((i-ell >= 0)?Mi.get(dataSeq.y(i-ell),dataSeq.y(i)):0));
if (Double.isInfinite(val)) {
System.out.println("Infinite score");
}
return val;
}
protected void setSegment(DataSequence dataSeq, int prevPos, int pos, int label) {
((CandSegDataSequence)dataSeq).setSegment(prevPos+1,pos, label);
}
public void singleSegmentClassScores(CandSegDataSequence dataSeq, double lambda[], TIntFloatHashMap scores) {
viterbiSearch(dataSeq, lambda,false);
scores.clear();
int i = dataSeq.length()-1;
if (i >= 0) {
double norm = RobustMath.LOG0;
for (int y = 0; y < context[i].size(); y++) {
if (context[i].entryNotNull(y)) {
Soln soln = ((Entry)context[i].getQuick(y)).get(0);
assert (soln.prevSoln == null); // only applicable for single segment.
norm = RobustMath.logSumExp(norm,soln.score);
}
}
for (int y = 0; y < context[i].size(); y++) {
if (context[i].entryNotNull(y)) {
Soln soln = ((Entry)context[i].getQuick(y)).get(0);
scores.put(soln.label,(float)Math.exp(soln.score-norm));
}
}
/*context[i].getNonZeros(validPrevYs, prevContext);
for (int prevPx = 0; prevPx < validPrevYs.size(); prevPx++) {
Soln soln = ((Entry)prevContext.getQuick(prevPx)).get(0);
assert (soln.prevSoln == null); // only applicable for single segment.
norm = RobustMath.logSumExp(norm,soln.score);
}
for (int prevPx = 0; prevPx < validPrevYs.size(); prevPx++) {
Soln soln = ((Entry)prevContext.getQuick(prevPx)).get(0);
scores.put(soln.label,(float)Math.exp(soln.score-norm));
}
*/
}
}
protected Context newContext(int numY, int beamsize, int pos, int startPos){
if (labelConstraints == null)
return new Context(numY,beamsize,pos, startPos);
return new ContextForLabelConstraints(numY,beamsize,pos,startPos);
}
public double viterbiSearch(DataSequence dataSeq, double[] lambda,
boolean calcCorrectScore) {
//labelConstraints = LabelConstraints.checkConstraints((CandSegDataSequence)dataSeq, labelConstraints);
return viterbiSearch(dataSeq, lambda, null, null, true, calcCorrectScore);
}
public double viterbiSearch(DataSequence dataSeq, double lambda[],
DoubleMatrix2D[][] Mis, DoubleMatrix1D[][] Ris,
boolean constraints, boolean calCorrectScore) {
if(constraints)
labelConstraints = LabelConstraints.checkConstraints((CandSegDataSequence)dataSeq, labelConstraints);
else
labelConstraints = null;
return super.viterbiSearch(dataSeq, lambda, Mis, Ris, calCorrectScore);
}
public double viterbiSearch(DataSequence dataSeq, double lambda[],
DoubleMatrix2D[][] Mis, DoubleMatrix1D[][] Ris,
Soln soln, boolean constraints, boolean calCorrectScore) {
if(constraints)
labelConstraints = LabelConstraints.checkConstraints((CandSegDataSequence)dataSeq, labelConstraints);
else
labelConstraints = null;
return super.viterbiSearch(dataSeq, lambda, Mis, Ris, soln, calCorrectScore);
}
public double sumScoreTopKViolators(DataSequence dataSeq, double lambda[]) {
LabelConstraints labelCons = LabelConstraints.checkConstraints((CandSegDataSequence)dataSeq, labelConstraints);
if (labelCons==null)
return RobustMath.LOG0;
int oldbeamsize=beamsize;
beamsize=20;
viterbiSearch(dataSeq,lambda,null,null,false,false);
double totalScore=RobustMath.LOG0;
int numSols = finalSoln.numSolns();
for (int k = numSols-1; k >= 0; k--) {
Soln ybest = finalSoln.get(k);
float score = ybest.score;
ybest = ybest.prevSoln;
TIntHashSet labelsSeen=new TIntHashSet();
boolean violating=false;
while (ybest != null) {
if (!labelCons.valid(labelsSeen,ybest.label,-1)) {
violating=true;
break;
}
if (labelCons.conflicting(ybest.label)) {
labelsSeen.add(ybest.label);
}
ybest = ybest.prevSoln;
}
if (violating) totalScore = RobustMath.logSumExp(score,totalScore);
}
beamsize=oldbeamsize;
return totalScore;
}
public double viterbiSearchBackward(DataSequence dataSeq, double[] lambda,
DoubleMatrix2D Mis[][],DoubleMatrix1D Ris[][],
boolean calcCorrectScore) {
labelConstraints = null;
return super.viterbiSearchBackward(dataSeq, lambda, Mis, Ris, calcCorrectScore);
}
public double viterbiSearchBackward(DataSequence dataSeq, double[] lambda,
DoubleMatrix2D Mis[][],DoubleMatrix1D Ris[][], boolean constraints,
boolean calcCorrectScore) {
if(constraints)
labelConstraints = LabelConstraints.checkConstraints((CandSegDataSequence)dataSeq, labelConstraints);
else
labelConstraints = null;
return super.viterbiSearchBackward(dataSeq, lambda, Mis, Ris, calcCorrectScore);
}
public static class SegmentationImpl extends LabelSequence implements Segmentation {
class Segment implements Comparable {
int start;
int end;
int label;
int id;
Segment(int start, int end, int label) {
this.start = start;
this.end = end;
this.label = label;
}
/* (non-Javadoc)
* @see java.lang.Comparable#compareTo(java.lang.Object)
*/
public int compareTo(Object arg0) {
return end - ((Segment)arg0).end;
}
}
TreeSet<Segment> segments = new TreeSet<Segment>();
Segment segmentArr[]=null;
Segment dummySegment = new Segment(0,0,0);
/* (non-Javadoc)
* @see iitb.CRF.Segmentation#numSegments()
*/
public int numSegments() {
return segments.size();
}
/* (non-Javadoc)
* @see iitb.CRF.Segmentation#segmentLabel(int)
*/
public int segmentLabel(int segmentNum) {
return segmentArr[segmentNum].label;
}
/* (non-Javadoc)
* @see iitb.CRF.Segmentation#segmentStart(int)
*/
public int segmentStart(int segmentNum) {
return segmentArr[segmentNum].start;
}
/* (non-Javadoc)
* @see iitb.CRF.Segmentation#segmentEnd(int)
*/
public int segmentEnd(int segmentNum) {
return segmentArr[segmentNum].end;
}
/* (non-Javadoc)
* @see iitb.CRF.Segmentation#getSegmentId(int)
*/
public int getSegmentId(int offset) {
dummySegment.end = offset;
// if (segments.headSet(dummySegment) == null)
// return 0;
return ((Segment)segments.tailSet(dummySegment).first()).id;
}
/* (non-Javadoc)
* @see iitb.CRF.Segmentation#setSegment(int, int, int)
*/
public void setSegment(int segmentStart, int segmentEnd, int label) {
Segment segment = new Segment(segmentStart, segmentEnd, label);
segments.add(segment);
}
public void doneAdd() {
segmentArr = new Segment[segments.size()];
int p = 0;
for (Iterator<Segment> iter = segments.iterator(); iter.hasNext();) {
segmentArr[p++] = iter.next();
}
for (int i = segmentArr.length-1; i >= 0; segmentArr[i].id = i, i--);
}
public void apply(DataSequence data) {
for (int i = 0; i < numSegments(); i++)
((CandSegDataSequence)data).setSegment(segmentStart(i),segmentEnd(i),segmentLabel(i));
}
/**
* @param prevPos
* @param pos
* @param label
*/
public void add(int prevPos, int pos, int label) {
setSegment(prevPos+1,pos,label);
};
};
public Segmentation[] segmentSequences(CandSegDataSequence dataSeq, double lambda[], int numLabelSeqs, double[] scores) {
viterbiSearch(dataSeq, lambda,false);
int numSols = Math.min(finalSoln.numSolns(), numLabelSeqs);
Segmentation segments[] = new Segmentation[numSols];
for (int k = numSols-1; k >= 0; k--) {
Soln ybest = finalSoln.get(k);
if (scores != null) scores[k] = (double)ybest.score;
ybest = ybest.prevSoln;
segments[k] = new SegmentationImpl();
while (ybest != null) {
segments[k].setSegment(ybest.prevPos()+1,ybest.pos,ybest.label);
ybest = ybest.prevSoln;
}
((SegmentationImpl)segments[k]).doneAdd();
}
if (scores!=null) {
double lZx = model.getLogZx(dataSeq);
if (scores.length > numSols) scores[numSols] = lZx;
for (int i = 0; i < numSols; i++) {
scores[i] = Math.min(Math.exp(scores[i]-lZx),1);
}
}
return segments;
}
protected LabelSequence newLabelSequence(int len){
return new SegmentationImpl();
}
};