/*
* Created on Jul 9, 2005
*
*/
package iitb.CRF;
/**
* @author sunita
*
*/
public class MaxentTrainer extends Trainer {
/**
* @param p
*/
public MaxentTrainer(CrfParams p) {
super(p);
}
protected double sumProduct(DataSequence dataSeq, FeatureGenerator featureGenerator, double[] lambda,
double[] grad, double[] expFVals, boolean onlyForwardPass, int numRecord, FeatureGenerator fgenForExpVals) {
double thisSeqLogli=0;
for (int i = 0; i < dataSeq.length(); i++) {
// compute the Mi matrix
initMDone = computeLogMi(featureGenerator,lambda,dataSeq,i,Mi_YY,Ri_Y,false,reuseM,initMDone);
if (i > 0) {
Ri_Y.assign(Mi_YY.viewRow(dataSeq.y(i-1)),sumFunc);
}
if ((grad !=null)||(expFVals!=null)) {
for (int f = 0; f < lambda.length; f++)
ExpF[f] = RobustMath.LOG0;
// find features that fire at this position..
featureGenerator.startScanFeaturesAt(dataSeq, i);
while (featureGenerator.hasNext()) {
Feature feature = featureGenerator.next();
int f = feature.index();
int yp = feature.y();
int yprev = feature.yprev();
float val = feature.value();
if ((i > 0) && (yprev >= 0) && (yprev != dataSeq.y(i-1)))
continue;
if ((grad != null) && (dataSeq.y(i) == yp) && (((i-1 >= 0) && (yprev == dataSeq.y(i-1))) || (yprev < 0))) {
grad[f] += val;
thisSeqLogli += val*lambda[f];
if (params.debugLvl > 2) {
System.out.println("Feature fired " + f + " " + feature);
}
}
ExpF[f] = RobustMath.logSumExp(ExpF[f], Ri_Y.get(yp) +RobustMath.log(val));
}
}
if (params.debugLvl > 2) {
System.out.println("Ri " + Ri_Y.toString());
}
double lZx = RobustMath.logSumExp(Ri_Y);
thisSeqLogli -= lZx;
// update grad.
if (grad != null) {
for (int f = 0; f < grad.length; f++) {
grad[f] -= RobustMath.exp(ExpF[f]-lZx);
}
}
if (expFVals!=null) {
for (int f = 0; f < lambda.length; f++) {
expFVals[f] += RobustMath.exp(ExpF[f]-lZx);
}
}
if (params.debugLvl > 1) {
System.out.println("Sequence " + thisSeqLogli + " log(Zx) " + lZx + " Zx " + Math.exp(lZx));
}
}
return thisSeqLogli;
}
}