package edu.berkeley.cs.nlp.ocular.model.em; import static edu.berkeley.cs.nlp.ocular.util.Tuple2.Tuple2; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import tberg.murphy.threading.BetterThreader; import tberg.murphy.util.GeneralPriorityQueue; import tberg.murphy.arrays.a; import edu.berkeley.cs.nlp.ocular.model.emission.EmissionModel; import edu.berkeley.cs.nlp.ocular.model.transition.SparseTransitionModel; import edu.berkeley.cs.nlp.ocular.model.transition.SparseTransitionModel.TransitionState; import edu.berkeley.cs.nlp.ocular.util.Tuple2; /** * @author Taylor Berg-Kirkpatrick (tberg@eecs.berkeley.edu) */ public class BeamingSemiMarkovDP { private static class BeamState { private final TransitionState transState; public double score = Double.NEGATIVE_INFINITY; public Tuple2<Integer,TransitionState> backPointer = null; public BeamState(TransitionState transState) { this.transState = transState; } public int hashCode() { return transState.hashCode(); } public boolean equals(Object obj) { if (obj instanceof BeamState) { return transState.equals(((BeamState) obj).transState); } else { return false; } } } private GeneralPriorityQueue<BeamState>[][] alphas; double[][][] betas; private SparseTransitionModel forwardTransitionModel; private DenseBigramTransitionModel backwardTransitionModel; private EmissionModel emissionModel; @SuppressWarnings("unchecked") public BeamingSemiMarkovDP(EmissionModel emissionModel, SparseTransitionModel forwardTransitionModel, DenseBigramTransitionModel backwardTransitionModel) { this.emissionModel = emissionModel; this.forwardTransitionModel = forwardTransitionModel; this.backwardTransitionModel = backwardTransitionModel; this.alphas = new GeneralPriorityQueue[emissionModel.numSequences()][]; for (int d=0; d<emissionModel.numSequences(); ++d) { this.alphas[d] = new GeneralPriorityQueue[emissionModel.sequenceLength(d)+1]; for (int t=0; t<emissionModel.sequenceLength(d)+1; ++t) { this.alphas[d][t] = new GeneralPriorityQueue<BeamState>(); } } this.betas = new double[emissionModel.numSequences()][][]; for (int d=0; d<emissionModel.numSequences(); ++d) { this.betas[d] = new double[emissionModel.sequenceLength(d)+1][emissionModel.numChars()]; } } public Tuple2<Tuple2<TransitionState[][],int[][]>,Double> decode(final int beamSize, int numThreads) { System.out.print("Decoding"); if (numThreads == 1) return decodeSingleThread(beamSize); else return decodeMultipleThreads(beamSize, numThreads); } private Tuple2<Tuple2<TransitionState[][],int[][]>,Double> decodeSingleThread(int beamSize) { Collection<BeamState> startStates = null; double logJointProb = Double.NEGATIVE_INFINITY; for (int d = 0; d < emissionModel.numSequences(); ++d) { Tuple2<Double,Collection<BeamState>> logJointProbAndNextStartStates = doForwardPassLogSpace(d, beamSize, startStates); logJointProb = logJointProbAndNextStartStates._1; startStates = logJointProbAndNextStartStates._2; } TransitionState[][] decodeStates = new TransitionState[emissionModel.numSequences()][]; int[][] decodeWidths = new int[emissionModel.numSequences()][]; TransitionState finalState = null; for (int d = emissionModel.numSequences()-1; d >= 0; --d) { Tuple2<Tuple2<TransitionState[],int[]>,TransitionState> statesAndWidthsAndNextFinalState = followBackpointers(d, finalState); decodeStates[d] = statesAndWidthsAndNextFinalState._1._1; decodeWidths[d] = statesAndWidthsAndNextFinalState._1._2; finalState = statesAndWidthsAndNextFinalState._2; } return Tuple2(Tuple2(decodeStates, decodeWidths), logJointProb); } private Tuple2<Tuple2<TransitionState[][],int[][]>,Double> decodeMultipleThreads(final int beamSize, int numThreads) { final TransitionState[][] decodeStates = new TransitionState[emissionModel.numSequences()][]; final int[][] decodeWidths = new int[emissionModel.numSequences()][]; final int blockSize = (int) Math.ceil(((double) emissionModel.numSequences()) / ((double)numThreads)); final double[] logJointProb = new double[] {0.0}; { BetterThreader.Function<Integer,Object> func = new BetterThreader.Function<Integer,Object>(){public void call(Integer b, Object ignore) { double blockLogJointProb = Double.NEGATIVE_INFINITY; Collection<BeamState> startStates = null; for (int d=b*blockSize; d<(b+1)*blockSize; ++d) { if (d < emissionModel.numSequences()) { Tuple2<Double,Collection<BeamState>> logJointProbAndNextStartStates = doForwardPassLogSpace(d, beamSize, startStates); blockLogJointProb = logJointProbAndNextStartStates._1; startStates = logJointProbAndNextStartStates._2; } } TransitionState finalState = null; for (int d=(b+1)*blockSize-1; d>=b*blockSize; --d) { if (d < emissionModel.numSequences()) { Tuple2<Tuple2<TransitionState[],int[]>,TransitionState> statesAndWidthsAndNextFinalState = followBackpointers(d, finalState); decodeStates[d] = statesAndWidthsAndNextFinalState._1._1; decodeWidths[d] = statesAndWidthsAndNextFinalState._1._2; finalState = statesAndWidthsAndNextFinalState._2; } } synchronized (logJointProb) { logJointProb[0] += blockLogJointProb; } }}; BetterThreader<Integer,Object> threader = new BetterThreader<Integer,Object>(func, numThreads); for (int b=0; b<numThreads; ++b) threader.addFunctionArgument(b); threader.run(); } System.out.println(); return Tuple2(Tuple2(decodeStates, decodeWidths), logJointProb[0]); } private Tuple2<Double,Collection<BeamState>> doForwardPassLogSpace(int d, int beamSize, Collection<BeamState> startStates) { System.out.print("."); // System.out.printf("Backward pass: %d%n", d); doDenseCoarseBackwardPassLogSpace(d, betas[d]); // System.out.printf("Forward pass: %d%n", d); for (GeneralPriorityQueue<BeamState> queue : alphas[d]) queue.clear(); for (int t=0; t<emissionModel.sequenceLength(d)+1; ++t) { if (t == 0) { if (startStates == null || startStates.isEmpty()) { startStates = addNullBackpointers(forwardTransitionModel.startStates()); //if (startStates.isEmpty()) new EmptyBeamException("The forwardTransitionModel has no possible start states."); } for (BeamState startBeamState : startStates) { TransitionState nextTs = startBeamState.transState; double startLogProb = startBeamState.score; if (startLogProb != Double.NEGATIVE_INFINITY) { for (int w : emissionModel.allowedWidths(nextTs)) { if (t + w < emissionModel.sequenceLength(d)+1) { int nextT = t + w; double emissionLogProb = emissionModel.logProb(d, t, nextTs, nextT-t); double score = startLogProb + emissionLogProb; if (score != Double.NEGATIVE_INFINITY) { addToBeam(alphas[d][nextT], nextTs, score, betas[d][nextT][nextTs.getGlyphChar().templateCharIndex], new Tuple2<Integer,TransitionState>(0, startBeamState.backPointer._2), beamSize); } } } } } } else { for (BeamState beamState : alphas[d][t].getObjects()) { Collection<Tuple2<TransitionState,Double>> allowedTrans = beamState.transState.forwardTransitions(); for (Tuple2<TransitionState,Double> trans : allowedTrans) { TransitionState nextTs = trans._1; double transLogProb = trans._2; for (int w : emissionModel.allowedWidths(nextTs)) { if (t + w < emissionModel.sequenceLength(d)+1) { int nextT = t + w; double emissionLogProb = emissionModel.logProb(d, t, nextTs, nextT-t); double score = beamState.score + transLogProb + emissionLogProb; if (score != Double.NEGATIVE_INFINITY) { addToBeam(alphas[d][nextT], nextTs, score, betas[d][nextT][nextTs.getGlyphChar().templateCharIndex], Tuple2(t, beamState.transState), beamSize); } } } } } } } double bestFinalScore = Double.NEGATIVE_INFINITY; Map<TransitionState,BeamState> wrappedStartStatesMap = new HashMap<TransitionState,BeamState>(); for (BeamState endBeamState : alphas[d][emissionModel.sequenceLength(d)].getObjects()) { double endScore = endBeamState.score + endBeamState.transState.endLogProb(); if (endScore != Double.NEGATIVE_INFINITY) { if (endScore > bestFinalScore) { bestFinalScore = endScore; } for (Tuple2<TransitionState,Double> startTransitionPair : endBeamState.transState.nextLineStartStates()) { double score = endScore + startTransitionPair._2; if (score != Double.NEGATIVE_INFINITY) { BeamState startBeamState = wrappedStartStatesMap.get(startTransitionPair._1); if (startBeamState == null) { startBeamState = new BeamState(startTransitionPair._1); startBeamState.score = Double.NEGATIVE_INFINITY; startBeamState.backPointer = new Tuple2<Integer, TransitionState>(-1, null); wrappedStartStatesMap.put(startTransitionPair._1, startBeamState); } if (score > startBeamState.score) { startBeamState.score = score; startBeamState.backPointer = Tuple2(-1, endBeamState.transState); } } } } } Collection<BeamState> wrappedStartStates = new ArrayList<BeamState>(); for (Map.Entry<TransitionState, BeamState> entry : wrappedStartStatesMap.entrySet()) { wrappedStartStates.add(entry.getValue()); } return Tuple2(bestFinalScore, wrappedStartStates); } private static void addToBeam(GeneralPriorityQueue<BeamState> queue, TransitionState nextTs, double score, double forwardScore, Tuple2<Integer,TransitionState> backPointer, int beamSize) { double priority = -(score+forwardScore); if (queue.isEmpty() || priority < queue.getPriority()) { BeamState key = new BeamState(nextTs); if (queue.containsKey(key)) { queue.decreasePriority(key, priority); } else { queue.setPriority(key, priority); } BeamState object = queue.getObject(key); if (object.score < score) { object.score = score; object.backPointer = backPointer; } while (queue.size() > beamSize) { queue.removeFirst(); } } } private static Collection<BeamState> addNullBackpointers(Collection<Tuple2<TransitionState,Double>> without) { List<BeamState> with = new ArrayList<BeamState>(); for (Tuple2<TransitionState,Double> startPair : without) { BeamState beamState = new BeamState(startPair._1); beamState.score = startPair._2; beamState.backPointer = Tuple2(-1, null); with.add(beamState); } return with; } private Tuple2<Tuple2<TransitionState[],int[]>,TransitionState> followBackpointers(int d, TransitionState finalTs) { List<TransitionState> transStateDecodeList = new ArrayList<TransitionState>(); List<Integer> widthsDecodeList = new ArrayList<Integer>(); TransitionState nextFinalTs = null; try { TransitionState bestFinalTs = null; if (finalTs == null) { double bestFinalScore = Double.NEGATIVE_INFINITY; Collection<BeamState> possibleBeamStates = alphas[d][emissionModel.sequenceLength(d)].getObjects(); if (possibleBeamStates.isEmpty()) throw new EmptyBeamException("No possible final states found for this line. Consider increasing -beamSize."); for (BeamState beamState : possibleBeamStates) { double score = beamState.score + beamState.transState.endLogProb(); if (score > bestFinalScore) { bestFinalScore = score; bestFinalTs = beamState.transState; } } if (bestFinalTs == null) throw new EmptyBeamException("No final-state possibilities with non-zero probabilities for this line. Consider increasing -beamSize."); } else { bestFinalTs = finalTs; } int currentT = emissionModel.sequenceLength(d); TransitionState currentTs = bestFinalTs; //System.out.print("Line "+d+" Backward decode: "); while (true) { if (currentTs == null) throw new EmptyBeamException("No current-state possiblities with non-zero probabilities when following backpointers. Consider increasing -beamSize."); Tuple2<Integer,TransitionState> backpointer = alphas[d][currentT].getObject(new BeamState(currentTs)).backPointer; int width = currentT - backpointer._1; transStateDecodeList.add(currentTs); widthsDecodeList.add(width); currentT = backpointer._1; currentTs = backpointer._2; //System.out.println("Horizontal pixel "+currentT+", character ["+emissionModel.getCharIndexer().getObject(currentTs.getCharIndex())+"]"); //System.out.print(currentTs != null ? emissionModel.getCharIndexer().getObject(currentTs.getCharIndex()) : "#" ); if (currentT == 0) { nextFinalTs = currentTs; break; } } //System.out.println(); } catch (EmptyBeamException e) { System.out.println("ERRROR: Line "+d+": "+e.getMessage()); nextFinalTs = null; } Collections.reverse(transStateDecodeList); Collections.reverse(widthsDecodeList); int[] widthsDecode = a.toIntArray(widthsDecodeList); return Tuple2(Tuple2(transStateDecodeList.toArray(new TransitionState[0]), widthsDecode), nextFinalTs); } private void doDenseCoarseBackwardPassLogSpace(int d, double[][] betas) { int numChars = emissionModel.numChars(); for (int t=emissionModel.sequenceLength(d); t>=0; --t) { Arrays.fill(betas[t], Double.NEGATIVE_INFINITY); if (t==emissionModel.sequenceLength(d)) { for (int c=0; c<numChars; ++c) { betas[t][c] = backwardTransitionModel.endLogProb(c); } } else { for (int nextC=0; nextC<numChars; ++nextC) { double betaWithoutTrans = Double.NEGATIVE_INFINITY; int[] allowedWidths = emissionModel.allowedWidths(nextC); for (int w : allowedWidths) { if (t + w <= emissionModel.sequenceLength(d)) { double emissionLogProb = emissionModel.logProb(d, t, nextC, w); betaWithoutTrans = Math.max(betaWithoutTrans, emissionLogProb + betas[t+w][nextC]); } } double[] betasCol = betas[t]; double[] logTransProbs = backwardTransitionModel.backwardTransitions(nextC); for (int c=0; c<numChars; ++c) { betasCol[c] = Math.max(betasCol[c], logTransProbs[c] + betaWithoutTrans); } } } } } }