package edu.berkeley.cs.nlp.ocular.model; import tberg.murphy.indexer.Indexer; import tberg.murphy.indexer.IntArrayIndexer; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import tberg.murphy.math.m; import tberg.murphy.opt.DifferentiableFunction; import tberg.murphy.opt.LBFGSMinimizer; import tberg.murphy.opt.Minimizer; import tberg.murphy.tuple.Pair; import tberg.murphy.arrays.a; import edu.berkeley.cs.nlp.ocular.data.textreader.Charset; import edu.berkeley.cs.nlp.ocular.image.ImageUtils.PixelType; import edu.berkeley.cs.nlp.ocular.util.StringHelper; /** * @author Taylor Berg-Kirkpatrick (tberg@eecs.berkeley.edu) */ public class CharacterTemplate implements Serializable { private static final long serialVersionUID = 2L; public static final int LINE_HEIGHT = 30; public static final float[] EXP_GAINS = new float[] {1.0f, 0.5f, 0.25f}; public static final float[] EXP_STD_DEVS = new float[] {1.5f, 1.5f, 1.5f}; public static final float[] EXP_SPC_BLACK_PROBS = new float[] {5e-2f, 2e-2f, 1e-1f}; public static final int MAX_OFFSET = 5; public static final float EMIT_REG = 1e-2f; public static final float INIT_WIDTH_STD_THRESH = 2.5f; public static final float INIT_WIDTH_MIN_VAR = 1e-2f; public static final float LEARN_WIDTH_STD_THRESH = 2.5f; public static final float LEARN_WIDTH_MIN_VAR = 1e-2f; public static final float INIT_LBFGS_TOL = 1e-10f; public static final int INIT_LBFGS_ITERS = 1000; public static final float MSTEP_LBFGS_TOL = 1e-5f; public static final int MSTEP_LBFGS_ITERS = 20; private String character; private int templateMaxWidth; private int templateMinWidth; private float[][] templateWeights; private float[][] templateWeightsPriorMeans; private float[][][][] templateLogBlackProbs; private float[][][][] templateLogWhiteProbs; private boolean[][] templateCountSparsity; private boolean[][] templateLogProbsCached; private float[][][][] templateBlackCounts; private float[][][][] templateWhiteCounts; private float[] templateWidthProbs; private float[] templateWidthCounts; private Indexer<int[]> paramIndexer; private float[][][][] interpolationWeights; public float[][][][] getInterpolationWeights() { return interpolationWeights; } public CharacterTemplate(String character, float templateMaxWidthFraction, float templateMinWidthFraction) { this.templateMaxWidth = (int) Math.max(1, Math.floor(templateMaxWidthFraction*LINE_HEIGHT)); this.templateMinWidth = (int) Math.max(1, Math.floor(templateMinWidthFraction*LINE_HEIGHT)); int numTemplateWidths = (templateMaxWidth - templateMinWidth) + 1; this.templateWidthProbs = new float[numTemplateWidths]; for (int i=0; i<templateWidthProbs.length; ++i) templateWidthProbs[i] = 1.0f; a.normalizei(templateWidthProbs); this.character = character; this.templateWidthCounts = new float[templateWidthProbs.length]; if (!character.equals(Charset.SPACE)) { this.templateWeights = new float[templateMaxWidth][LINE_HEIGHT]; for (int i=0; i<templateMaxWidth; ++i) { Arrays.fill(templateWeights[i], 0.0f); }; this.templateWeightsPriorMeans = new float[templateMaxWidth][LINE_HEIGHT]; for (int i=0; i<templateMaxWidth; ++i) { Arrays.fill(templateWeightsPriorMeans[i], 0.0f); }; this.templateLogBlackProbs = new float[EXP_GAINS.length][templateWidthProbs.length][][]; this.templateLogWhiteProbs = new float[EXP_GAINS.length][templateWidthProbs.length][][]; this.templateLogProbsCached = new boolean[EXP_GAINS.length][templateWidthProbs.length]; this.templateCountSparsity = new boolean[EXP_GAINS.length][templateWidthProbs.length]; this.templateBlackCounts = new float[EXP_GAINS.length][templateWidthProbs.length][][]; this.templateWhiteCounts = new float[EXP_GAINS.length][templateWidthProbs.length][][]; this.interpolationWeights = new float[EXP_GAINS.length][templateWidthProbs.length][][]; for (int e=0; e<EXP_GAINS.length; ++e) { for (int w=0; w<templateWidthProbs.length; ++w) { int width = templateMinWidth+w; this.interpolationWeights[e][w] = new float[width][templateMaxWidth]; this.templateLogBlackProbs[e][w] = new float[width][LINE_HEIGHT]; this.templateLogWhiteProbs[e][w] = new float[width][LINE_HEIGHT]; this.templateBlackCounts[e][w] = new float[width][LINE_HEIGHT]; this.templateWhiteCounts[e][w] = new float[width][LINE_HEIGHT]; float interval = ((float) templateMaxWidth) / ((float) width); for (int i=0; i<width; ++i) { float emissionLocation = interval*(i+0.5f); for (int j=0; j<templateMaxWidth; ++j) { float templatePixelLocation = j+0.5f; this.interpolationWeights[e][w][i][j] = (float) Math.exp(m.gaussianLogProb((templatePixelLocation - emissionLocation)*(templatePixelLocation-emissionLocation), EXP_STD_DEVS[e]*interval)); } a.normalizei(this.interpolationWeights[e][w][i]); a.scalei(this.interpolationWeights[e][w][i], EXP_GAINS[e]); } } } this.paramIndexer = new IntArrayIndexer(); for (int i=0; i<this.templateWeights.length; ++i) { for (int j=0; j<this.templateWeights[i].length; ++j) { this.paramIndexer.getIndex(new int[] {i, j}); } } this.paramIndexer.lock(); } } public void initializeAndSetPriorFromFontData(PixelType[][][] fontData) { if (!character.equals(Charset.SPACE)) { System.out.println("Initializing "+character+" from font data..."); clearEmissionCounts(); clearWidthCounts(); for (PixelType[][] observations : fontData) { if (observations.length >= templateMinWidth() && observations.length <= templateMaxWidth()) { incrementWidthCounts(observations.length, 1.0f); for (int pos=0; pos<observations.length; ++pos) incrementEmissionCounts(0, 0, observations.length, pos, 1.0f, observations[pos]); } } updateWidthParameters(INIT_WIDTH_MIN_VAR, INIT_WIDTH_STD_THRESH); updateEmissionParameters(INIT_LBFGS_TOL, INIT_LBFGS_ITERS); templateWeightsPriorMeans = a.copy(templateWeights); System.out.println(toString()); } } public int[] allowedWidths() { List<Integer> allowedWidths = new ArrayList<Integer>(); for (int w=templateMinWidth(); w<=templateMaxWidth(); ++w) { if (widthProb(w) > 0.0f) { allowedWidths.add(w); } } return a.toIntArray(allowedWidths); } public float[][] blackProbs(int exposure, int offset, int width) { float[][] result = new float[width][LINE_HEIGHT]; if (!character.equals(Charset.SPACE)) { for (int i=0; i<width; ++i) { for (int j=0; j<LINE_HEIGHT; ++j) { result[i][j] = (float) Math.exp(templateLogProbs(width, exposure, true)[i][Math.min(LINE_HEIGHT-1, Math.max(0, j+offset))]); } } } else { for (int i=0; i<width; ++i) { for (int j=0; j<LINE_HEIGHT; ++j) { result[i][j] = EXP_SPC_BLACK_PROBS[exposure]; } } } return result; } public float[][] logBlackProbs(int exposure, int offset, int width) { float[][] result = new float[width][LINE_HEIGHT]; if (!character.equals(Charset.SPACE)) { for (int i=0; i<width; ++i) { for (int j=0; j<LINE_HEIGHT; ++j) { result[i][j] = (float) templateLogProbs(width, exposure, true)[i][Math.min(LINE_HEIGHT-1, Math.max(0, j+offset))]; } } } else { for (int i=0; i<width; ++i) { for (int j=0; j<LINE_HEIGHT; ++j) { result[i][j] = (float) Math.log(EXP_SPC_BLACK_PROBS[exposure]); } } } return result; } public float[][] logWhiteProbs(int exposure, int offset, int width) { float[][] result = new float[width][LINE_HEIGHT]; if (!character.equals(Charset.SPACE)) { for (int i=0; i<width; ++i) { for (int j=0; j<LINE_HEIGHT; ++j) { result[i][j] = (float) templateLogProbs(width, exposure, false)[i][Math.min(LINE_HEIGHT-1, Math.max(0, j+offset))]; } } } else { for (int i=0; i<width; ++i) { for (int j=0; j<LINE_HEIGHT; ++j) { result[i][j] = (float) Math.log(1.0 - EXP_SPC_BLACK_PROBS[exposure]); } } } return result; } public float emissionLogProb(PixelType[][] observations, int startCol, int endCol, int exposure, int offset) { int width = endCol - startCol; float logProb = 0.0f; for (int i=0; i<width; ++i) { logProb += columnEmissionLogProb(exposure, offset, width, i, observations[startCol+i]); } return logProb; } private float columnEmissionLogProb(int exposure, int offset, int width, int pos, PixelType[] observation) { float logProb = 0.0f; for (int j=0; j<LINE_HEIGHT; ++j) { logProb += pixelEmissionLogProb(exposure, offset, width, pos, j, observation[j]); } return logProb; } private float pixelEmissionLogProb(int exposure, int offset, int width, int pos, int j, PixelType observation) { if (!character.equals(Charset.SPACE)) { if (observation == PixelType.BLACK) { return templateLogProbs(width, exposure, true)[pos][Math.min(LINE_HEIGHT-1, Math.max(0, j+offset))]; } if (observation == PixelType.WHITE) { return templateLogProbs(width, exposure, false)[pos][Math.min(LINE_HEIGHT-1, Math.max(0, j+offset))]; } else { return 0.0f; } } else { if (observation == PixelType.BLACK) { return (float) Math.log(EXP_SPC_BLACK_PROBS[exposure]); } if (observation == PixelType.WHITE) { return (float) Math.log(1.0 - EXP_SPC_BLACK_PROBS[exposure]); } else { return 0.0f; } } } public float widthProb(int width) { return templateWidthProbs[width-templateMinWidth()]; } public float widthLogProb(int width) { return (float) Math.log(templateWidthProbs[width-templateMinWidth()]); } public void clearCounts() { clearEmissionCounts(); clearWidthCounts(); } public void incrementCounts(float count, PixelType[][] observations, int startCol, int width, int exposure, int offset) { for (int i=0; i<width; ++i) { incrementEmissionCounts(exposure, offset, width, i, count, observations[startCol+i]); } incrementWidthCounts(width, count); } public void updateParameters() { updateWidthParameters(LEARN_WIDTH_MIN_VAR, LEARN_WIDTH_STD_THRESH); updateEmissionParameters(MSTEP_LBFGS_TOL, MSTEP_LBFGS_ITERS); } public String getCharacter() { return character; } public String toString() { int bestWidth = -1; double bestWidthProb = Double.NEGATIVE_INFINITY; for (int width : allowedWidths()) { if (widthProb(width) > bestWidthProb) { bestWidthProb = widthProb(width); bestWidth = width; } } float[][] blackProbs = blackProbs(EXP_GAINS.length/2, 0, bestWidth); StringBuffer buf = new StringBuffer(); buf.append(character).append(" ").append(StringHelper.toUnicode(character)).append(":\n"); for (int j=0; j<LINE_HEIGHT; ++j) { for (int i=0; i<bestWidth; ++i) { float prob = blackProbs[i][j]; if (prob >= 0.0 && prob < 0.333) { buf.append(". "); } else if (prob >= 0.333 && prob < 0.666) { buf.append("o "); } else if (prob >= 0.666) { buf.append("O "); } } buf.append("\n"); } buf.append("Width probs: ").append(renderWidthProbs(templateWidthProbs, templateMinWidth())).append("\n"); return buf.toString(); } private String renderWidthProbs(float[] probs, int firstIndex) { if (probs.length <= 0) throw new RuntimeException("probs.length <= 0. was probs.length=" + probs.length); StringBuffer buf = new StringBuffer(); for (int i=0; i<probs.length; ++i) { buf.append(i+firstIndex).append(" = ").append(String.format("%.2f", probs[i])).append(", "); } buf.delete(buf.length() - 2, buf.length()); return buf.toString(); } public int templateMaxWidth() { return templateMaxWidth; } public int templateMinWidth() { return templateMinWidth; } private void clearWidthCounts() { Arrays.fill(templateWidthCounts, 0.0f); } private void incrementWidthCounts(int width, float count) { synchronized (templateWidthCounts) { templateWidthCounts[width-templateMinWidth] += count; } } private void updateWidthParameters(float widthMinVar, float widthStdThresh) { if (!character.equals(Charset.SPACE)) { if (a.sum(templateWidthCounts) > 0.0) { float mean = 0.0f; float totalCount = a.sum(templateWidthCounts); for (int width=templateMinWidth; width<=templateMaxWidth; ++width) { mean += width * (templateWidthCounts[width-templateMinWidth] / totalCount); } float var = 0.0f; for (int width=templateMinWidth; width<=templateMaxWidth; ++width) { var += (mean - width) * (mean - width) * (templateWidthCounts[width-templateMinWidth] / totalCount); } templateWidthProbs = buildGuassianWidthProbs(mean, Math.max(widthMinVar, var), templateMinWidth, templateMaxWidth, widthStdThresh); } } } private static float[] buildGuassianWidthProbs(float mean, float var, int min, int max, float guassianWidthStdMultThreshold) { float[] probs = new float[max-min+1]; for (int i=min; i<=max; ++i) { float sqrDistFromMean = (mean - i)*(mean - i); if (Math.sqrt(sqrDistFromMean) < guassianWidthStdMultThreshold*Math.sqrt(var)) { probs[i-min] = (float) Math.exp(-sqrDistFromMean/(2.0*var)); } } a.normalizei(probs); return probs; } private void clearEmissionCounts() { if (!character.equals(Charset.SPACE)) { for (int e=0; e<EXP_GAINS.length; ++e) { Arrays.fill(templateCountSparsity[e], false); for (int w=0; w<interpolationWeights[e].length; ++w) { for (int pos=0; pos<interpolationWeights[e][w].length; ++pos) { Arrays.fill(templateBlackCounts[e][w][pos], 0.0f); Arrays.fill(templateWhiteCounts[e][w][pos], 0.0f); } } } } } private void incrementEmissionCounts(int exposure, int offset, int width, int pos, float count, PixelType[] observation) { if (!character.equals(Charset.SPACE)) { synchronized (templateBlackCounts[exposure][width-templateMinWidth()][pos]) { for (int j=0; j<observation.length; ++j) { if (observation[j] == PixelType.BLACK) { templateBlackCounts[exposure][width-templateMinWidth()][pos][Math.min(LINE_HEIGHT-1, Math.max(0, j+offset))] += count; } else if (observation[j] == PixelType.WHITE) { templateWhiteCounts[exposure][width-templateMinWidth()][pos][Math.min(LINE_HEIGHT-1, Math.max(0, j+offset))] += count; } } } if (count > 0.0f) templateCountSparsity[exposure][width-templateMinWidth()] = true; } } private void updateEmissionParameters(float lbfgsTol, int iters) { if (!character.equals(Charset.SPACE)) { Minimizer minimizer = new LBFGSMinimizer(lbfgsTol, iters); double[] finalParams = minimizer.minimize(new NegExpectedLogLikelihoodFunc(), a.toDouble(getParamVector()), false, null); setParamVector(a.toFloat(finalParams)); } } private void invalidateTemplateLogProbsCache() { for (int e=0; e<EXP_GAINS.length; ++e) { Arrays.fill(templateLogProbsCached[e], false); } } private float[][] templateLogProbs(int width, int e, boolean black) { if (!templateLogProbsCached[e][width-templateMinWidth()]) { for (int pos=0; pos<width; ++pos) { for (int j=0; j<LINE_HEIGHT; ++j) { float innerProd = 0.0f; for (int tpos=0; tpos<templateMaxWidth(); ++tpos) { innerProd += interpolationWeights[e][width-templateMinWidth()][pos][tpos]*templateWeights[tpos][j]; } templateLogBlackProbs[e][width-templateMinWidth()][pos][j] = innerProd - (float) Math.log(1.0 + Math.exp(innerProd)); templateLogWhiteProbs[e][width-templateMinWidth()][pos][j] = (float) -Math.log(1.0 + Math.exp(innerProd)); } } templateLogProbsCached[e][width-templateMinWidth()] = true; } if (black) { return templateLogBlackProbs[e][width-templateMinWidth()]; } else { return templateLogWhiteProbs[e][width-templateMinWidth()]; } } private void setParamVector(float[] params) { for (int i=0; i<params.length; ++i) { int[]rowCol = paramIndexer.getObject(i); templateWeights[rowCol[0]][rowCol[1]] = params[i]; } invalidateTemplateLogProbsCache(); } private float[] getParamVector() { float[] params = new float[paramIndexer.size()]; for (int i=0; i<templateWeights.length; ++i) { for (int j=0; j<templateWeights[i].length; ++j) { params[paramIndexer.getIndex(new int[] {i,j})] = templateWeights[i][j]; } } return params; } private float[] getPriorMeanVector() { float[] prior = new float[paramIndexer.size()]; for (int i=0; i<templateWeightsPriorMeans.length; ++i) { for (int j=0; j<templateWeightsPriorMeans[i].length; ++j) { prior[paramIndexer.getIndex(new int[] {i,j})] = templateWeightsPriorMeans[i][j]; } } return prior; } private float getNegExpectedLogLikelihood() { float result = 0.0f; for (int e=0; e<EXP_GAINS.length; ++e) { for (int width=templateMinWidth(); width<=templateMaxWidth(); ++width) { if (templateCountSparsity[e][width-templateMinWidth()]) { for (int pos=0; pos<width; ++pos) { for (int j=0; j<LINE_HEIGHT; ++j) { result -= templateBlackCounts[e][width-templateMinWidth()][pos][j] * templateLogProbs(width, e, true)[pos][j] + templateWhiteCounts[e][width-templateMinWidth()][pos][j] * templateLogProbs(width, e, false)[pos][j]; } } } } } return result; } private float[] getNegExpectedLogLikelihoodGradient() { float[] result = new float[paramIndexer.size()]; for (int e=0; e<EXP_GAINS.length; ++e) { for (int width=templateMinWidth; width<=templateMaxWidth; ++width) { if (templateCountSparsity[e][width-templateMinWidth()]) { for (int pos=0; pos<width; ++pos) { for (int j=0; j<LINE_HEIGHT; ++j) { for (int tpos=0; tpos<templateMaxWidth; ++tpos) { int paramIndex = paramIndexer.getIndex(new int[] {tpos, j}); result[paramIndex] -= interpolationWeights[e][width-templateMinWidth()][pos][tpos] * (templateBlackCounts[e][width-templateMinWidth()][pos][j] - (templateBlackCounts[e][width-templateMinWidth()][pos][j] + templateWhiteCounts[e][width-templateMinWidth()][pos][j]) * Math.exp(templateLogProbs(width, e, true)[pos][j])); } } } } } } return result; } private class NegExpectedLogLikelihoodFunc implements DifferentiableFunction { float[] priorMeans = getPriorMeanVector(); public Pair<Double, double[]> calculate(double[] xDouble) { float[] x = a.toFloat(xDouble); setParamVector(x); float[] priorDelta = a.comb(x, 1.0f, priorMeans, -1.0f); float reg = EMIT_REG*a.innerProd(priorDelta, priorDelta); float[] regGrad = a.scale(priorDelta, EMIT_REG*2.0f); return Pair.makePair((double) getNegExpectedLogLikelihood()+reg, a.toDouble(a.comb(getNegExpectedLogLikelihoodGradient(), 1.0f, regGrad, 1.0f))); } } }