package edu.stanford.nlp.ie.crf; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.math.ArrayMath; import edu.stanford.nlp.math.SloppyMath; import edu.stanford.nlp.util.Index; import java.util.ArrayList; import java.util.Arrays; import java.util.List; /** Stores a factor table as a one dimensional array of floats. * * @author Jenny Finkel */ public class FloatFactorTable { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(FloatFactorTable.class); private final int numClasses; private final int windowSize; private final float[] table; public FloatFactorTable(int numClasses, int windowSize) { this.numClasses = numClasses; this.windowSize = windowSize; table = new float[SloppyMath.intPow(numClasses, windowSize)]; Arrays.fill(table, Float.NEGATIVE_INFINITY); } public boolean hasNaN() { return ArrayMath.hasNaN(table); } public String toProbString() { StringBuilder sb = new StringBuilder("{\n"); for (int i = 0; i < table.length; i++) { sb.append(Arrays.toString(toArray(i))); sb.append(": "); sb.append(prob(toArray(i))); sb.append("\n"); } sb.append("}"); return sb.toString(); } public String toString(Index classIndex) { StringBuilder sb = new StringBuilder("{\n"); for (int i = 0; i < table.length; i++) { sb.append(toString(toArray(i), classIndex)); sb.append(": "); sb.append(getValue(i)); sb.append("\n"); } sb.append("}"); return sb.toString(); } @Override public String toString() { StringBuilder sb = new StringBuilder("{\n"); for (int i = 0; i < table.length; i++) { sb.append(Arrays.toString(toArray(i))); sb.append(": "); sb.append(getValue(i)); sb.append("\n"); } sb.append("}"); return sb.toString(); } private String toString(int[] array, Index classIndex) { List l = new ArrayList(); for (int anArray : array) { l.add(classIndex.get(anArray)); } return l.toString(); } private int[] toArray(int index) { int[] indices = new int[windowSize]; for (int i = indices.length - 1; i >= 0; i--) { indices[i] = index % numClasses; index /= numClasses; } return indices; } private int indexOf(int[] entry) { int index = 0; for (int anEntry : entry) { index *= numClasses; index += anEntry; } return index; } private int indexOf(int[] front, int end) { int index = 0; for (int aFront : front) { index *= numClasses; index += aFront; } index *= numClasses; index += end; return index; } private int[] indicesEnd(int[] entries) { int[] indices = new int[SloppyMath.intPow(numClasses, windowSize - entries.length)]; int offset = SloppyMath.intPow(numClasses, entries.length); int index = 0; for (int entry : entries) { index *= numClasses; index += entry; } for (int i = 0; i < indices.length; i++) { indices[i] = index; index += offset; } return indices; } private int[] indicesFront(int[] entries) { int[] indices = new int[SloppyMath.intPow(numClasses, windowSize - entries.length)]; int offset = SloppyMath.intPow(numClasses, windowSize - entries.length); int start = 0; for (int entry : entries) { start *= numClasses; start += entry; } start *= offset; int end = 0; for (int i = 0; i < entries.length; i++) { end *= numClasses; end += entries[i]; if (i == entries.length - 1) { end += 1; } } end *= offset; for (int i = start; i < end; i++) { indices[i - start] = i; } return indices; } public int windowSize() { return windowSize; } public int numClasses() { return numClasses; } private int size() { return table.length; } public float totalMass() { return ArrayMath.logSum(table); } public float unnormalizedLogProb(int[] label) { return getValue(label); } public float logProb(int[] label) { return unnormalizedLogProb(label) - totalMass(); } public float prob(int[] label) { return (float) Math.exp(unnormalizedLogProb(label) - totalMass()); } // given is at the begining, of is at the end public float conditionalLogProb(int[] given, int of) { if (given.length != windowSize - 1) { log.info("error computing conditional log prob"); System.exit(0); } int[] label = indicesFront(given); float[] masses = new float[label.length]; for (int i = 0; i < masses.length; i++) { masses[i] = table[label[i]]; } float z = ArrayMath.logSum(masses); return table[indexOf(given, of)] - z; } public float unnormalizedLogProbFront(int[] label) { label = indicesFront(label); float[] masses = new float[label.length]; for (int i = 0; i < masses.length; i++) { masses[i] = table[label[i]]; } return ArrayMath.logSum(masses); } public float logProbFront(int[] label) { return unnormalizedLogProbFront(label) - totalMass(); } public float unnormalizedLogProbEnd(int[] label) { label = indicesEnd(label); float[] masses = new float[label.length]; for (int i = 0; i < masses.length; i++) { masses[i] = table[label[i]]; } return ArrayMath.logSum(masses); } public float logProbEnd(int[] label) { return unnormalizedLogProbEnd(label) - totalMass(); } public float unnormalizedLogProbEnd(int label) { int[] l = {label}; l = indicesEnd(l); float[] masses = new float[l.length]; for (int i = 0; i < masses.length; i++) { masses[i] = table[l[i]]; } return ArrayMath.logSum(masses); } public float logProbEnd(int label) { return unnormalizedLogProbEnd(label) - totalMass(); } private float getValue(int index) { return table[index]; } public float getValue(int[] label) { return table[indexOf(label)]; } private void setValue(int index, float value) { table[index] = value; } public void setValue(int[] label, float value) { table[indexOf(label)] = value; } public void incrementValue(int[] label, float value) { table[indexOf(label)] += value; } private void logIncrementValue(int index, float value) { table[index] = SloppyMath.logAdd(table[index], value); } public void logIncrementValue(int[] label, float value) { int index = indexOf(label); table[index] = SloppyMath.logAdd(table[index], value); } public void multiplyInFront(FloatFactorTable other) { int divisor = SloppyMath.intPow(numClasses, windowSize - other.windowSize()); for (int i = 0; i < table.length; i++) { table[i] += other.getValue(i / divisor); } } public void multiplyInEnd(FloatFactorTable other) { int divisor = SloppyMath.intPow(numClasses, other.windowSize()); for (int i = 0; i < table.length; i++) { table[i] += other.getValue(i % divisor); } } public FloatFactorTable sumOutEnd() { FloatFactorTable ft = new FloatFactorTable(numClasses, windowSize - 1); for (int i = 0; i < table.length; i++) { ft.logIncrementValue(i / numClasses, table[i]); } return ft; } public FloatFactorTable sumOutFront() { FloatFactorTable ft = new FloatFactorTable(numClasses, windowSize - 1); int mod = SloppyMath.intPow(numClasses, windowSize - 1); for (int i = 0; i < table.length; i++) { ft.logIncrementValue(i % mod, table[i]); } return ft; } public void divideBy(FloatFactorTable other) { for (int i = 0; i < table.length; i++) { if (table[i] != Float.NEGATIVE_INFINITY || other.table[i] != Float.NEGATIVE_INFINITY) { table[i] -= other.table[i]; } } } public static void main(String[] args) { FloatFactorTable ft = new FloatFactorTable(6, 3); /** for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { for (int k = 0; k < 2; k++) { int[] a = new int[]{i, j, k}; System.out.print(ft.toString(a)+": "+ft.indexOf(a)); } } } for (int i = 0; i < 2; i++) { int[] b = new int[]{i}; System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesFront(b))); } for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { int[] b = new int[]{i, j}; System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesFront(b))); } } for (int i = 0; i < 2; i++) { int[] b = new int[]{i}; System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesBack(b))); } for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { int[] b = new int[]{i, j}; ft2.setValue(b, (i*2)+j); } } for (int i = 0; i < 2; i++) { for (int j = 0; j < 2; j++) { int[] b = new int[]{i, j}; System.out.print(ft.toString(b)+": "+ft.toString(ft.indicesBack(b))); } } System.out.println("##########################################"); **/ for (int i = 0; i < 6; i++) { for (int j = 0; j < 6; j++) { for (int k = 0; k < 6; k++) { int[] b = new int[]{i, j, k}; ft.setValue(b, (i * 4) + (j * 2) + k); } } } //System.out.println(ft); //System.out.println(ft.sumOutFront()); FloatFactorTable ft2 = new FloatFactorTable(6, 2); for (int i = 0; i < 6; i++) { for (int j = 0; j < 6; j++) { int[] b = new int[]{i, j}; ft2.setValue(b, i * 6 + j); } } System.out.println(ft); //FloatFactorTable ft3 = ft2.sumOutFront(); //System.out.println(ft3); for (int i = 0; i < 6; i++) { for (int j = 0; j < 6; j++) { int[] b = new int[]{i, j}; float t = 0; for (int k = 0; k < 6; k++) { t += Math.exp(ft.conditionalLogProb(b, k)); log.info(k + "|" + i + "," + j + " : " + Math.exp(ft.conditionalLogProb(b, k))); } System.out.println(t); } } } }