package edu.stanford.nlp.ie.crf;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.util.Index;
import java.util.*;
/**
* @author Mengqiu Wang
* TODO(mengqiu) currently only works with disjoint feature sets
* for non-disjoint feature sets, need to recompute EHat each iteration, and multiply in the scale
* in EHat and E calculations for each lopExpert
*/
public class CRFLogConditionalObjectiveFunctionForLOP extends AbstractCachingDiffFunction implements HasCliquePotentialFunction {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(CRFLogConditionalObjectiveFunctionForLOP.class);
/** label indices - for all possible label sequences - for each feature */
List<Index<CRFLabel>> labelIndices;
Index<String> classIndex; // didn't have <String> before. Added since that's what is assumed everywhere.
double[][][] Ehat; // empirical counts of all the features [lopIter][feature][class]
double[] sumOfObservedLogPotential; // empirical sum of all log potentials [lopIter]
double[][][][][] sumOfExpectedLogPotential; // sumOfExpectedLogPotential[m][i][j][lopIter][k] m-docNo;i-position;j-cliqueNo;k-label
List<Set<Integer>> featureIndicesSetArray;
List<List<Integer>> featureIndicesListArray;
int window;
int numClasses;
int[] map;
int[][][][] data; // data[docIndex][tokenIndex][][]
double[][] lopExpertWeights; // lopExpertWeights[expertIter][weightIndex]
double[][][] lopExpertWeights2D;
int[][] labels; // labels[docIndex][tokenIndex]
int[][] learnedParamsMapping;
int numLopExpert;
boolean backpropTraining;
int domainDimension = -1;
String crfType = "maxent";
String backgroundSymbol;
public static boolean VERBOSE = false;
CRFLogConditionalObjectiveFunctionForLOP(int[][][][] data, int[][] labels, double[][] lopExpertWeights, int window,
Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, String backgroundSymbol, int numLopExpert,
List<Set<Integer>> featureIndicesSetArray, List<List<Integer>> featureIndicesListArray, boolean backpropTraining) {
this.window = window;
this.classIndex = classIndex;
this.numClasses = classIndex.size();
this.labelIndices = labelIndices;
this.map = map;
this.data = data;
this.lopExpertWeights = lopExpertWeights;
this.labels = labels;
this.backgroundSymbol = backgroundSymbol;
this.numLopExpert = numLopExpert;
this.featureIndicesSetArray = featureIndicesSetArray;
this.featureIndicesListArray = featureIndicesListArray;
this.backpropTraining = backpropTraining;
initialize2DWeights();
if (backpropTraining) {
computeEHat();
} else {
logPotential(lopExpertWeights2D);
}
}
@Override
public int domainDimension() {
if (domainDimension < 0) {
domainDimension = numLopExpert;
if (backpropTraining) {
// for (int i = 0; i < map.length; i++) {
// domainDimension += labelIndices[map[i]].size();
// }
for (int i = 0; i < numLopExpert; i++) {
List<Integer> featureIndicesList = featureIndicesListArray.get(i);
double[][] expertWeights2D = lopExpertWeights2D[i];
for (int fIndex: featureIndicesList) {
int len = expertWeights2D[fIndex].length;
domainDimension += len;
}
}
}
}
return domainDimension;
}
@Override
public double[] initial() {
double[] initial = new double[domainDimension()];
if (backpropTraining) {
learnedParamsMapping = new int[domainDimension()][3];
int index = 0;
for (; index < numLopExpert; index++) {
initial[index] = 1.0;
}
for (int i = 0; i < numLopExpert; i++) {
List<Integer> featureIndicesList = featureIndicesListArray.get(i);
double[][] expertWeights2D = lopExpertWeights2D[i];
for (int fIndex: featureIndicesList) {
for (int j = 0; j < expertWeights2D[fIndex].length; j++) {
initial[index] = expertWeights2D[fIndex][j];
learnedParamsMapping[index] = new int[]{i, fIndex, j};
index++;
}
}
}
} else {
Arrays.fill(initial, 1.0);
}
return initial;
}
public double[][][] empty2D() {
double[][][] d2 = new double[numLopExpert][][];
for (int lopIter = 0; lopIter < numLopExpert; lopIter++) {
double[][] d = new double[map.length][];
// int index = 0;
for (int i = 0; i < map.length; i++) {
d[i] = new double[labelIndices.get(map[i]).size()];
// cdm july 2005: below array initialization isn't necessary: JLS (3rd ed.) 4.12.5
// Arrays.fill(d[i], 0.0);
// index += labelIndices[map[i]].size();
}
d2[lopIter] = d;
}
return d2;
}
private void initialize2DWeights() {
lopExpertWeights2D = new double[numLopExpert][][];
for (int lopIter = 0; lopIter < numLopExpert; lopIter++) {
lopExpertWeights2D[lopIter] = to2D(lopExpertWeights[lopIter], labelIndices, map);
}
}
public double[][] to2D(double[] weights, List<Index<CRFLabel>> labelIndices, int[] map) {
double[][] newWeights = new double[map.length][];
int index = 0;
for (int i = 0; i < map.length; i++) {
newWeights[i] = new double[labelIndices.get(map[i]).size()];
System.arraycopy(weights, index, newWeights[i], 0, labelIndices.get(map[i]).size());
index += labelIndices.get(map[i]).size();
}
return newWeights;
}
private void computeEHat() {
Ehat = empty2D();
for (int m = 0; m < data.length; m++) {
int[][][] docData = data[m];
int[] docLabels = labels[m];
int[] windowLabels = new int[window];
Arrays.fill(windowLabels, classIndex.indexOf(backgroundSymbol));
if (docLabels.length>docData.length) { // only true for self-training
// fill the windowLabel array with the extra docLabels
System.arraycopy(docLabels, 0, windowLabels, 0, windowLabels.length);
// shift the docLabels array left
int[] newDocLabels = new int[docData.length];
System.arraycopy(docLabels, docLabels.length-newDocLabels.length, newDocLabels, 0, newDocLabels.length);
docLabels = newDocLabels;
}
for (int i = 0; i < docData.length; i++) {
System.arraycopy(windowLabels, 1, windowLabels, 0, window - 1);
windowLabels[window - 1] = docLabels[i];
int[][] docDataI = docData[i];
for (int j = 0; j < docDataI.length; j++) { // j iterates over cliques
int[] docDataIJ = docDataI[j];
int[] cliqueLabel = new int[j + 1];
System.arraycopy(windowLabels, window - 1 - j, cliqueLabel, 0, j + 1);
CRFLabel crfLabel = new CRFLabel(cliqueLabel);
Index<CRFLabel> labelIndex = labelIndices.get(j);
int observedLabelIndex = labelIndex.indexOf(crfLabel);
//log.info(crfLabel + " " + observedLabelIndex);
for (int lopIter = 0; lopIter < numLopExpert; lopIter++) {
double[][] ehatOfIter = Ehat[lopIter];
Set<Integer> indicesSet = featureIndicesSetArray.get(lopIter);
for (int featureIdx : docDataIJ) { // k iterates over features
if (indicesSet.contains(featureIdx)) {
ehatOfIter[featureIdx][observedLabelIndex]++;
}
}
}
}
}
}
}
private void logPotential(double[][][] learnedLopExpertWeights2D) {
sumOfExpectedLogPotential = new double[data.length][][][][];
sumOfObservedLogPotential = new double[numLopExpert];
for (int m = 0; m < data.length; m++) {
int[][][] docData = data[m];
int[] docLabels = labels[m];
int[] windowLabels = new int[window];
Arrays.fill(windowLabels, classIndex.indexOf(backgroundSymbol));
double[][][][] sumOfELPm = new double[docData.length][][][];
if (docLabels.length>docData.length) { // only true for self-training
// fill the windowLabel array with the extra docLabels
System.arraycopy(docLabels, 0, windowLabels, 0, windowLabels.length);
// shift the docLabels array left
int[] newDocLabels = new int[docData.length];
System.arraycopy(docLabels, docLabels.length-newDocLabels.length, newDocLabels, 0, newDocLabels.length);
docLabels = newDocLabels;
}
for (int i = 0; i < docData.length; i++) {
System.arraycopy(windowLabels, 1, windowLabels, 0, window - 1);
windowLabels[window - 1] = docLabels[i];
double[][][] sumOfELPmi = new double[docData[i].length][][];
int[][] docDataI = docData[i];
for (int j = 0; j < docDataI.length; j++) { // j iterates over cliques
int[] docDataIJ = docDataI[j];
int[] cliqueLabel = new int[j + 1];
System.arraycopy(windowLabels, window - 1 - j, cliqueLabel, 0, j + 1);
CRFLabel crfLabel = new CRFLabel(cliqueLabel);
Index<CRFLabel> labelIndex = labelIndices.get(j);
double[][] sumOfELPmij = new double[numLopExpert][];
int observedLabelIndex = labelIndex.indexOf(crfLabel);
//log.info(crfLabel + " " + observedLabelIndex);
for (int lopIter = 0; lopIter < numLopExpert; lopIter++) {
double[] sumOfELPmijIter = new double[labelIndex.size()];
Set<Integer> indicesSet = featureIndicesSetArray.get(lopIter);
for (int featureIdx : docDataIJ) { // k iterates over features
if (indicesSet.contains(featureIdx)) {
sumOfObservedLogPotential[lopIter] += learnedLopExpertWeights2D[lopIter][featureIdx][observedLabelIndex];
// sum over potential of this clique over all possible labels, used later in calculating expected counts
for (int l = 0; l < labelIndex.size(); l++) {
sumOfELPmijIter[l] += learnedLopExpertWeights2D[lopIter][featureIdx][l];
}
}
}
sumOfELPmij[lopIter] = sumOfELPmijIter;
}
sumOfELPmi[j] = sumOfELPmij;
}
sumOfELPm[i] = sumOfELPmi;
}
sumOfExpectedLogPotential[m] = sumOfELPm;
}
}
public static double[] combineAndScaleLopWeights(int numLopExpert, double[][] lopExpertWeights, double[] lopScales) {
double[] newWeights = new double[lopExpertWeights[0].length];
for (int i = 0; i < newWeights.length; i++) {
double tempWeight = 0;
for (int lopIter = 0; lopIter < numLopExpert; lopIter++) {
tempWeight += lopExpertWeights[lopIter][i] * lopScales[lopIter];
}
newWeights[i] = tempWeight;
}
return newWeights;
}
public static double[][] combineAndScaleLopWeights2D(int numLopExpert, double[][][] lopExpertWeights2D, double[] lopScales) {
double[][] newWeights = new double[lopExpertWeights2D[0].length][];
for (int i = 0; i < newWeights.length; i++) {
int innerDim = lopExpertWeights2D[0][i].length;
double[] innerWeights = new double[innerDim];
for (int j = 0; j < innerDim; j++) {
double tempWeight = 0;
for (int lopIter = 0; lopIter < numLopExpert; lopIter++) {
tempWeight += lopExpertWeights2D[lopIter][i][j] * lopScales[lopIter];
}
innerWeights[j] = tempWeight;
}
newWeights[i] = innerWeights;
}
return newWeights;
}
public double[][][] separateLopExpertWeights2D(double[] learnedParams) {
double[][][] learnedWeights2D = empty2D();
for (int paramIndex = numLopExpert; paramIndex < learnedParams.length; paramIndex++) {
int[] mapping = learnedParamsMapping[paramIndex];
learnedWeights2D[mapping[0]][mapping[1]][mapping[2]] = learnedParams[paramIndex];
}
return learnedWeights2D;
}
public double[][] separateLopExpertWeights(double[] learnedParams) {
double[][] learnedWeights = new double[numLopExpert][];
double[][][] learnedWeights2D = separateLopExpertWeights2D(learnedParams);
for (int i = 0; i < numLopExpert; i++) {
learnedWeights[i] = CRFLogConditionalObjectiveFunction.to1D(learnedWeights2D[i], lopExpertWeights[i].length);
}
return learnedWeights;
}
public double[] separateLopScales(double[] learnedParams) {
double[] rawScales = new double[numLopExpert];
System.arraycopy(learnedParams, 0, rawScales, 0, numLopExpert);
return rawScales;
}
public CliquePotentialFunction getCliquePotentialFunction(double[] x) {
double[] rawScales = separateLopScales(x);
double[] scales = ArrayMath.softmax(rawScales);
double[][][] learnedLopExpertWeights2D = lopExpertWeights2D;
if (backpropTraining) {
learnedLopExpertWeights2D = separateLopExpertWeights2D(x);
}
double[][] combinedWeights2D = combineAndScaleLopWeights2D(numLopExpert, learnedLopExpertWeights2D, scales);
return new LinearCliquePotentialFunction(combinedWeights2D);
}
// todo [cdm]: Below data[m] --> docData
/**
* Calculates both value and partial derivatives at the point x, and save them internally.
*/
@Override
public void calculate(double[] x) {
double prob = 0.0; // the log prob of the sequence given the model, which is the negation of value at this point
double[][][] E = empty2D();
double[] eScales = new double[numLopExpert];
double[] rawScales = separateLopScales(x);
double[] scales = ArrayMath.softmax(rawScales);
double[][][] learnedLopExpertWeights2D = lopExpertWeights2D;
if (backpropTraining) {
learnedLopExpertWeights2D = separateLopExpertWeights2D(x);
logPotential(learnedLopExpertWeights2D);
}
double[][] combinedWeights2D = combineAndScaleLopWeights2D(numLopExpert, learnedLopExpertWeights2D, scales);
// iterate over all the documents
for (int m = 0; m < data.length; m++) {
int[][][] docData = data[m];
int[] docLabels = labels[m];
double[][][][] sumOfELPm = sumOfExpectedLogPotential[m]; // sumOfExpectedLogPotential[m][i][j][lopIter][k] m-docNo;i-position;j-cliqueNo;k-label
// make a clique tree for this document
CliquePotentialFunction cliquePotentialFunc = new LinearCliquePotentialFunction(combinedWeights2D);
CRFCliqueTree cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, cliquePotentialFunc, null);
// compute the log probability of the document given the model with the parameters x
int[] given = new int[window - 1];
Arrays.fill(given, classIndex.indexOf(backgroundSymbol));
if (docLabels.length > docData.length) { // only true for self-training
// fill the given array with the extra docLabels
System.arraycopy(docLabels, 0, given, 0, given.length);
// shift the docLabels array left
int[] newDocLabels = new int[docData.length];
System.arraycopy(docLabels, docLabels.length-newDocLabels.length, newDocLabels, 0, newDocLabels.length);
docLabels = newDocLabels;
}
// iterate over the positions in this document
for (int i = 0; i < docData.length; i++) {
int label = docLabels[i];
double p = cliqueTree.condLogProbGivenPrevious(i, label, given);
if (VERBOSE) {
log.info("P(" + label + "|" + ArrayMath.toString(given) + ")=" + p);
}
prob += p;
System.arraycopy(given, 1, given, 0, given.length - 1);
given[given.length - 1] = label;
}
// compute the expected counts for this document, which we will need to compute the derivative
// iterate over the positions in this document
for (int i = 0; i < docData.length; i++) {
// for each possible clique at this position
double[][][] sumOfELPmi = sumOfELPm[i];
for (int j = 0; j < docData[i].length; j++) {
double[][] sumOfELPmij = sumOfELPmi[j];
Index<CRFLabel> labelIndex = labelIndices.get(j);
// for each possible labeling for that clique
for (int l = 0; l < labelIndex.size(); l++) {
int[] label = labelIndex.get(l).getLabel();
double p = cliqueTree.prob(i, label); // probability of these labels occurring in this clique with these features
for (int lopIter = 0; lopIter < numLopExpert; lopIter++) {
Set<Integer> indicesSet = featureIndicesSetArray.get(lopIter);
double scale = scales[lopIter];
double expected = sumOfELPmij[lopIter][l];
for (int innerLopIter = 0; innerLopIter < numLopExpert; innerLopIter++) {
expected -= scales[innerLopIter] * sumOfELPmij[innerLopIter][l];
}
expected *= scale;
eScales[lopIter] += (p * expected);
double[][] eOfIter = E[lopIter];
if (backpropTraining) {
for (int k = 0; k < docData[i][j].length; k++) { // k iterates over features
int featureIdx = docData[i][j][k];
if (indicesSet.contains(featureIdx)) {
eOfIter[featureIdx][l] += p;
}
}
}
}
}
}
}
}
if (Double.isNaN(prob)) { // shouldn't be the case
throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunctionForLOP.calculate()");
}
value = -prob;
if(VERBOSE){
log.info("value is " + value);
}
// compute the partial derivative for each feature by comparing expected counts to empirical counts
for (int lopIter = 0; lopIter < numLopExpert; lopIter++) {
double scale = scales[lopIter];
double observed = sumOfObservedLogPotential[lopIter];
for (int j = 0; j < numLopExpert; j++) {
observed -= scales[j] * sumOfObservedLogPotential[j];
}
observed *= scale;
double expected = eScales[lopIter];
derivative[lopIter] = (expected - observed);
if (VERBOSE) {
log.info("deriv(" + lopIter + ") = " + expected + " - " + observed + " = " + derivative[lopIter]);
}
}
if (backpropTraining) {
int dIndex = numLopExpert;
for (int lopIter = 0; lopIter < numLopExpert; lopIter++) {
double scale = scales[lopIter];
double[][] eOfExpert = E[lopIter];
double[][] ehatOfExpert = Ehat[lopIter];
List<Integer> featureIndicesList = featureIndicesListArray.get(lopIter);
for (int fIndex: featureIndicesList) {
for (int j = 0; j < eOfExpert[fIndex].length; j++) {
derivative[dIndex++] = scale * (eOfExpert[fIndex][j] - ehatOfExpert[fIndex][j]);
if (VERBOSE) {
log.info("deriv[" + lopIter+ "](" + fIndex + "," + j + ") = " + scale + " * (" + eOfExpert[fIndex][j] + " - " + ehatOfExpert[fIndex][j] + ") = " + derivative[dIndex - 1]);
}
}
}
}
assert(dIndex == domainDimension());
}
}
}