package iitb.CRF;
import cern.colt.matrix.tdouble.impl.*;
/**
*
* @author Sunita Sarawagi
*
*/
class NestedTrainer extends Trainer {
public NestedTrainer(CrfParams p) {
super(p);
}
DenseDoubleMatrix1D alpha_Y_Array[];
protected double sumProductInner(DataSequence data, FeatureGenerator featureGenerator,
double lambda[], double grad[],
boolean onlyForwardPass, int numRecord, FeatureGenerator fgenForExpVals) {
FeatureGeneratorNested featureGenNested = (FeatureGeneratorNested)featureGenerator;
SegmentDataSequence dataSeq = (SegmentDataSequence)data;
int base = -1;
if ((alpha_Y_Array == null) || (alpha_Y_Array.length < dataSeq.length()-base)) {
alpha_Y_Array = new DenseDoubleMatrix1D[2*dataSeq.length()];
for (int i = 0; i < alpha_Y_Array.length; i++)
alpha_Y_Array[i] = new DenseDoubleMatrix1D(numY);
}
if ((beta_Y == null) || (beta_Y.length < dataSeq.length())) {
beta_Y = new DenseDoubleMatrix1D[2*dataSeq.length()];
for (int i = 0; i < beta_Y.length; i++)
beta_Y[i] = new DenseDoubleMatrix1D(numY);
}
// compute beta values in a backward scan.
// also scale beta-values as much as possible to avoid numerical overflows
beta_Y[dataSeq.length()-1].assign(0);
for (int i = dataSeq.length()-2; i >= 0; i--) {
beta_Y[i].assign(RobustMath.LOG0);
for (int ell = 1; (ell <= featureGenNested.maxMemory()) && (i+ell < dataSeq.length()); ell++) {
// compute the Mi matrix
featureGenNested.startScanFeaturesAt(dataSeq, i, i+ell);
//if (! featureGenNested.hasNext())
// break;
initMDone = computeLogMi(featureGenNested,lambda,Mi_YY,Ri_Y,false,reuseM,initMDone);
tmp_Y.assign(beta_Y[i+ell]);
tmp_Y.assign(Ri_Y,sumFunc);
RobustMath.logMult(Mi_YY, tmp_Y, beta_Y[i],1,1,false,edgeGen);
}
}
double thisSeqLogli = 0;
alpha_Y_Array[0].assign(0);
int segmentStart = 0;
int segmentEnd = -1;
boolean invalid = false;
for (int i = 0; i < dataSeq.length(); i++) {
if (segmentEnd < i) {
segmentStart = i;
segmentEnd = dataSeq.getSegmentEnd(i);
}
if (segmentEnd-segmentStart+1 > featureGenNested.maxMemory()) {
if (icall == 0) {
System.out.println("Ignoring record with segment length greater than maxMemory " + dataSeq);
}
invalid = true;
break;
}
alpha_Y_Array[i-base].assign(RobustMath.LOG0);
for (int ell = 1; (ell <= featureGenNested.maxMemory()) && (i-ell >= base); ell++) {
// compute the Mi matrix
featureGenNested.startScanFeaturesAt(dataSeq, i-ell,i);
// if (!featureGenNested.hasNext())
// break;
initMDone = computeLogMi(featureGenNested,lambda,Mi_YY,Ri_Y,false,reuseM,initMDone);
if (fgenForExpVals != null) {
// find features that fire at this position..
((FeatureGeneratorNested)fgenForExpVals).startScanFeaturesAt(dataSeq, i-ell,i);
boolean isSegment = ((i-ell+1==segmentStart) && (i == segmentEnd));
while (fgenForExpVals.hasNext()) {
Feature feature = fgenForExpVals.next();
int f = feature.index();
int yp = feature.y();
int yprev = feature.yprev();
float val = feature.value();
boolean allEllMatch = isSegment && (dataSeq.y(i) == yp);
if (allEllMatch && (((i-ell >= 0) && (yprev == dataSeq.y(i-ell))) || (yprev < 0))) {
grad[f] += val;
thisSeqLogli += val*lambda[f];
}
if ((yprev < 0) && (i-ell >= 0)) {
for (yprev = 0; yprev < Mi_YY.rows(); yprev++)
ExpF[f] = RobustMath.logSumExp(ExpF[f], (alpha_Y_Array[i-ell-base].get(yprev)+Ri_Y.get(yp)+Mi_YY.get(yprev,yp) + RobustMath.log(val)+beta_Y[i].get(yp)));
} else if (i-ell < 0) {
ExpF[f] = RobustMath.logSumExp(ExpF[f], (Ri_Y.get(yp)+RobustMath.log(val)+beta_Y[i].get(yp)));
} else {
ExpF[f] = RobustMath.logSumExp(ExpF[f], (alpha_Y_Array[i-ell-base].get(yprev)+Ri_Y.get(yp)+Mi_YY.get(yprev,yp)+RobustMath.log(val)+beta_Y[i].get(yp)));
}
}
}
if (i-ell >= 0) {
RobustMath.logMult(Mi_YY, alpha_Y_Array[i-ell-base],tmp_Y,1,0,true,edgeGen);
tmp_Y.assign(Ri_Y,sumFunc);
RobustMath.logSumExp(alpha_Y_Array[i-base],tmp_Y);
} else {
RobustMath.logSumExp(alpha_Y_Array[i-base],Ri_Y);
}
}
if (params.debugLvl > 2) {
System.out.println("Alpha-i " + alpha_Y_Array[i-base].toString());
System.out.println("Ri " + Ri_Y.toString());
System.out.println("Mi " + Mi_YY.toString());
System.out.println("Beta-i " + beta_Y[i].toString());
}
if (params.debugLvl > 1) {
System.out.println(" pos " + i + " " + thisSeqLogli);
}
}
if (invalid)
return 0;
lZx = RobustMath.logSumExp(alpha_Y_Array[dataSeq.length()-1-base]);
return thisSeqLogli;
}
};