package edu.berkeley.cs.nlp.ocular.preprocessing;
import java.util.Arrays;
import java.util.Random;
import tberg.murphy.arrays.a;
import tberg.murphy.math.m;
/**
* @author Taylor Berg-Kirkpatrick (tberg@eecs.berkeley.edu)
*/
public class VerticalModel {
public static enum VerticalModelStateType {
ASCENDER, BASE, DESCENDER;
}
public static class SuffStats {
public double[] totalMass = new double[VerticalModelStateType.values().length];
public double[] totalMassTimesLength = new double[VerticalModelStateType.values().length];
public double[] totalEmissionMass = new double[VerticalModelStateType.values().length];
public double[] totalSizeMass = new double[VerticalModelStateType.values().length];
public SuffStats() {
Arrays.fill(totalMass, 0);
Arrays.fill(totalMassTimesLength, 0);
Arrays.fill(totalEmissionMass, 0);
Arrays.fill(totalSizeMass, 0);
}
public SuffStats(VerticalProfile profile, int state, int start, int stop, double mass) {
Arrays.fill(totalMass, 0);
Arrays.fill(totalMassTimesLength, 0);
Arrays.fill(totalEmissionMass, 0);
Arrays.fill(totalSizeMass, 0);
this.totalMass[state] += mass;
this.totalMassTimesLength[state] += mass * (stop - start);
for (int i = start; i < stop; i++) {
this.totalEmissionMass[state] += mass * profile.emissionsPerRow[i];
}
this.totalSizeMass[state] += mass * (stop - start);
}
public void addIn(SuffStats other) {
for (int i = 0; i < totalMass.length; i++) {
totalMass[i] += other.totalMass[i];
}
for (int i = 0; i < totalMass.length; i++) {
totalMassTimesLength[i] += other.totalMassTimesLength[i];
}
for (int i = 0; i < totalEmissionMass.length; i++) {
totalEmissionMass[i] += other.totalEmissionMass[i];
}
for (int i = 0; i < totalSizeMass.length; i++) {
totalSizeMass[i] += other.totalSizeMass[i];
}
}
public double[] getEmissionMeans() {
double[] emissionMeans = new double[totalSizeMass.length];
for (int i = 0; i < emissionMeans.length; i++) {
emissionMeans[i] = totalEmissionMass[i]/totalMassTimesLength[i];
}
return emissionMeans;
}
public double[] getSizeMeans() {
double[] sizeMeans = new double[totalSizeMass.length];
for (int i = 0; i < sizeMeans.length; i++) {
sizeMeans[i] = totalSizeMass[i]/totalMass[i];
}
return sizeMeans;
}
}
public int imageWidth;
public double[][] emissionLogProbs;
public double emissionVariance;
public double[][] sizeLogProbs;
public double[] sizeVariances;
public static final int[] minSizes = { 6, 6, 6 };
public static final int[] maxSizes = { 30, 30, 30 };
public static VerticalModel getRandomlyInitializedModel(int imageWidth, Random rand) {
double[] emissionMeans = new double[VerticalModelStateType.values().length];
// emissionMeans[0] = 0.1 * imageWidth;
// emissionMeans[1] = 0.3 * imageWidth;
// emissionMeans[2] = 0.0 * imageWidth;
double[] blackFractions = new double[2];
for (int i = 0; i < blackFractions.length; i++) {
blackFractions[i] = 0.8 * rand.nextDouble();
}
Arrays.sort(blackFractions);
emissionMeans[0] = blackFractions[0] * imageWidth;
emissionMeans[1] = blackFractions[1] * imageWidth;
emissionMeans[2] = blackFractions[0] * imageWidth;
double emissionStd = 0.05;
double emissionVariance = (emissionStd * imageWidth) * (emissionStd * imageWidth);
double[] sizeMeans = new double[VerticalModelStateType.values().length];
double nonSpaceMean = rand.nextInt(Math.min(maxSizes[0], maxSizes[1])-Math.max(minSizes[0], minSizes[1])) + Math.max(minSizes[0], minSizes[1]);
double spaceMean = rand.nextInt(maxSizes[2]-minSizes[2]) + minSizes[2];
sizeMeans[0] = nonSpaceMean;
sizeMeans[1] = nonSpaceMean;
sizeMeans[2] = spaceMean;
double[] sizeVariances = new double[VerticalModelStateType.values().length];
sizeVariances[0] = 2.0*2.0;
sizeVariances[1] = 2.0*2.0;
sizeVariances[2] = 2.0*2.0;
return new VerticalModel(imageWidth, emissionMeans, emissionVariance, sizeMeans, sizeVariances);
}
public VerticalModel(int imageWidth, double[] emissionMeans, double emissionVariance, double[] sizeMeans, double[] sizeVariances) {
this.imageWidth = imageWidth;
this.emissionVariance = emissionVariance;
this.sizeVariances = sizeVariances;
updateMeansOnly(emissionMeans, sizeMeans);
}
public void updateMeansOnly(double[] emissionMeans, double[] sizeMeans) {
sizeVariances[0] = Math.pow(Math.sqrt(sizeVariances[0]) * 0.8, 2.0);
sizeVariances[1] = Math.pow(Math.sqrt(sizeVariances[1]) * 0.8, 2.0);
sizeVariances[2] = Math.pow(Math.sqrt(sizeVariances[2]) * 0.8, 2.0);
emissionVariance = Math.pow(Math.sqrt(emissionVariance) * 0.8, 2.0);
setEmissionParams(emissionMeans, emissionVariance);
setSizeParams(sizeMeans, sizeVariances);
// System.out.println("Instantiating new model");
// System.out.println("emissionMeans = " + Arrays.toString(emissionMeans));
// System.out.println("emissionLogProbs[0] = " + Arrays.toString(emissionLogProbs[0]));
// System.out.println("emissionLogProbs[1] = " + Arrays.toString(emissionLogProbs[1]));
// System.out.println("emissionLogProbs[2] = " + Arrays.toString(emissionLogProbs[2]));
// System.out.println("sizeMeans = " + Arrays.toString(sizeMeans));
// System.out.println("sizeLogProbs[0] = " + Arrays.toString(sizeLogProbs[0]));
// System.out.println("sizeLogProbs[1] = " + Arrays.toString(sizeLogProbs[1]));
// System.out.println("sizeLogProbs[2] = " + Arrays.toString(sizeLogProbs[2]));
}
public void freezeSizeParams(int flexibilityRadius) {
for (int i = 0; i < sizeLogProbs.length; i++) {
int maxIdx = -1;
double maxLogProb = Double.NEGATIVE_INFINITY;
for (int j = 0; j < sizeLogProbs[i].length; j++) {
if (sizeLogProbs[i][j] > maxLogProb) {
maxIdx = j;
maxLogProb = sizeLogProbs[i][j];
}
}
// System.out.println("maxIdx: " + maxIdx);
for (int j = 0; j < sizeLogProbs[i].length; j++) {
if (j >= Math.max(0, maxIdx - flexibilityRadius) && j <= Math.min(sizeLogProbs[i].length, maxIdx + flexibilityRadius)) {
sizeLogProbs[i][j] = 0.0;
} else {
sizeLogProbs[i][j] = Double.NEGATIVE_INFINITY;
}
}
sizeLogProbs[i] = a.log(a.normalize(a.exp(sizeLogProbs[i])));
// System.out.println(Arrays.toString(sizeLogProbs[i]));
}
}
private void setEmissionParams(double[] emissionMeans, double emissionVariance) {
this.emissionLogProbs = new double[emissionMeans.length][];
for (int i = 0; i < emissionMeans.length; i++) {
this.emissionLogProbs[i] = new double[imageWidth];
for (int j = 0; j < this.emissionLogProbs[i].length; j++) {
this.emissionLogProbs[i][j] = m.gaussianLogProb(emissionMeans[i], emissionVariance, j);
}
this.emissionLogProbs[i] = a.log(a.normalize(a.exp(this.emissionLogProbs[i])));
// Smooth a little bit; only necessary if the variance isn't set high enough
// final double SMOOTHING = 0.001;
// this.emissionLogProbs[i] = a.log(a.add(a.scale(a.exp(this.emissionLogProbs[i]), 1.0 - SMOOTHING), SMOOTHING/this.emissionLogProbs[i].length));
// assert Math.abs(a.sum(a.exp(this.emissionLogProbs[i])) - 1.0) < 1e-8 : a.sum(a.exp(this.emissionLogProbs[i]));
}
}
private void setSizeParams(double[] sizeMeans, double[] sizeVariances) {
this.sizeLogProbs = new double[sizeMeans.length][];
for (int i = 0; i < sizeMeans.length; i++) {
this.sizeLogProbs[i] = new double[maxSize(i) - minSize(i)];
Arrays.fill(this.sizeLogProbs[i], 0.0);
for (int j = 0; j < this.sizeLogProbs[i].length; j++) {
this.sizeLogProbs[i][j] = m.gaussianLogProb(sizeMeans[i], sizeVariances[i], minSize(i) + j);
}
this.sizeLogProbs[i] = a.log(a.normalize(a.exp(this.sizeLogProbs[i])));
}
}
public int numStates() {
return VerticalModelStateType.values().length;
}
public int getPredecessor(int stateIdx) {
return (stateIdx + numStates() - 1) % numStates();
}
public int getSuccessor(int stateIdx) {
return (stateIdx + 1) % numStates();
}
public int minSize(int stateType) {
return minSizes[stateType];
}
public int maxSize(int stateType) {
return maxSizes[stateType];
}
public double getLogProb(VerticalProfile profile, int stateIdx, int posn) {
int emissionIdx = (int)(Math.min(profile.emissionsPerRow[posn], emissionLogProbs[stateIdx].length - 1));
return emissionLogProbs[stateIdx][emissionIdx];
}
public double getLogProb(VerticalProfile profile, int stateIdx, int start, int stop) {
// System.out.println(start + " " + stop + " " + minSize(stateIdx));
double totalLogProb = this.sizeLogProbs[stateIdx][stop-start-minSize(stateIdx)];
for (int i = start; i < stop; i++) {
totalLogProb += getLogProb(profile, stateIdx, i);
}
return totalLogProb;
}
}