package edu.stanford.nlp.ie.crf;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.util.Index;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Stores a factor table as a one dimensional array of doubles.
* This class supports a restricted form of factor table where each
* variable has the same set of values, but supports cliques of
* arbitrary size.
*
* @author Jenny Finkel
*/
@SuppressWarnings("UnusedDeclaration")
public class FactorTable {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(FactorTable.class);
private final int numClasses;
private final int windowSize;
private final double[] table;
public FactorTable(int numClasses, int windowSize) {
this.numClasses = numClasses;
this.windowSize = windowSize;
table = new double[SloppyMath.intPow(numClasses, windowSize)];
Arrays.fill(table, Double.NEGATIVE_INFINITY);
}
public FactorTable(FactorTable t) {
numClasses = t.numClasses();
windowSize = t.windowSize();
table = new double[t.size()];
System.arraycopy(t.table, 0, table, 0, t.size());
}
public boolean hasNaN() {
return ArrayMath.hasNaN(table);
}
public String toProbString() {
StringBuilder sb = new StringBuilder("{\n");
for (int i = 0; i < table.length; i++) {
sb.append(Arrays.toString(toArray(i)));
sb.append(": ");
sb.append(prob(toArray(i)));
sb.append('\n');
}
sb.append('}');
return sb.toString();
}
public String toNonLogString() {
StringBuilder sb = new StringBuilder("{\n");
for (int i = 0; i < table.length; i++) {
sb.append(Arrays.toString(toArray(i)));
sb.append(": ");
sb.append(Math.exp(getValue(i)));
sb.append('\n');
}
sb.append('}');
return sb.toString();
}
public <L> String toString(Index<L> classIndex) {
StringBuilder sb = new StringBuilder("{\n");
for (int i = 0; i < table.length; i++) {
sb.append(toString(toArray(i), classIndex));
sb.append(": ");
sb.append(getValue(i));
sb.append('\n');
}
sb.append('}');
return sb.toString();
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder("{\n");
for (int i = 0; i < table.length; i++) {
sb.append(Arrays.toString(toArray(i)));
sb.append(": ");
sb.append(getValue(i));
sb.append('\n');
}
sb.append('}');
return sb.toString();
}
private static <L> String toString(int[] array, Index<L> classIndex) {
List<L> l = new ArrayList<>(array.length);
for (int item : array) {
l.add(classIndex.get(item));
}
return l.toString();
}
public int[] toArray(int index) {
int[] indices = new int[windowSize];
for (int i = indices.length - 1; i >= 0; i--) {
indices[i] = index % numClasses;
index /= numClasses;
}
return indices;
}
/* e.g., numClasses = 4
[2,3] -> 11
0 1 2 3
4 5 6 7
8 9 10 11
[0,2] -> 2
summary:
index % numClasses -> curr timestamp index
index / numClasses -> prev timestamp index
*/
private int indexOf(int[] entry) {
int index = 0;
for (int item : entry) {
index *= numClasses;
index += item;
}
// if (index < 0) throw new RuntimeException("index=" + index + " entry=" + Arrays.toString(entry)); // only if overflow
return index;
}
private int indexOf(int[] front, int end) {
int index = 0;
for (int item : front) {
index *= numClasses;
index += item;
}
index *= numClasses;
index += end;
return index;
}
private int indexOf(int front, int[] end) {
int index = front;
for (int item : end) {
index *= numClasses;
index += item;
}
return index;
}
private int[] indicesEnd(int[] entries) {
int index = 0;
for (int entry : entries) {
index *= numClasses;
index += entry;
}
int[] indices = new int[SloppyMath.intPow(numClasses, windowSize - entries.length)];
final int offset = SloppyMath.intPow(numClasses, entries.length);
for (int i = 0; i < indices.length; i++) {
indices[i] = index;
index += offset;
}
// log.info("indicesEnd returning: " + Arrays.toString(indices));
return indices;
}
/** This now returns the first index of the requested entries.
* The run of numClasses ^ (windowSize - entries.length)
* successive entries will give all of them.
*
* @param entries The class indices of size windowsSize
* @return First index of requested entries
*/
private int indicesFront(int[] entries) {
int start = 0;
for (int entry : entries) {
start *= numClasses;
start += entry;
}
int offset = SloppyMath.intPow(numClasses, windowSize - entries.length);
return start * offset;
}
public int windowSize() {
return windowSize;
}
public int numClasses() {
return numClasses;
}
public int size() {
return table.length;
}
public double totalMass() {
return ArrayMath.logSum(table);
}
/** Returns a single clique potential. */
public double unnormalizedLogProb(int[] label) {
return getValue(label);
}
public double logProb(int[] label) {
return unnormalizedLogProb(label) - totalMass();
}
public double prob(int[] label) {
return Math.exp(unnormalizedLogProb(label) - totalMass());
}
/**
* Computes the probability of the tag OF being at the end of the table given
* that the previous tag sequence in table is GIVEN. given is at the beginning,
* of is at the end.
*
* @return the probability of the tag OF being at the end of the table
*/
public double conditionalLogProbGivenPrevious(int[] given, int of) {
if (given.length != windowSize - 1) {
throw new IllegalArgumentException("conditionalLogProbGivenPrevious requires given one less than clique size (" +
windowSize + ") but was " + Arrays.toString(given));
}
// Note: other similar methods could be optimized like this one, but this is the one the CRF uses....
/*
int startIndex = indicesFront(given);
int numCellsToSum = SloppyMath.intPow(numClasses, windowSize - given.length);
double z = ArrayMath.logSum(table, startIndex, startIndex + numCellsToSum);
int i = indexOf(given, of);
System.err.printf("startIndex is %d, numCellsToSum is %d, i is %d (of is %d)%n", startIndex, numCellsToSum, i, of);
*/
int startIndex = indicesFront(given);
double z = ArrayMath.logSum(table, startIndex, startIndex + numClasses);
int i = startIndex + of;
// System.err.printf("startIndex is %d, numCellsToSum is %d, i is %d (of is %d)%n", startIndex, numClasses, i, of);
return table[i] - z;
}
// public double conditionalLogProbGivenPreviousForPartial(int[] given, int of) {
// if (given.length != windowSize - 1) {
// log.info("error computing conditional log prob");
// System.exit(0);
// }
// // int[] label = indicesFront(given);
// // double[] masses = new double[label.length];
// // for (int i = 0; i < masses.length; i++) {
// // masses[i] = table[label[i]];
// // }
// // double z = ArrayMath.logSum(masses);
//
// int i = indexOf(given, of);
// // if (SloppyMath.isDangerous(z) || SloppyMath.isDangerous(table[i])) {
// // log.info("z="+z);
// // log.info("t="+table[i]);
// // }
//
// return table[i];
// }
/**
* Computes the probabilities of the tag at the end of the table given that
* the previous tag sequence in table is GIVEN. given is at the beginning,
* position in question is at the end
*
* @return the probabilities of the tag at the end of the table
*/
public double[] conditionalLogProbsGivenPrevious(int[] given) {
if (given.length != windowSize - 1) {
throw new IllegalArgumentException("conditionalLogProbsGivenPrevious requires given one less than clique size (" +
windowSize + ") but was " + Arrays.toString(given));
}
double[] result = new double[numClasses];
for (int i = 0; i < numClasses; i++) {
int index = indexOf(given, i);
result[i] = table[index];
}
ArrayMath.logNormalize(result);
return result;
}
/**
* Computes the probability of the sequence OF being at the end of the table
* given that the first tag in table is GIVEN. given is at the beginning, of is
* at the end
*
* @return the probability of the sequence of being at the end of the table
*/
public double conditionalLogProbGivenFirst(int given, int[] of) {
if (of.length != windowSize - 1) {
throw new IllegalArgumentException("conditionalLogProbGivenFirst requires of one less than clique size (" +
windowSize + ") but was " + Arrays.toString(of));
}
// compute P(given, of)
int[] labels = new int[windowSize];
labels[0] = given;
System.arraycopy(of, 0, labels, 1, windowSize - 1);
// double probAll = logProb(labels);
double probAll = unnormalizedLogProb(labels);
// compute P(given)
// double probGiven = logProbFront(given);
double probGiven = unnormalizedLogProbFront(given);
// compute P(given, of) / P(given)
return probAll - probGiven;
}
/**
* Computes the probability of the sequence OF being at the end of the table
* given that the first tag in table is GIVEN. given is at the beginning, of is
* at the end.
*
* @return the probability of the sequence of being at the end of the table
*/
public double unnormalizedConditionalLogProbGivenFirst(int given, int[] of) {
if (of.length != windowSize - 1) {
throw new IllegalArgumentException("unnormalizedConditionalLogProbGivenFirst requires of one less than clique size (" +
windowSize + ") but was " + Arrays.toString(of));
}
// compute P(given, of)
int[] labels = new int[windowSize];
labels[0] = given;
System.arraycopy(of, 0, labels, 1, windowSize - 1);
// double probAll = logProb(labels);
double probAll = unnormalizedLogProb(labels);
// compute P(given)
// double probGiven = logProbFront(given);
// double probGiven = unnormalizedLogProbFront(given);
// compute P(given, of) / P(given)
// return probAll - probGiven;
return probAll;
}
/**
* Computes the probability of the tag OF being at the beginning of the table
* given that the tag sequence GIVEN is at the end of the table. given is at
* the end, of is at the beginning
*
* @return the probability of the tag of being at the beginning of the table
*/
public double conditionalLogProbGivenNext(int[] given, int of) {
if (given.length != windowSize - 1) {
throw new IllegalArgumentException("conditionalLogProbGivenNext requires given one less than clique size (" +
windowSize + ") but was " + Arrays.toString(given));
}
int[] label = indicesEnd(given);
double[] masses = new double[label.length];
for (int i = 0; i < masses.length; i++) {
masses[i] = table[label[i]];
}
double z = ArrayMath.logSum(masses);
return table[indexOf(of, given)] - z;
}
public double unnormalizedLogProbFront(int[] labels) {
int startIndex = indicesFront(labels);
int numCellsToSum = SloppyMath.intPow(numClasses, windowSize - labels.length);
// double[] masses = new double[labels.length];
// for (int i = 0; i < masses.length; i++) {
// masses[i] = table[labels[i]];
// }
return ArrayMath.logSum(table, startIndex, startIndex + numCellsToSum);
}
public double logProbFront(int[] label) {
return unnormalizedLogProbFront(label) - totalMass();
}
public double unnormalizedLogProbFront(int label) {
int[] labels = { label };
return unnormalizedLogProbFront(labels);
}
public double logProbFront(int label) {
return unnormalizedLogProbFront(label) - totalMass();
}
public double unnormalizedLogProbEnd(int[] labels) {
labels = indicesEnd(labels);
double[] masses = new double[labels.length];
for (int i = 0; i < masses.length; i++) {
masses[i] = table[labels[i]];
}
return ArrayMath.logSum(masses);
}
public double logProbEnd(int[] labels) {
return unnormalizedLogProbEnd(labels) - totalMass();
}
public double unnormalizedLogProbEnd(int label) {
int[] labels = { label };
return unnormalizedLogProbEnd(labels);
}
public double logProbEnd(int label) {
return unnormalizedLogProbEnd(label) - totalMass();
}
public double getValue(int index) {
return table[index];
}
public double getValue(int[] label) {
return table[indexOf(label)];
}
public void setValue(int index, double value) {
table[index] = value;
}
public void setValue(int[] label, double value) {
// try{
table[indexOf(label)] = value;
// } catch (Exception e) {
// e.printStackTrace();
// log.info("Table length: " + table.length + " indexOf(label): "
// + indexOf(label));
// throw new ArrayIndexOutOfBoundsException(e.toString());
// // System.exit(1);
// }
}
public void incrementValue(int[] label, double value) {
incrementValue(indexOf(label), value);
}
public void incrementValue(int index, double value) {
table[index] += value;
}
void logIncrementValue(int index, double value) {
table[index] = SloppyMath.logAdd(table[index], value);
}
public void logIncrementValue(int[] label, double value) {
logIncrementValue(indexOf(label), value);
}
public void multiplyInFront(FactorTable other) {
int divisor = SloppyMath.intPow(numClasses, windowSize - other.windowSize());
for (int i = 0; i < table.length; i++) {
table[i] += other.getValue(i / divisor);
}
}
public void multiplyInEnd(FactorTable other) {
int divisor = SloppyMath.intPow(numClasses, other.windowSize());
for (int i = 0; i < table.length; i++) {
table[i] += other.getValue(i % divisor);
}
}
public FactorTable sumOutEnd() {
FactorTable ft = new FactorTable(numClasses, windowSize - 1);
for (int i = 0, sz = ft.size(); i < sz; i++) {
ft.table[i] = ArrayMath.logSum(table, i * numClasses, (i+1) * numClasses);
}
/*
for (int i = 0; i < table.length; i++) {
ft.logIncrementValue(i / numClasses, table[i]);
}
*/
return ft;
}
public FactorTable sumOutFront() {
FactorTable ft = new FactorTable(numClasses, windowSize - 1);
int stride = ft.size();
for (int i = 0; i < stride; i++) {
ft.setValue(i, ArrayMath.logSum(table, i, table.length, stride));
}
return ft;
}
public void divideBy(FactorTable other) {
for (int i = 0; i < table.length; i++) {
if (table[i] != Double.NEGATIVE_INFINITY || other.table[i] != Double.NEGATIVE_INFINITY) {
table[i] -= other.table[i];
}
}
}
public static void main(String[] args) {
int numClasses = 6;
final int cliqueSize = 3;
System.err.printf("Creating factor table with %d classes and window (clique) size %d%n", numClasses, cliqueSize);
FactorTable ft = new FactorTable(numClasses, cliqueSize);
/**
* for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { for (int k =
* 0; k < 2; k++) { int[] a = new int[]{i, j, k};
* System.out.print(ft.toString(a)+": "+ft.indexOf(a)); } } } for (int i =
* 0; i < 2; i++) { int[] b = new int[]{i};
* System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesFront(b))); }
* for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { int[] b = new
* int[]{i, j};
* System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesFront(b))); }
* } for (int i = 0; i < 2; i++) { int[] b = new int[]{i};
* System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesBack(b))); }
* for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { int[] b = new
* int[]{i, j}; ft2.setValue(b, (i*2)+j); } } for (int i = 0; i < 2; i++) {
* for (int j = 0; j < 2; j++) { int[] b = new int[]{i, j};
* System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesBack(b))); } }
*
* System.out.println("##########################################");
**/
for (int i = 0; i < numClasses; i++) {
for (int j = 0; j < numClasses; j++) {
for (int k = 0; k < numClasses; k++) {
int[] b = { i, j, k };
ft.setValue(b, (i * 4) + (j * 2) + k);
}
}
}
log.info(ft);
double normalization = 0.0;
for (int i = 0; i < numClasses; i++) {
for (int j = 0; j < numClasses; j++) {
for (int k = 0; k < numClasses; k++) {
normalization += ft.unnormalizedLogProb(new int[] {i, j, k});
}
}
}
log.info("Normalization Z = " + normalization);
log.info(ft.sumOutFront());
FactorTable ft2 = new FactorTable(numClasses, 2);
for (int i = 0; i < numClasses; i++) {
for (int j = 0; j < numClasses; j++) {
int[] b = { i, j };
ft2.setValue(b, i * numClasses + j);
}
}
log.info(ft2);
// FactorTable ft3 = ft2.sumOutFront();
// log.info(ft3);
for (int i = 0; i < numClasses; i++) {
for (int j = 0; j < numClasses; j++) {
int[] b = { i, j };
double t = 0;
for (int k = 0; k < numClasses; k++) {
t += Math.exp(ft.conditionalLogProbGivenPrevious(b, k));
System.err
.println(k + "|" + i + ',' + j + " : " + Math.exp(ft.conditionalLogProbGivenPrevious(b, k)));
}
log.info(t);
}
}
log.info("conditionalLogProbGivenFirst");
for (int j = 0; j < numClasses; j++) {
for (int k = 0; k < numClasses; k++) {
int[] b = { j, k };
double t = 0.0;
for (int i = 0; i < numClasses; i++) {
t += ft.unnormalizedConditionalLogProbGivenFirst(i, b);
System.err
.println(i + "|" + j + ',' + k + " : " + ft.unnormalizedConditionalLogProbGivenFirst(i, b));
}
log.info(t);
}
}
log.info("conditionalLogProbGivenFirst");
for (int i = 0; i < numClasses; i++) {
for (int j = 0; j < numClasses; j++) {
int[] b = { i, j };
double t = 0.0;
for (int k = 0; k < numClasses; k++) {
t += ft.conditionalLogProbGivenNext(b, k);
System.err
.println(i + "," + j + '|' + k + " : " + ft.conditionalLogProbGivenNext(b, k));
}
log.info(t);
}
}
numClasses = 2;
FactorTable ft3 = new FactorTable(numClasses, cliqueSize);
ft3.setValue(new int[] {0, 0, 0}, Math.log(0.25));
ft3.setValue(new int[] {0, 0, 1}, Math.log(0.35));
ft3.setValue(new int[] {0, 1, 0}, Math.log(0.05));
ft3.setValue(new int[] {0, 1, 1}, Math.log(0.07));
ft3.setValue(new int[] {1, 0, 0}, Math.log(0.08));
ft3.setValue(new int[] {1, 0, 1}, Math.log(0.16));
ft3.setValue(new int[] {1, 1, 0}, Math.log(1e-50));
ft3.setValue(new int[] {1, 1, 1}, Math.log(1e-50));
FactorTable ft4 = ft3.sumOutFront();
log.info(ft4.toNonLogString());
FactorTable ft5 = ft3.sumOutEnd();
log.info(ft5.toNonLogString());
} // end main
}