/*
* Created on Oct 15, 2005
* This class implements the piecewise approximate trainer as described
* in this paper: Piecewise Training for Undirected Models. Charles Sutton and Andrew McCallum. UAI, 2005
*/
package iitb.CRF;
public class PiecewiseTrainer extends Trainer {
public PiecewiseTrainer(CrfParams p) {
super(p);
}
protected double sumProduct(DataSequence dataSeq, FeatureGenerator featureGenerator,
double lambda[], double grad[], double expFVals[], boolean onlyForwardPass, int numRecord,
FeatureGenerator fgenForExpVals) {
double thisSeqLogli = 0;
for (int i = 0; i < dataSeq.length(); i++) {
// compute the Mi matrix
initMDone = computeLogMi(featureGenerator,lambda,dataSeq,i,Mi_YY,Ri_Y,false,reuseM,initMDone);
// compute log partition function.
alpha_Y.assign(Ri_Y);
if (i > 0) {
for (int colNum = 0; colNum < Mi_YY.columns(); colNum++) {
alpha_Y.set(colNum, Ri_Y.get(colNum)+RobustMath.logSumExp(Mi_YY.viewColumn(colNum)));
}
}
lZx = RobustMath.logSumExp(alpha_Y);
if (fgenForExpVals != null) {
for (int f = 0; f < ExpF.length; f++)
ExpF[f] = RobustMath.LOG0;
// find features that fire at this position..
fgenForExpVals.startScanFeaturesAt(dataSeq, i);
while (fgenForExpVals.hasNext()) {
Feature feature = fgenForExpVals.next();
int f = feature.index();
int yp = feature.y();
int yprev = feature.yprev();
float val = feature.value();
if ((grad != null) && (dataSeq.y(i) == yp) && (((i-1 >= 0) && (yprev == dataSeq.y(i-1))) || (yprev < 0))) {
grad[f] += val;
thisSeqLogli += val*lambda[f];
if (params.debugLvl > 2) {
System.out.println("Feature fired " + f + " " + feature);
}
}
if (yprev < 0) {
ExpF[f] = RobustMath.logSumExp(ExpF[f], alpha_Y.get(yp)+RobustMath.log(val));
} else {
ExpF[f] = RobustMath.logSumExp(ExpF[f], Ri_Y.get(yp)+Mi_YY.get(yprev,yp)+RobustMath.log(val));
}
}
for (int f = 0; f < grad.length; f++) {
grad[f] -= RobustMath.exp(ExpF[f]-lZx);
}
}
thisSeqLogli -= lZx;
if (params.debugLvl > 2) {
System.out.println("Alpha-i " + alpha_Y.toString());
System.out.println("Ri " + Ri_Y.toString());
System.out.println("Mi " + Mi_YY.toString());
System.out.println("Beta-i " + beta_Y[i].toString());
}
}
return thisSeqLogli;
}
}