package edu.berkeley.cs.nlp.ocular.model.emission; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import edu.berkeley.cs.nlp.ocular.data.textreader.Charset; import edu.berkeley.cs.nlp.ocular.image.ImageUtils.PixelType; import edu.berkeley.cs.nlp.ocular.model.CharacterTemplate; import edu.berkeley.cs.nlp.ocular.model.em.EmissionCacheInnerLoop; import edu.berkeley.cs.nlp.ocular.model.transition.SparseTransitionModel.TransitionState; import tberg.murphy.gpu.CudaUtil; import tberg.murphy.indexer.Indexer; import tberg.murphy.threading.BetterThreader; /** * @author Taylor Berg-Kirkpatrick (tberg@eecs.berkeley.edu) */ public class CachingEmissionModel implements EmissionModel { private EmissionCacheInnerLoop innerLoop; private int numChars; private CharacterTemplate[] templates; private PixelType[][][] observations; private float[][] whiteObservations; private float[][] blackObservations; private int[][] templateAllowedWidths; private int[] templateMinWidths; private int[] templateMaxWidths; private int[] padAndTemplateMinWidths; private int[] padAndTemplateMaxWidths; private int[][] padAndTemplateAllowedWidths; private float[][][][] cachedLogProbs; private int spaceIndex; private int padMinWidth; private int padMaxWidth; public CachingEmissionModel(CharacterTemplate[] templates, Indexer<String> charIndexer, PixelType[][][] observations, int padMinWidth, int padMaxWidth, EmissionCacheInnerLoop innerLoop) { this.innerLoop = innerLoop; this.numChars = charIndexer.size(); this.spaceIndex = charIndexer.getIndex(Charset.SPACE); this.templates = templates; this.observations = observations; this.padMinWidth = padMinWidth; this.padMaxWidth = padMaxWidth; for (int c=0; c<numChars; ++c) if (templates[c] == null) throw new RuntimeException("template for template["+c+"] ("+charIndexer.getObject(c)+") is null!"); this.whiteObservations = new float[observations.length][]; this.blackObservations = new float[observations.length][]; for (int d=0; d<observations.length; ++d) { this.whiteObservations[d] = new float[sequenceLength(d)*CharacterTemplate.LINE_HEIGHT]; this.blackObservations[d] = new float[sequenceLength(d)*CharacterTemplate.LINE_HEIGHT]; for (int t=0; t<sequenceLength(d); ++t) { for (int j=0; j<CharacterTemplate.LINE_HEIGHT; ++j) { PixelType observation = observations[d][t][j]; if (observation == PixelType.BLACK) { this.whiteObservations[d][CudaUtil.flatten(sequenceLength(d), CharacterTemplate.LINE_HEIGHT, t, j)] = 0.0f; this.blackObservations[d][CudaUtil.flatten(sequenceLength(d), CharacterTemplate.LINE_HEIGHT, t, j)] = 1.0f; } else if (observation == PixelType.WHITE) { this.whiteObservations[d][CudaUtil.flatten(sequenceLength(d), CharacterTemplate.LINE_HEIGHT, t, j)] = 1.0f; this.blackObservations[d][CudaUtil.flatten(sequenceLength(d), CharacterTemplate.LINE_HEIGHT, t, j)] = 0.0f; } else { this.whiteObservations[d][CudaUtil.flatten(sequenceLength(d), CharacterTemplate.LINE_HEIGHT, t, j)] = 0.0f; this.blackObservations[d][CudaUtil.flatten(sequenceLength(d), CharacterTemplate.LINE_HEIGHT, t, j)] = 0.0f; } } } } } public int numChars() { return numChars; } public int numSequences() { return observations.length; } public int sequenceLength(int d) { return observations[d].length; } public int[] allowedWidths(int c) { return padAndTemplateAllowedWidths[c]; } public int[] allowedWidths(TransitionState ts) { return allowedWidths(ts.getGlyphChar().templateCharIndex); } public float logProb(int d, int t, int c, int w) { return cachedLogProbs[d][t][c][w-padAndTemplateMinWidths[c]]; } public float logProb(int d, int t, TransitionState ts, int w) { return logProb(d, t, ts.getGlyphChar().templateCharIndex, w); } public int getExposure(int d, int t, TransitionState ts, int w) { int c = ts.getGlyphChar().templateCharIndex; double bestScore = Double.NEGATIVE_INFINITY; int bestExposure = -1; for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { for (int e=0; e<CharacterTemplate.EXP_GAINS.length; ++e) { for (int pw=padMinWidth; pw<=padMaxWidth; ++pw) { int tw = w-pw; if (tw >= templateMinWidths[c] && tw <= templateMaxWidths[c]) { double score = templates[c].widthLogProb(tw) + templates[c].emissionLogProb(observations[d], t, t+tw, e, offset) + padWidthLogProb(pw) + templates[spaceIndex].emissionLogProb(observations[d], t+tw, t+tw+pw, e, offset); if (score > bestScore) { bestScore = score; bestExposure = e; } } } } } return bestExposure; } public int getOffset(int d, int t, TransitionState ts, int w) { int c = ts.getGlyphChar().templateCharIndex; double bestScore = Double.NEGATIVE_INFINITY; int bestOffset = Integer.MIN_VALUE; for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { for (int e=0; e<CharacterTemplate.EXP_GAINS.length; ++e) { for (int pw=padMinWidth; pw<=padMaxWidth; ++pw) { int tw = w-pw; if (tw >= templateMinWidths[c] && tw <= templateMaxWidths[c]) { double score = templates[c].widthLogProb(tw) + templates[c].emissionLogProb(observations[d], t, t+tw, e, offset) + padWidthLogProb(pw) + templates[spaceIndex].emissionLogProb(observations[d], t+tw, t+tw+pw, e, offset); if (score > bestScore) { bestScore = score; bestOffset = offset; } } } } } return bestOffset; } public int getPadWidth(int d, int t, TransitionState ts, int w) { int c = ts.getGlyphChar().templateCharIndex; double bestScore = Double.NEGATIVE_INFINITY; int bestPadWidth = -1; for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { for (int e=0; e<CharacterTemplate.EXP_GAINS.length; ++e) { for (int pw=padMinWidth; pw<=padMaxWidth; ++pw) { int tw = w-pw; if (tw >= templateMinWidths[c] && tw <= templateMaxWidths[c]) { double score = templates[c].widthLogProb(tw) + templates[c].emissionLogProb(observations[d], t, t+tw, e, offset) + padWidthLogProb(pw) + templates[spaceIndex].emissionLogProb(observations[d], t+tw, t+tw+pw, e, offset); if (score > bestScore) { bestScore = score; bestPadWidth = pw; } } } } } return bestPadWidth; } public float padWidthLogProb(int pw) { return (float) Math.log(1.0 / ((padMaxWidth - padMinWidth) + 1.0)); } public void rebuildCache() { long nanoTime = System.nanoTime(); templateAllowedWidths = new int[numChars][]; templateMinWidths = new int[numChars]; templateMaxWidths = new int[numChars]; padAndTemplateMinWidths = new int[numChars]; padAndTemplateMaxWidths = new int[numChars]; padAndTemplateAllowedWidths = new int[numChars][]; for (int c=0; c<numChars; ++c) { templateAllowedWidths[c] = templates[c].allowedWidths(); templateMinWidths[c] = templates[c].templateMinWidth(); templateMaxWidths[c] = templates[c].templateMaxWidth(); padAndTemplateMinWidths[c] = templates[c].templateMinWidth() + padMinWidth; padAndTemplateMaxWidths[c] = templates[c].templateMaxWidth() + padMaxWidth; boolean[] padAndTemplateAllowedWidthsBool = new boolean[padAndTemplateMaxWidths[c]+1]; Arrays.fill(padAndTemplateAllowedWidthsBool, false); for (int tw : templateAllowedWidths[c]) { for (int pw=padMinWidth; pw<=padMaxWidth; ++pw) { padAndTemplateAllowedWidthsBool[tw+pw] = true; } } List<Integer> padAndTemplateAllowedWidthsList = new ArrayList<Integer>(); for (int w=0; w<padAndTemplateAllowedWidthsBool.length; ++w) { if (padAndTemplateAllowedWidthsBool[w]) padAndTemplateAllowedWidthsList.add(w); } padAndTemplateAllowedWidths[c] = new int[padAndTemplateAllowedWidthsList.size()]; for (int wi=0; wi<padAndTemplateAllowedWidthsList.size(); ++wi) { padAndTemplateAllowedWidths[c][wi] = padAndTemplateAllowedWidthsList.get(wi); } } final float[][][] logColumnProbsWhitespace = new float[observations.length][][]; for (int d=0; d<observations.length; ++d) { logColumnProbsWhitespace[d] = new float[sequenceLength(d)][CharacterTemplate.EXP_GAINS.length]; for (int e=0; e<CharacterTemplate.EXP_GAINS.length; ++e) { float[] logWhiteProbsWhitespace = templates[spaceIndex].logWhiteProbs(e, 0, 1)[0]; float[] logBlackProbsWhitespace = templates[spaceIndex].logBlackProbs(e, 0, 1)[0]; for (int t=0; t<sequenceLength(d); ++t) { float logProb = 0.0f; for (int j=0; j<CharacterTemplate.LINE_HEIGHT; ++j) { logProb += logWhiteProbsWhitespace[j] * whiteObservations[d][CudaUtil.flatten(sequenceLength(d), CharacterTemplate.LINE_HEIGHT, t, j)]; } for (int j=0; j<CharacterTemplate.LINE_HEIGHT; ++j) { logProb += logBlackProbsWhitespace[j] * blackObservations[d][CudaUtil.flatten(sequenceLength(d), CharacterTemplate.LINE_HEIGHT, t, j)]; } logColumnProbsWhitespace[d][t][e] = logProb; } } } cachedLogProbs = new float[numSequences()][][][]; for (int d=0; d<numSequences(); ++d) { cachedLogProbs[d] = new float[sequenceLength(d)][][]; for (int t=0; t<sequenceLength(d); ++t) { cachedLogProbs[d][t] = new float[numChars][]; for (int c=0; c<numChars; ++c) { cachedLogProbs[d][t][c] = new float[padAndTemplateMaxWidths[c]-padAndTemplateMinWidths[c]+1]; Arrays.fill(cachedLogProbs[d][t][c], Float.NEGATIVE_INFINITY); } } } int maxTemplateWidthTmp = Integer.MIN_VALUE; int minTemplateWidthTmp = Integer.MAX_VALUE; for (int c=0; c<numChars; ++c) maxTemplateWidthTmp = Math.max(maxTemplateWidthTmp, templateMaxWidths[c]); for (int c=0; c<numChars; ++c) minTemplateWidthTmp = Math.min(minTemplateWidthTmp, templateMinWidths[c]); final int maxTemplateWidth = maxTemplateWidthTmp; final int minTemplateWidth = minTemplateWidthTmp; final int numTemplateWidths = (maxTemplateWidth-minTemplateWidth)+1; final int[][][][] templateIndices = new int[numTemplateWidths][numChars][CharacterTemplate.EXP_GAINS.length][2*CharacterTemplate.MAX_OFFSET+1]; @SuppressWarnings("unchecked") final List<float[]>[] whiteTemplatesList = new List[numTemplateWidths]; @SuppressWarnings("unchecked") final List<float[]>[] blackTemplatesList = new List[numTemplateWidths]; for (int tw=minTemplateWidth; tw<=maxTemplateWidth; ++tw) { whiteTemplatesList[tw-minTemplateWidth] = new ArrayList<float[]>(); blackTemplatesList[tw-minTemplateWidth] = new ArrayList<float[]>(); } final int[] templateNumIndices = new int[numTemplateWidths]; for (int c=0; c<numChars; ++c) { for (int tw : templateAllowedWidths[c]) { for (int e=0; e<CharacterTemplate.EXP_GAINS.length; ++e) { for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { float[][] logWhiteProbsTemplate = templates[c].logWhiteProbs(e, offset, tw); float[][] logBlackProbsTemplate = templates[c].logBlackProbs(e, offset, tw); whiteTemplatesList[tw-minTemplateWidth].add(CudaUtil.flatten(logWhiteProbsTemplate)); blackTemplatesList[tw-minTemplateWidth].add(CudaUtil.flatten(logBlackProbsTemplate)); templateIndices[tw-minTemplateWidth][c][e][offset+CharacterTemplate.MAX_OFFSET] = templateNumIndices[tw-minTemplateWidth]; templateNumIndices[tw-minTemplateWidth]++; } } } } int totalTemplateNumIndices = 0; final int[] templateIndicesOffsets = new int[numTemplateWidths]; for (int tw=minTemplateWidth; tw<=maxTemplateWidth; ++tw) { templateIndicesOffsets[tw-minTemplateWidth] = totalTemplateNumIndices; totalTemplateNumIndices += templateNumIndices[tw-minTemplateWidth]; } float[][] whiteTemplates = new float[numTemplateWidths][]; float[][] blackTemplates = new float[numTemplateWidths][]; for (int tw=minTemplateWidth; tw<=maxTemplateWidth; ++tw) { whiteTemplates[tw-minTemplateWidth] = CudaUtil.flatten(whiteTemplatesList[tw-minTemplateWidth]); blackTemplates[tw-minTemplateWidth] = CudaUtil.flatten(blackTemplatesList[tw-minTemplateWidth]); } int maxSequenceLength = Integer.MIN_VALUE; for (int d=0; d<numSequences(); ++d) maxSequenceLength = Math.max(maxSequenceLength, sequenceLength(d)); innerLoop.startup(whiteTemplates, blackTemplates, templateNumIndices, templateIndicesOffsets, minTemplateWidth, maxTemplateWidth, maxSequenceLength, totalTemplateNumIndices); float[][] scores = new float[innerLoop.numOuterThreads()][maxSequenceLength*totalTemplateNumIndices]; BetterThreader.Function<Integer,float[]> func = new BetterThreader.Function<Integer,float[]>(){public void call(Integer d, float[] scores){ Arrays.fill(scores, 0.0f); innerLoop.compute(scores, whiteObservations[d], blackObservations[d], sequenceLength(d)); populate(d, scores, minTemplateWidth, logColumnProbsWhitespace, templateIndices, templateNumIndices, templateIndicesOffsets, innerLoop.numPopulateThreads()); }}; BetterThreader<Integer,float[]> threader = new BetterThreader<Integer,float[]>(func, innerLoop.numOuterThreads()); for (int d=0; d<numSequences(); ++d) threader.addFunctionArgument(d); for (int t=0; t<innerLoop.numOuterThreads(); ++t) threader.setThreadArgument(t, scores[t]); threader.run(); innerLoop.shutdown(); System.out.println("Rebuild emission cache: " + (System.nanoTime() - nanoTime)/1000000 + "ms"); System.out.printf("Estimated emission cache size: %.3fgb\n", estimateMemoryUsage()); } private void populate(final int d, final float[] scores, final int minTemplateWidth, final float[][][] logColumnProbsWhitespace, final int[][][][] templateIndices, final int[] templateNumIndices, final int[] templateIndicesOffsets, int numThreads) { BetterThreader.Function<Integer,Object> func = new BetterThreader.Function<Integer,Object>(){public void call(Integer t, Object ignore){ for (int c=0; c<numChars; ++c) { int[] templateWidths = templateAllowedWidths[c]; for (int tw : templateWidths) { double templateWidthLogProb = templates[c].widthLogProb(tw); if (t+tw+padMinWidth <= sequenceLength(d)) { for (int e=0; e<CharacterTemplate.EXP_GAINS.length; ++e) { float templateLogProb = Float.NEGATIVE_INFINITY; for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { float logProb = (float) templateWidthLogProb + scores[templateIndicesOffsets[tw-minTemplateWidth]*sequenceLength(d) + CudaUtil.flatten(sequenceLength(d), templateNumIndices[tw-minTemplateWidth], t, templateIndices[tw-minTemplateWidth][c][e][offset+CharacterTemplate.MAX_OFFSET])]; if (logProb > templateLogProb) { templateLogProb = logProb; } } for (int pw=padMinWidth; pw<=padMaxWidth; ++pw) { int w = tw + pw; if (t+w <= sequenceLength(d)) { float padLogProb = (float) padWidthLogProb(pw); if (pw > 0) { for (int tt=0; tt<pw; ++tt) { padLogProb += logColumnProbsWhitespace[d][t+tw+tt][e]; } } if (templateLogProb + padLogProb > cachedLogProbs[d][t][c][w-padAndTemplateMinWidths[c]]) { cachedLogProbs[d][t][c][w-padAndTemplateMinWidths[c]] = templateLogProb + padLogProb; } } } } } } } }}; BetterThreader<Integer,Object> threader = new BetterThreader<Integer,Object>(func, numThreads); for (int t=0; t<sequenceLength(d); ++t) threader.addFunctionArgument(t); threader.run(); } public void incrementCount(int d, TransitionState ts, int startCol, int endCol, float count) { if (count > 0.0) { int c = ts.getGlyphChar().templateCharIndex; int w = endCol - startCol; int tw = w - getPadWidth(d, startCol, ts, w); templates[c].incrementCounts(count, observations[d], startCol, tw, getExposure(d, startCol, ts, w), getOffset(d, startCol, ts, w)); } } public void incrementCounts(int d, TransitionState[] ts, int[] widths) { int t=0; for (int i=0; i<ts.length; ++i) { int width = widths[i]; incrementCount(d, ts[i], t, t+width, 1.0f); t += width; } } private double estimateMemoryUsage() { double elementsOfCache = 0.0; for (int i=0; i<cachedLogProbs.length; ++i) { if (cachedLogProbs[i] != null) { for (int j=0; j<cachedLogProbs[i].length; ++j) { if (cachedLogProbs[i][j] != null) { for (int k=0; k<cachedLogProbs[i][j].length; ++k) { if (cachedLogProbs[i][j][k] != null) elementsOfCache += cachedLogProbs[i][j][k].length; } } } } } return 4 * elementsOfCache / 1e9; } public static class CachingEmissionModelFactory implements EmissionModel.EmissionModelFactory { Indexer<String> charIndexer; int padMinWidth; int padMaxWidth; EmissionCacheInnerLoop innerLoop; public CachingEmissionModelFactory(Indexer<String> charIndexer, int padMinWidth, int padMaxWidth, EmissionCacheInnerLoop innerLoop) { this.charIndexer = charIndexer; this.padMinWidth = padMinWidth; this.padMaxWidth = padMaxWidth; this.innerLoop = innerLoop; } public EmissionModel make(CharacterTemplate[] templates, PixelType[][][] observations) { return new CachingEmissionModel(templates, charIndexer, observations, padMinWidth, padMaxWidth, innerLoop); } } }