package edu.stanford.nlp.ie.crf;
import edu.stanford.nlp.util.Index;
import java.util.*;
/**
* @author Mengqiu Wang
*/
public class CRFLogConditionalObjectiveFunctionNoisyLabel extends CRFLogConditionalObjectiveFunction {
// protected final double[][][] parallelEhat;
protected final double[][] errorMatrix;
CRFLogConditionalObjectiveFunctionNoisyLabel(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String priorType, String backgroundSymbol, double sigma, double[][][][] featureVal, int multiThreadGrad, double[][] errorMatrix) {
super(data, labels, window, classIndex, labelIndices, map, priorType, backgroundSymbol, sigma, featureVal, multiThreadGrad, false);
this.errorMatrix = errorMatrix;
}
public CliquePotentialFunction getFunc(int docIndex) {
int[] docLabels = labels[docIndex];
return new NoisyLabelLinearCliquePotentialFunction(weights, docLabels, errorMatrix);
}
public void setWeights(double[][] weights) {
super.setWeights(weights);
}
@Override
protected double expectedAndEmpiricalCountsAndValueForADoc(double[][] E, double[][] Ehat, int docIndex) {
int[][][] docData = data[docIndex];
double[][][] featureVal3DArr = null;
if (featureVal != null) {
featureVal3DArr = featureVal[docIndex];
}
// make a clique tree for this document
CRFCliqueTree cliqueTreeNoisyLabel = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, getFunc(docIndex), featureVal3DArr);
CRFCliqueTree cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, cliquePotentialFunc, featureVal3DArr);
double prob = 0.0;
prob = cliqueTreeNoisyLabel.totalMass() - cliqueTree.totalMass();
documentExpectedCounts(E, docData, featureVal3DArr, cliqueTree);
documentExpectedCounts(Ehat, docData, featureVal3DArr, cliqueTreeNoisyLabel);
return prob;
}
@Override
protected double regularGradientAndValue() {
int totalLen = data.length;
List<Integer> docIDs = new ArrayList<>(totalLen);
for (int m=0; m < totalLen; m++) docIDs.add(m);
return multiThreadGradient(docIDs, true);
}
/**
* Calculates both value and partial derivatives at the point x, and save them internally.
*/
@Override
public void calculate(double[] x) {
clear2D(Ehat);
super.calculate(x);
}
}