/** SegmentTrainer.java
*
* @author Sunita Sarawagi
* @since 1.2
* @version 1.3
*/
package iitb.CRF;
import gnu.trove.iterator.TIntDoubleIterator;
import gnu.trove.map.hash.TIntDoubleHashMap;
import java.util.Iterator;
import cern.colt.matrix.tdouble.DoubleMatrix1D;
import cern.colt.matrix.tdouble.DoubleMatrix2D;
/**
*
* @author Sunita Sarawagi
*
*/
public class SegmentTrainer extends SparseTrainer {
protected DoubleMatrix1D alpha_Y_Array[];
protected DoubleMatrix1D alpha_Y_ArrayM[];
protected boolean initAlphaMDone[];
protected DoubleMatrix1D allZeroVector;
public SegmentTrainer(CrfParams p) {
super(p);
logTrainer = true;
}
protected void init(CRF model, DataIter data, double[] l) {
super.init(model,data,l);
logProcessing = true;
allZeroVector = newLogDoubleMatrix1D(numY);
allZeroVector.assign(0);
}
protected double sumProductInner(DataSequence data, FeatureGenerator featureGenerator, double lambda[], double grad[],
boolean onlyForwardPass, int numRecord, FeatureGenerator fgenForExpCompute) {
return sumProductInner(data,featureGenerator,lambda,grad,onlyForwardPass,numRecord,fgenForExpCompute,null,null);
}
double sumProductInner(DataSequence data, FeatureGenerator featureGenerator, double lambda[], double grad[],
boolean onlyForwardPass, int numRecord, FeatureGenerator fgenForExpCompute,TIntDoubleHashMap segmentMarginals[][],
TIntDoubleHashMap edgeMarginals[][][]) {
FeatureGeneratorNested featureGenNested = (FeatureGeneratorNested)featureGenerator;
CandSegDataSequence dataSeq = (CandSegDataSequence)data;
FeatureGeneratorNested featureGenNestedForExpVals = (FeatureGeneratorNested)fgenForExpCompute;
int base = -1;
if ((alpha_Y_Array == null) || (alpha_Y_Array.length < dataSeq.length()-base)) {
allocateAlphaBeta(2*dataSeq.length()+1);
}
int dataSize = dataSeq.length();
CandidateSegments candidateSegs = (CandidateSegments)dataSeq;
DoubleMatrix1D oldBeta = (dataSize > 0)?beta_Y[dataSeq.length()-1]:null;
if (!onlyForwardPass) {
if (dataSize > 0) beta_Y[dataSize-1] = allZeroVector;
for (int i = dataSeq.length()-2; i >= 0; i--) {
beta_Y[i].assign(RobustMath.LOG0);
}
for (int segEnd = dataSeq.length()-1; segEnd >= 0; segEnd--) {
int numCands = candidateSegs.numCandSegmentsEndingAt(segEnd)-1;
for (int nc = 0; nc <= numCands; nc++) {
int segStart = candidateSegs.candSegmentStart(segEnd,nc);
int ell = segEnd-segStart+1;
int i = segStart-1;
if (i < 0)
continue;
// compute the Mi matrix
initMDone = computeLogMi(dataSeq,i,i+ell,featureGenNested,lambda,Mi_YY,Ri_Y,reuseM,initMDone);
tmp_Y.assign(Ri_Y);
if (i+ell < dataSize-1) tmp_Y.assign(beta_Y[i+ell], sumFunc);
if (!reuseM) Mi_YY.zMult(tmp_Y, beta_Y[i],1,1,false);
else beta_Y[i].assign(tmp_Y, RobustMath.logSumExpFunc);
}
if (reuseM && (segEnd-1 >= 0)) {
tmp_Y.assign(beta_Y[segEnd-1]);
Mi_YY.zMult(tmp_Y, beta_Y[segEnd-1],1,0,false);
}
}
}
double thisSeqLogli = 0;
if (reuseM) {
for (int i = dataSeq.length(); i >= 0; i--)
initAlphaMDone[i] = false;
}
alpha_Y_Array[0] = allZeroVector; //.assign(0);
int trainingSegmentEnd=-1;
int trainingSegmentStart = 0;
boolean trainingSegmentFound = true;
boolean noneFired=true;
for (int segEnd = 0; segEnd < dataSize; segEnd++) {
alpha_Y_Array[segEnd-base].assign(RobustMath.LOG0);
if ((grad != null) && (trainingSegmentEnd < segEnd)) {
if ((!trainingSegmentFound)&& noneFired) {
System.out.println("Error: Training segment ("+trainingSegmentStart + " "+ trainingSegmentEnd + ") not found amongst candidate segments");
}
trainingSegmentFound = false;
trainingSegmentStart = segEnd;
trainingSegmentEnd =((SegmentDataSequence)dataSeq).getSegmentEnd(segEnd);
}
int numCands = candidateSegs.numCandSegmentsEndingAt(segEnd)-1;
for (int nc = 0; nc <= numCands; nc++) {
//for (int nc = candidateSegs.numCandSegmentsEndingAt(segEnd)-1; nc >= 0; nc--) {
int ell = segEnd - candidateSegs.candSegmentStart(segEnd,nc)+1;
// compute the Mi matrix
initMDone=computeLogMi(dataSeq,segEnd-ell,segEnd,featureGenNested,lambda,Mi_YY,Ri_Y,reuseM,initMDone);
boolean mAdded = false, rAdded = false;
if (segEnd-ell >= 0) {
if (!reuseM) Mi_YY.zMult(alpha_Y_Array[segEnd-ell-base],newAlpha_Y,1,0,true);
else {
if (!initAlphaMDone[segEnd-ell-base]) {
alpha_Y_ArrayM[segEnd-ell-base].assign(RobustMath.LOG0);
Mi_YY.zMult(alpha_Y_Array[segEnd-ell-base],alpha_Y_ArrayM[segEnd-ell-base],1,0,true);
initAlphaMDone[segEnd-ell-base] = true;
}
newAlpha_Y.assign(alpha_Y_ArrayM[segEnd-ell-base]);
}
newAlpha_Y.assign(Ri_Y,sumFunc);
} else
newAlpha_Y.assign(Ri_Y);
alpha_Y_Array[segEnd-base].assign(newAlpha_Y, RobustMath.logSumExpFunc);
if (featureGenNestedForExpVals != null) {
// find features that fire at this position..
featureGenNestedForExpVals.startScanFeaturesAt(dataSeq, segEnd-ell,segEnd);
while (featureGenNestedForExpVals.hasNext()) {
Feature feature = featureGenNestedForExpVals.next();
int f = feature.index();
int yp = feature.y();
int yprev = feature.yprev();
float val = feature.value();
if ((grad != null) && dataSeq.holdsInTrainingData(feature,segEnd-ell,segEnd)) {
grad[f] += val;
thisSeqLogli += val*lambda[f];
noneFired=false;
if (params.debugLvl > 2) {
System.out.println("Feature fired " + f + " " + feature);
}
}
if (yprev < 0) {
ExpF[f] = RobustMath.logSumExp(ExpF[f], (newAlpha_Y.get(yp)+RobustMath.log(val)+beta_Y[segEnd].get(yp)));
} else {
ExpF[f] = RobustMath.logSumExp(ExpF[f], (alpha_Y_Array[segEnd-ell-base].get(yprev)+Ri_Y.get(yp)+Mi_YY.get(yprev,yp)+RobustMath.log(val)+beta_Y[segEnd].get(yp)));
}
}
}
if (segmentMarginals != null) {
for (int yp = (int) (newAlpha_Y.size()-1); yp >= 0; yp--) {
if ((segmentMarginals[yp][segEnd-ell+1]!=null) && (segmentMarginals[yp][segEnd-ell+1].containsKey(segEnd))) {
// segmentMarginals[yp][segEnd-ell+1] = new TIntDoubleHashMap();
segmentMarginals[yp][segEnd-ell+1].put(segEnd,newAlpha_Y.get(yp)+beta_Y[segEnd].get(yp));
//segmentMarginals[yp][segEnd-ell+1].put(segEnd,Ri_Y.get(yp));
}
if (edgeMarginals != null) {
for (int yprev = (int) (newAlpha_Y.size()-1); yprev >= 0; yprev--) {
if (edgeMarginals[yprev][yp][segEnd-ell+1]==null)
edgeMarginals[yprev][yp][segEnd-ell+1] = new TIntDoubleHashMap();
edgeMarginals[yprev][yp][segEnd-ell+1].put(segEnd,alpha_Y_Array[segEnd-ell-base].get(yprev)+Ri_Y.get(yp)+Mi_YY.get(yprev,yp)+beta_Y[segEnd].get(yp));
//edgeMarginals[yprev][yp][segEnd-ell+1].put(segEnd,Mi_YY.get(yprev,yp));
}
}
}
}
if ((grad != null) && (segEnd == trainingSegmentEnd) && (segEnd-ell+1==trainingSegmentStart)) {
trainingSegmentFound = true;
double val1 = Ri_Y.get(dataSeq.y(trainingSegmentEnd));
double val2 = 0;
if (trainingSegmentStart > 0) {
val2 = Mi_YY.get(dataSeq.y(trainingSegmentStart-1), dataSeq.y(trainingSegmentEnd));
}
if ((val1 == RobustMath.LOG0) || (val2 == RobustMath.LOG0)) {
System.out.println("Error: training labels not covered in generated features " + val1 + " "+val2
+ " y " + dataSeq.y(trainingSegmentEnd));
System.out.println(dataSeq);
featureGenNested.startScanFeaturesAt(dataSeq, segEnd-ell,segEnd);
while (featureGenNested.hasNext()) {
Feature feature = featureGenNested.next();
System.out.println(feature + " " + feature.yprev() + " "+feature.y());
}
}
}
}
if (params.debugLvl > 2) {
System.out.println("Alpha-i " + alpha_Y_Array[segEnd-base].toString());
System.out.println("Ri " + Ri_Y.toString());
System.out.println("Mi " + Mi_YY.toString());
System.out.println("Beta-i " + beta_Y[segEnd].toString());
}
}
lZx = alpha_Y_Array[dataSeq.length()-1-base].zSum();
if (dataSize > 0) beta_Y[dataSize-1] = oldBeta;
if (segmentMarginals != null) {
// normalize with respect to thisSeqLogLi.
boolean normalize=false;
if (normalize) {
for (int y = 0; y < segmentMarginals.length; y++) {
for (int segStart = 0; segStart < segmentMarginals[y].length; segStart++) {
if (segmentMarginals[y][segStart] == null) continue;
for (TIntDoubleIterator segEndProbIter = segmentMarginals[y][segStart].iterator(); segEndProbIter.hasNext();) {
segEndProbIter.advance();
segEndProbIter.setValue(Math.exp(segEndProbIter.value()-lZx));
//System.out.println(segEndProbIter.key() + " " + segEndProbIter.value());
assert (segmentMarginals[y][segStart].get(segEndProbIter.key()) < 1+0.0001);
}
if (edgeMarginals != null) {
for (int yprev = 0; yprev < edgeMarginals.length; yprev++) {
for (TIntDoubleIterator segEndProbIter = edgeMarginals[yprev][y][segStart].iterator(); segEndProbIter.hasNext();) {
segEndProbIter.advance();
segEndProbIter.setValue(Math.exp(segEndProbIter.value()-lZx));
assert (segEndProbIter.value() < 1+0.0001);
}
}
}
}
}
}
return lZx;
}
return thisSeqLogli;
}
/**
* @param i
*/
protected void allocateAlphaBeta(int newSize) {
super.allocateAlphaBeta(newSize);
alpha_Y_Array = new DoubleMatrix1D[newSize];
for (int i = 0; i < alpha_Y_Array.length; i++)
alpha_Y_Array[i] = newLogDoubleMatrix1D(numY);
alpha_Y_ArrayM = new DoubleMatrix1D[newSize];
for (int i = 0; i < alpha_Y_ArrayM.length; i++)
alpha_Y_ArrayM[i] = newLogDoubleMatrix1D(numY);
initAlphaMDone = new boolean[newSize];
}
// TODO..
public static double initLogMi(CandSegDataSequence dataSeq, int prevPos, int pos,
FeatureGeneratorNested featureGenNested, double[] lambda, DoubleMatrix2D Mi, DoubleMatrix1D Ri) {
featureGenNested.startScanFeaturesAt(dataSeq,prevPos,pos);
Iterator constraints = dataSeq.constraints(prevPos,pos);
return initLogMi(0.0,constraints,Mi,Ri);
}
public static boolean computeLogMi(CandSegDataSequence dataSeq, int prevPos, int pos,
FeatureGeneratorNested featureGenNested,
double[] lambda, DoubleMatrix2D Mi, DoubleMatrix1D Ri,
boolean reuseM, boolean initMDone) {
if (reuseM && initMDone)
Mi = null;
computeLogMi(dataSeq, prevPos, pos, featureGenNested,lambda,Mi,Ri);
if ((prevPos >= 0) && reuseM) {
initMDone = true;
//((FeatureGeneratorNestedSameTransitions)featureGenNested).transitionsCached();
}
return initMDone;
}
public static void computeLogMi(CandSegDataSequence dataSeq, int prevPos, int pos,
FeatureGeneratorNested featureGenNested, double[] lambda, DoubleMatrix2D Mi, DoubleMatrix1D Ri) {
double defaultValue = initLogMi(dataSeq, prevPos,pos,featureGenNested,lambda,Mi,Ri);
SparseTrainer.computeLogMiInitDone(featureGenNested,lambda,Mi,Ri,defaultValue);
}
};