package edu.stanford.nlp.ie.crf;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.sequences.SequenceListener;
import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.GeneralizedCounter;
import edu.stanford.nlp.util.Index;
import java.util.Arrays;
import java.util.List;
/** Builds a CliqueTree (an array of FactorTable) and does message passing
* inference along it.
*
* @author Jenny Finkel
*/
public class CRFCliqueTree implements SequenceModel, SequenceListener {
private FactorTable[] factorTables;
private double z; // norm constant
private Index classIndex;
private String backgroundSymbol;
private int backgroundIndex;
// the window size, which is also the clique size
private int windowSize;
// the number of possible classes for each label
private int numClasses;
private int[] possibleValues;
private CRFCliqueTree() {}
private CRFCliqueTree(FactorTable[] factorTables, Index classIndex, String backgroundSymbol) {
this.factorTables = factorTables;
this.classIndex = classIndex;
this.backgroundSymbol = backgroundSymbol;
backgroundIndex = classIndex.indexOf(backgroundSymbol);
z = factorTables[0].totalMass();
windowSize = factorTables[0].windowSize();
numClasses = classIndex.size();
possibleValues = new int[numClasses];
for (int i=0; i<numClasses; i++) {
possibleValues[i] = i;
}
//Debug only
//System.out.println("CRFCliqueTree constructed::numClasses: " + numClasses);
}
public Index classIndex() {
return classIndex;
}
// SEQUENCE MODEL METHODS
public int length() {
return factorTables.length;
}
public int leftWindow() {
return windowSize;
}
public int rightWindow() {
return 0;
}
public int[] getPossibleValues(int position) {
return possibleValues;
}
public double scoreOf(int[] sequence, int pos) {
return scoresOf(sequence, pos)[sequence[pos]];
}
/**
* Computes the unnormalized log conditional distribution over values of the
* element at position pos in the sequence, conditioned on the values
* of the elements in all other positions of the provided sequence.
*
* @param sequence the sequence containing the rest of the values to condition on
* @param position the position of the element to give a distribution for
* @return an array of type double, representing a probability distribution; sums to 1.0
*/
public double[] scoresOf(int[] sequence, int position) {
if (position>=factorTables.length) throw new RuntimeException("Index out of bounds: " + position);
// DecimalFormat nf = new DecimalFormat("#0.000");
// if (position>0 && position<sequence.length-1) System.out.println(position + ": asking about " +sequence[position-1] + "(" + sequence[position] + ")" + sequence[position+1]);
double[] probThisGivenPrev = new double[numClasses];
double[] probNextGivenThis = new double[numClasses];
// double[] marginal = new double[numClasses]; // for debugging only
// compute prob of this tag given the window-1 previous tags, normalized
// extract the window-1 previous tags, pad left with background if necessary
int prevLength = windowSize-1;
int[] prev = new int[prevLength+1]; // leave an extra element for the label at this position
int i=0;
for ( ; i<prevLength-position; i++) { // will only happen if position-prevLength < 0
prev[i] = classIndex.indexOf(backgroundSymbol);
}
for ( ; i<prevLength; i++) {
prev[i] = sequence[position-prevLength+i];
}
for (int label=0; label<numClasses; label++) {
prev[prev.length-1] = label;
probThisGivenPrev[label] = factorTables[position].unnormalizedLogProb(prev);
// marginal[label] = factorTables[position].logProbEnd(label); // remove: for debugging only
}
// ArrayMath.logNormalize(probThisGivenPrev);
// compute the prob of the window-1 next tags given this tag
// extract the window-1 next tags
int nextLength = windowSize-1;
if (position+nextLength>=length()) {
nextLength = length()-position-1;
}
FactorTable nextFactorTable = factorTables[position+nextLength];
if (nextLength != windowSize - 1) {
for (int j = 0; j < windowSize - 1 - nextLength; j++) {
nextFactorTable = nextFactorTable.sumOutFront();
}
}
if (nextLength==0) { // we are asking about the prob of no sequence
Arrays.fill(probNextGivenThis, 1.0);
} else {
int[] next = new int[nextLength];
System.arraycopy(sequence, position+1, next, 0, nextLength);
for (int label=0; label<numClasses; label++) {
// ask the factor table such that pos is the first position in the window
//probNextGivenThis[label] = factorTables[position+nextLength].conditionalLogProbGivenFirst(label, next);
//probNextGivenThis[label] = nextFactorTable.conditionalLogProbGivenFirst(label, next);
probNextGivenThis[label] = nextFactorTable.unnormalizedConditionalLogProbGivenFirst(label, next);
}
}
// pointwise multiply
return ArrayMath.pairwiseAdd(probThisGivenPrev, probNextGivenThis);
}
/**
* Returns the log probability of this sequence given the CRF.
* Does so by computing the marginal of the first windowSize tags, and
* then computing the conditional probability for the rest of them, conditioned on the previous
* tags.
* @param sequence the sequence to compute a score for
* @return the score for the sequence
*/
public double scoreOf(int[] sequence) {
int[] given = new int[window() - 1];
Arrays.fill(given, classIndex.indexOf(backgroundSymbol));
double logProb = 0;
for (int i=0; i<length(); i++) {
int label = sequence[i];
logProb += condLogProbGivenPrevious(i, label, given);
System.arraycopy(given, 1, given, 0, given.length - 1);
given[given.length - 1] = label;
}
return logProb;
}
// OTHER
public int window() {
return windowSize;
}
public int getNumClasses() {
return numClasses;
}
public double totalMass() {
return z;
}
public int backgroundIndex() {
return backgroundIndex;
}
public String backgroundSymbol() {
return backgroundSymbol;
}
//
// MARGINAL PROB OF TAG AT SINGLE POSITION
//
public double logProb(int position, int label) {
double u = factorTables[position].unnormalizedLogProbEnd(label);
return u - z;
}
public double prob(int position, int label) {
return Math.exp(logProb(position, label));
}
public double logProb(int position, Object label) {
return logProb(position, classIndex.indexOf(label));
}
public double prob(int position, Object label) {
return Math.exp(logProb(position, label));
}
public ClassicCounter probs(int position) {
ClassicCounter c = new ClassicCounter();
for (int i = 0; i < classIndex.size(); i++) {
Object label = classIndex.get(i);
c.incrementCount(label, prob(position, i));
}
return c;
}
public ClassicCounter logProbs(int position) {
ClassicCounter c = new ClassicCounter();
for (int i = 0; i < classIndex.size(); i++) {
Object label = classIndex.get(i);
c.incrementCount(label, logProb(position, i));
}
return c;
}
//
// MARGINAL PROBS OF TAGS AT MULTIPLE POSITIONS
//
/**
* returns the log probability for the given labels (indexed using classIndex), where the
* last label corresponds to the label at the specified position. For instance if you called
* logProb(5, {1,2,3}) it will return the marginal log prob that the label at position
* 3 is 1, the label at position 4 is 2 and the label at position 5 is 3.
*/
public double logProb(int position, int[] labels) {
if (labels.length < windowSize) {
return factorTables[position].unnormalizedLogProbEnd(labels) - z;
} else if (labels.length == windowSize) {
return factorTables[position].unnormalizedLogProb(labels) - z;
} else {
int[] l = new int[windowSize];
System.arraycopy(labels, 0, l, 0, l.length);
int position1 = position-labels.length+windowSize;
double p = factorTables[position1].unnormalizedLogProb(l) - z;
l = new int[windowSize-1];
System.arraycopy(labels, 1, l, 0, l.length);
position1++;
for (int i = windowSize; i < labels.length; i++) {
p += condLogProbGivenPrevious(position1++, labels[i], l);
System.arraycopy(l, 1, l, 0, l.length-1);
l[windowSize-2] = labels[i];
}
return p;
}
}
/**
* Returns the probability for the given labels (indexed using classIndex), where the
* last label corresponds to the label at the specified position. For instance if you called
* prob(5, {1,2,3}) it will return the marginal prob that the label at position
* 3 is 1, the label at position 4 is 2 and the label at position 5 is 3.
*/
public double prob(int position, int[] labels) {
return Math.exp(logProb(position, labels));
}
/**
* returns the log probability for the given labels, where the
* last label corresponds to the label at the specified position. For instance if you called
* logProb(5, {"O", "PER", "ORG"}) it will return the marginal log prob that the label at position
* 3 is "O", the label at position 4 is "PER" and the label at position 5 is "ORG".
*/
public double logProb(int position, Object[] labels) {
return logProb(position, objectArrayToIntArray(labels));
}
/**
* returns the probability for the given labels, where the
* last label corresponds to the label at the specified position. For instance if you called
* logProb(5, {"O", "PER", "ORG"}) it will return the marginal prob that the label at position
* 3 is "O", the label at position 4 is "PER" and the label at position 5 is "ORG".
*/
public double prob(int position, Object[] labels) {
return Math.exp(logProb(position, labels));
}
public GeneralizedCounter logProbs(int position, int window) {
GeneralizedCounter gc = new GeneralizedCounter(window);
int[] labels = new int[window];
// cdm july 2005: below array initialization isn't necessary: JLS (3rd ed.) 4.12.5
// Arrays.fill(labels, 0);
OUTER: while (true) {
List labelsList = Arrays.asList(intArrayToObjectArray(labels));
gc.incrementCount(labelsList, logProb(position, labels));
for (int i = 0; i < labels.length; i++) {
labels[i]++;
if (labels[i] < numClasses) { break; }
if (i == labels.length - 1) { break OUTER; }
labels[i] = 0;
}
}
return gc;
}
public GeneralizedCounter probs(int position, int window) {
GeneralizedCounter gc = new GeneralizedCounter(window);
int[] labels = new int[window];
// cdm july 2005: below array initialization isn't necessary: JLS (3rd ed.) 4.12.5
// Arrays.fill(labels, 0);
OUTER: while (true) {
List labelsList = Arrays.asList(intArrayToObjectArray(labels));
gc.incrementCount(labelsList, prob(position, labels));
for (int i = 0; i < labels.length; i++) {
labels[i]++;
if (labels[i] < numClasses) { break; }
if (i == labels.length - 1) { break OUTER; }
labels[i] = 0;
}
}
return gc;
}
//
// HELPER METHODS
//
private int[] objectArrayToIntArray(Object[] os) {
int[] is = new int[os.length];
for (int i = 0; i < os.length; i++) {
is[i] = classIndex.indexOf(os[i]);
}
return is;
}
private Object[] intArrayToObjectArray(int[] is) {
Object[] os = new Object[is.length];
for (int i = 0; i < is.length; i++) {
os[i] = classIndex.get(is[i]);
}
return os;
}
/**
* Gives the probability of a tag at a single position conditioned on
* a sequence of previous labels.
*
* @param position Index in sequence
* @param label Label of item at index
* @param prevLabels
* @return conditional log probability
*/
public double condLogProbGivenPrevious(int position, int label, int[] prevLabels) {
if (prevLabels.length + 1 == windowSize) {
return factorTables[position].conditionalLogProbGivenPrevious(prevLabels, label);
} else if (prevLabels.length + 1 < windowSize) {
FactorTable ft = factorTables[position].sumOutFront();
while (ft.windowSize() > prevLabels.length + 1) {
ft = ft.sumOutFront();
}
return ft.conditionalLogProbGivenPrevious(prevLabels, label);
} else {
int[] p = new int[windowSize-1];
System.arraycopy(prevLabels, prevLabels.length - p.length, p, 0, p.length);
return factorTables[position].conditionalLogProbGivenPrevious(p, label);
}
}
public double condLogProbGivenPrevious(int position, Object label, Object[] prevLabels) {
return condLogProbGivenPrevious(position, classIndex.indexOf(label), objectArrayToIntArray(prevLabels));
}
public double condProbGivenPrevious(int position, int label, int[] prevLabels) {
return Math.exp(condLogProbGivenPrevious(position, label, prevLabels));
}
public double condProbGivenPrevious(int position, Object label, Object[] prevLabels) {
return Math.exp(condLogProbGivenPrevious(position, label, prevLabels));
}
public ClassicCounter condLogProbsGivenPrevious(int position, int[] prevlabels) {
ClassicCounter c = new ClassicCounter();
for (int i = 0; i < classIndex.size(); i++) {
Object label = classIndex.get(i);
c.incrementCount(label, condLogProbGivenPrevious(position, i, prevlabels));
}
return c;
}
public ClassicCounter condLogProbsGivenPrevious(int position, Object[] prevlabels) {
ClassicCounter c = new ClassicCounter();
for (int i = 0; i < classIndex.size(); i++) {
Object label = classIndex.get(i);
c.incrementCount(label, condLogProbGivenPrevious(position, i, prevlabels));
}
return c;
}
//
// PROB OF TAG AT SINGLE POSITION CONDITIONED ON FOLLOWING SEQUENCE OF LABELS
//
public double condLogProbGivenNext(int position, int label, int[] nextLabels) {
position = position+nextLabels.length;
if (nextLabels.length + 1 == windowSize) {
return factorTables[position].conditionalLogProbGivenNext(nextLabels, label);
} else if (nextLabels.length + 1 < windowSize) {
FactorTable ft = factorTables[position].sumOutFront();
while (ft.windowSize() > nextLabels.length + 1) {
ft = ft.sumOutFront();
}
return ft.conditionalLogProbGivenPrevious(nextLabels, label);
} else {
int[] p = new int[windowSize-1];
System.arraycopy(nextLabels, 0, p, 0, p.length);
return factorTables[position].conditionalLogProbGivenPrevious(p, label);
}
}
public double condLogProbGivenNext(int position, Object label, Object[] nextLabels) {
return condLogProbGivenNext(position, classIndex.indexOf(label), objectArrayToIntArray(nextLabels));
}
public double condProbGivenNext(int position, int label, int[] nextLabels) {
return Math.exp(condLogProbGivenNext(position, label, nextLabels));
}
public double condProbGivenNext(int position, Object label, Object[] nextLabels) {
return Math.exp(condLogProbGivenNext(position, label, nextLabels));
}
public ClassicCounter condLogProbsGivenNext(int position, int[] nextlabels) {
ClassicCounter c = new ClassicCounter();
for (int i = 0; i < classIndex.size(); i++) {
Object label = classIndex.get(i);
c.incrementCount(label, condLogProbGivenNext(position, i, nextlabels));
}
return c;
}
public ClassicCounter condLogProbsGivenNext(int position, Object[] nextlabels) {
ClassicCounter c = new ClassicCounter();
for (int i = 0; i < classIndex.size(); i++) {
Object label = classIndex.get(i);
c.incrementCount(label, condLogProbGivenNext(position, i, nextlabels));
}
return c;
}
//
// PROB OF TAG AT SINGLE POSITION CONDITIONED ON PREVIOUS AND FOLLOWING SEQUENCE OF LABELS
//
// public double condProbGivenPreviousAndNext(int position, int label, int[] prevLabels, int[] nextLabels) {
// }
//
// JOINT CONDITIONAL PROBS
//
/**
* @return a new CRFCliqueTree for the weights on the data
*/
public static CRFCliqueTree getCalibratedCliqueTree(double[][] weights, int[][][] data, Index[] labelIndices, int numClasses, Index classIndex, String backgroundSymbol) {
FactorTable[] factorTables = new FactorTable[data.length];
FactorTable[] messages = new FactorTable[data.length - 1];
for (int i = 0; i < data.length; i++) {
factorTables[i] = getFactorTable(weights, data[i], labelIndices, numClasses);
if (i > 0) {
messages[i - 1] = factorTables[i - 1].sumOutFront();
factorTables[i].multiplyInFront(messages[i - 1]);
}
}
for (int i = factorTables.length - 2; i >= 0; i--) {
FactorTable summedOut = factorTables[i + 1].sumOutEnd();
summedOut.divideBy(messages[i]);
factorTables[i].multiplyInEnd(summedOut);
}
return new CRFCliqueTree(factorTables, classIndex, backgroundSymbol);
}
/**
* @return a new CRFCliqueTree for the weights on the data
*/
public static CRFCliqueTree getCalibratedCliqueTree(double[] weights,
double wscale,
int[][] weightIndices,
int[][][] data,
Index[] labelIndices,
int numClasses,
Index classIndex,
String backgroundSymbol) {
FactorTable[] factorTables = new FactorTable[data.length];
FactorTable[] messages = new FactorTable[data.length - 1];
for (int i = 0; i < data.length; i++) {
factorTables[i] = getFactorTable(weights, wscale, weightIndices, data[i], labelIndices, numClasses);
if (i > 0) {
messages[i - 1] = factorTables[i - 1].sumOutFront();
factorTables[i].multiplyInFront(messages[i - 1]);
}
}
for (int i = factorTables.length - 2; i >= 0; i--) {
FactorTable summedOut = factorTables[i + 1].sumOutEnd();
summedOut.divideBy(messages[i]);
factorTables[i].multiplyInEnd(summedOut);
}
return new CRFCliqueTree(factorTables, classIndex, backgroundSymbol);
}
private static FactorTable getFactorTable(double[] weights, double wscale, int[][] weightIndices, int[][] data, Index[] labelIndices, int numClasses) {
FactorTable factorTable = null;
for (int j = 0; j < labelIndices.length; j++) {
Index labelIndex = labelIndices[j];
FactorTable ft = new FactorTable(numClasses, j + 1);
// ... and each possible labeling for that clique
for (int k = 0, liSize = labelIndex.size(); k < liSize; k++) {
int[] label = ((CRFLabel) labelIndex.get(k)).getLabel();
double weight = 0.0;
for (int m = 0; m < data[j].length; m++) {
int wi = weightIndices[data[j][m]][k];
weight += wscale*weights[wi];
}
// try{
ft.setValue(label, weight);
// } catch (Exception e) {
// System.out.println("CRFCliqueTree::getFactorTable");
// System.out.println("NumClasses: " + numClasses + " j+1: " + (j+1));
// System.out.println("k: " + k+" label: " +label+" labelIndexSize: " + labelIndex.size());
// throw new RunTimeException(e.toString());
// }
}
if (j > 0) {
ft.multiplyInEnd(factorTable);
}
factorTable = ft;
}
return factorTable;
}
private static FactorTable getFactorTable(double[][] weights, int[][] data, Index[] labelIndices, int numClasses) {
FactorTable factorTable = null;
for (int j = 0; j < labelIndices.length; j++) {
Index labelIndex = labelIndices[j];
FactorTable ft = new FactorTable(numClasses, j + 1);
// ... and each possible labeling for that clique
for (int k = 0, liSize = labelIndex.size(); k < liSize; k++) {
int[] label = ((CRFLabel) labelIndex.get(k)).getLabel();
double weight = 0.0;
for (int m = 0; m < data[j].length; m++) {
weight += weights[data[j][m]][k];
}
// try{
ft.setValue(label, weight);
// } catch (Exception e) {
// System.out.println("CRFCliqueTree::getFactorTable");
// System.out.println("NumClasses: " + numClasses + " j+1: " + (j+1));
// System.out.println("k: " + k+" label: " +label+" labelIndexSize: " + labelIndex.size());
// throw new RunTimeException(e.toString());
// }
}
if (j > 0) {
ft.multiplyInEnd(factorTable);
}
factorTable = ft;
}
return factorTable;
}
// SEQUENCE MODEL METHODS
/**
* Computes the distribution over values of the element at position pos in the sequence,
* conditioned on the values of the elements in all other positions of the provided sequence.
*
* @param sequence the sequence containing the rest of the values to condition on
* @param position the position of the element to give a distribution for
* @return an array of type double, representing a probability distribution; sums to 1.0
*/
public double[] getConditionalDistribution(int[] sequence, int position) {
double[] result = scoresOf(sequence, position);
ArrayMath.logNormalize(result);
// System.out.println("marginal: " + ArrayMath.toString(marginal, nf));
// System.out.println("conditional: " + ArrayMath.toString(result, nf));
result = ArrayMath.exp(result);
// System.out.println("conditional: " + ArrayMath.toString(result, nf));
return result;
}
/**
* Informs this sequence model that the value of the element at position pos has changed.
* This allows this sequence model to update its internal model if desired.
*
*/
public void updateSequenceElement(int[] sequence, int pos, int oldVal) {
// do nothing; we don't change this model
}
/**
* Informs this sequence model that the value of the whole sequence is initialized to sequence
*
*/
public void setInitialSequence(int[] sequence) {
// do nothing
}
/**
* @return the number of possible values for each element; it is assumed
* to be the same for the element at each position
*/
public int getNumValues() {
return numClasses;
}
}