/** SegmentViterbiPartialLabeled.java
*
* @author Sunita Sarawagi
* @since 1.3
* @version 1.3
*/
package iitb.CRF;
import gnu.trove.set.hash.TIntHashSet;
import cern.colt.matrix.tdouble.DoubleMatrix1D;
import cern.colt.matrix.tdouble.DoubleMatrix2D;
public class SegmentViterbiPartialLabeled extends SegmentViterbi {
public SegmentViterbiPartialLabeled(SegmentCRF nestedModel, int bs) {
super(nestedModel, bs);
}
public SegmentViterbiPartialLabeled(CRF model, int bs) {
super(model, bs);
}
@Override
protected void computeLogMi(DataSequence dataSeq, int i, int ell, double[] lambda) {
super.computeLogMi(dataSeq, i, ell, lambda);
if (dataSeq.y(i) >= 0) {
for (int j = 0; j < Ri.size(); j++) {
if (j != dataSeq.y(i))
Ri.set(j, RobustMath.LOG0);
}
assert(Ri.get(dataSeq.y(i)) > RobustMath.LOG0);
}
if (usedLabels != null && labelConstraints != null &&
usedLabels.length >= dataSeq.length() && usedLabels[i].size() > 0) {
for (int y = 0; y < Ri.size(); y++) {
if (!labelConstraints.valid(usedLabels[i], y, -1)) {
Ri.set(y, RobustMath.LOG0);
}
}
}
}
TIntHashSet usedLabels[];
public double viterbiSearch(DataSequence dataSeq, double lambda[],
DoubleMatrix2D[][] Mis, DoubleMatrix1D[][] Ris,
boolean constraints, boolean calCorrectScore) {
if(constraints)
labelConstraints = LabelConstraints.checkConstraints((CandSegDataSequence)dataSeq, labelConstraints);
else
labelConstraints = null;
if (labelConstraints != null) {
if (usedLabels==null || usedLabels.length < dataSeq.length()) {
usedLabels = new TIntHashSet[dataSeq.length()];
for (int i = 0; i < usedLabels.length; i++) {
usedLabels[i] = new TIntHashSet();
}
} else {
for (int i = 0; i < dataSeq.length(); i++) {
usedLabels[i].clear();
}
}
for (int i = dataSeq.length()-1; i >= 0; i--) {
if (dataSeq.y(i) < 0) continue;
if (!labelConstraints.conflicting(dataSeq.y(i))) continue;
for (int j = i-1; j >= 0; j--) {
usedLabels[j].add(dataSeq.y(i));
}
}
}
return super.viterbiSearch(dataSeq, lambda, Mis, Ris, calCorrectScore);
}
}