package edu.berkeley.cs.nlp.ocular.model;
import static edu.berkeley.cs.nlp.ocular.util.Tuple2.Tuple2;
import java.text.SimpleDateFormat;
import java.util.Calendar;
import edu.berkeley.cs.nlp.ocular.data.Document;
import edu.berkeley.cs.nlp.ocular.gsm.GlyphSubstitutionModel;
import edu.berkeley.cs.nlp.ocular.image.ImageUtils.PixelType;
import edu.berkeley.cs.nlp.ocular.lm.CodeSwitchLanguageModel;
import edu.berkeley.cs.nlp.ocular.lm.SingleLanguageModel;
import edu.berkeley.cs.nlp.ocular.model.em.BeamingSemiMarkovDP;
import edu.berkeley.cs.nlp.ocular.model.em.DenseBigramTransitionModel;
import edu.berkeley.cs.nlp.ocular.model.emission.EmissionModel;
import edu.berkeley.cs.nlp.ocular.model.emission.EmissionModel.EmissionModelFactory;
import edu.berkeley.cs.nlp.ocular.model.transition.CharacterNgramTransitionModel;
import edu.berkeley.cs.nlp.ocular.model.transition.CharacterNgramTransitionModelMarkovOffset;
import edu.berkeley.cs.nlp.ocular.model.transition.CodeSwitchTransitionModel;
import edu.berkeley.cs.nlp.ocular.model.transition.SparseTransitionModel;
import edu.berkeley.cs.nlp.ocular.model.transition.SparseTransitionModel.TransitionState;
import edu.berkeley.cs.nlp.ocular.util.Tuple2;
import tberg.murphy.threading.BetterThreader;
/**
* @author Taylor Berg-Kirkpatrick (tberg@eecs.berkeley.edu)
* @author Dan Garrette (dhgarrette@gmail.com)
*/
public class DecoderEM {
private EmissionModelFactory emissionModelFactory;
private boolean allowGlyphSubstitution;
private double noCharSubPrior;
private boolean elideAnything;
private boolean allowLanguageSwitchOnPunct;
private boolean markovVerticalOffset;
private int beamSize;
private int numDecodeThreads;
private int numMstepThreads;
private int decodeBatchSize;
public DecoderEM(EmissionModelFactory emissionModelFactory, boolean allowGlyphSubstitution, double noCharSubPrior, boolean elideAnything,
boolean allowLanguageSwitchOnPunct, boolean markovVerticalOffset,
int beamSize, int numDecodeThreads, int numMstepThreads, int decodeBatchSize) {
this.emissionModelFactory = emissionModelFactory;
this.allowGlyphSubstitution = allowGlyphSubstitution;
this.noCharSubPrior = noCharSubPrior;
this.elideAnything = elideAnything;
this.allowLanguageSwitchOnPunct = allowLanguageSwitchOnPunct;
this.markovVerticalOffset = markovVerticalOffset;
this.beamSize = beamSize;
this.numDecodeThreads = numDecodeThreads;
this.numMstepThreads = numMstepThreads;
this.decodeBatchSize = decodeBatchSize;
}
public Tuple2<DecodeState[][], Double> computeEStep(
Document doc, boolean updateFontParameterCounts,
CodeSwitchLanguageModel lm, GlyphSubstitutionModel gsm, final CharacterTemplate[] templates,
DenseBigramTransitionModel backwardTransitionModel) {
final PixelType[][][] pixels = doc.loadLineImages();
DecodeState[][] allDecodeStates = new DecodeState[pixels.length][0];
long totalDecodeNanoTime = 0;
long totalEmitNanoTime = 0;
double totalJointLogProb = 0.0;
int numBatches = (int) Math.ceil(pixels.length / (double) decodeBatchSize);
for (int b = 0; b < numBatches; ++b) {
System.gc();
System.gc();
System.gc();
System.out.println("Batch: " + b);
int startLine = b * decodeBatchSize;
int endLine = Math.min((b + 1) * decodeBatchSize, pixels.length);
PixelType[][][] batchPixels = new PixelType[endLine - startLine][][];
for (int line = startLine; line < endLine; ++line) {
batchPixels[line - startLine] = pixels[line];
}
System.out.println("Initializing EmissionModel " + (new SimpleDateFormat("yyyy/MM/dd HH:mm:ss").format(Calendar.getInstance().getTime())));
final EmissionModel batchEmissionModel = emissionModelFactory.make(templates, batchPixels);
System.out.println("Rebuilding cache " + (new SimpleDateFormat("yyyy/MM/dd HH:mm:ss").format(Calendar.getInstance().getTime())));
//long emissionCacheNanoTime = System.nanoTime();
long nanoTime = System.nanoTime();
batchEmissionModel.rebuildCache();
totalEmitNanoTime += (System.nanoTime() - nanoTime);
System.out.println("Done rebuilding cache " + (new SimpleDateFormat("yyyy/MM/dd HH:mm:ss").format(Calendar.getInstance().getTime())));
//overallEmissionCacheNanoTime += (System.nanoTime() - emissionCacheNanoTime);
nanoTime = System.nanoTime();
System.out.println("Constructing forwardTransitionModel");
SparseTransitionModel forwardTransitionModel = constructTransitionModel(lm, gsm);
BeamingSemiMarkovDP dp = new BeamingSemiMarkovDP(batchEmissionModel, forwardTransitionModel, backwardTransitionModel);
System.out.println("Ready to run decoder");
Tuple2<Tuple2<TransitionState[][], int[][]>, Double> decodeStatesAndWidthsAndJointLogProb = dp.decode(beamSize, numDecodeThreads);
System.out.println("Done running decoder");
totalDecodeNanoTime += (System.nanoTime() - nanoTime);
final TransitionState[][] batchDecodeStates = decodeStatesAndWidthsAndJointLogProb._1._1;
final int[][] batchDecodeWidths = decodeStatesAndWidthsAndJointLogProb._1._2;
totalJointLogProb += decodeStatesAndWidthsAndJointLogProb._2;
for (int batchLine = 0; batchLine < batchEmissionModel.numSequences(); ++batchLine) {
int line = startLine + batchLine;
TransitionState[] decodeStates = batchDecodeStates[batchLine];
int[] decodeWidths = batchDecodeWidths[batchLine];
allDecodeStates[line] = new DecodeState[decodeStates.length];
int stateStartCol = 0;
for (int di=0; di<decodeStates.length; ++di) {
int charAndPadWidth = decodeWidths[di];
int padWidth = batchEmissionModel.getPadWidth(batchLine, stateStartCol, decodeStates[di], charAndPadWidth);
int exposure = batchEmissionModel.getExposure(batchLine, stateStartCol, decodeStates[di], charAndPadWidth);
int verticalOffset = batchEmissionModel.getOffset(batchLine, stateStartCol, decodeStates[di], charAndPadWidth);
allDecodeStates[line][di] = new DecodeState(decodeStates[di], charAndPadWidth, padWidth, exposure, verticalOffset);
stateStartCol += charAndPadWidth;
}
}
if (updateFontParameterCounts) {
System.out.println("Ready to run increment counts");
incrementCounts(batchEmissionModel, batchDecodeStates, batchDecodeWidths);
}
}
System.out.println("Emission cache: " + (totalEmitNanoTime / 1000000) + "ms");
System.out.println("Decode: " + (totalDecodeNanoTime / 1000000) + "ms");
double avgLogProb = totalJointLogProb / numBatches;
return Tuple2(allDecodeStates, avgLogProb);
}
private SparseTransitionModel constructTransitionModel(CodeSwitchLanguageModel codeSwitchLM, GlyphSubstitutionModel codeSwitchGSM) {
SparseTransitionModel transitionModel;
boolean multilingual = codeSwitchLM.getLanguageIndexer().size() > 1;
if (multilingual || allowGlyphSubstitution) {
if (markovVerticalOffset) {
if (allowGlyphSubstitution)
throw new RuntimeException("Markov vertical offset transition model not currently supported with glyph substitution.");
else
throw new RuntimeException("Markov vertical offset transition model not currently supported for multiple languages.");
}
else {
transitionModel = new CodeSwitchTransitionModel(codeSwitchLM, allowLanguageSwitchOnPunct, codeSwitchGSM, allowGlyphSubstitution, noCharSubPrior, elideAnything);
System.out.println("Using CodeSwitchLanguageModel, GlyphSubstitutionModel, and CodeSwitchTransitionModel");
}
}
else { // only one language, default to original (monolingual) Ocular code because it will be faster.
SingleLanguageModel singleLm = codeSwitchLM.get(0);
if (markovVerticalOffset) {
transitionModel = new CharacterNgramTransitionModelMarkovOffset(singleLm);
System.out.println("Using OnlyOneLanguageCodeSwitchLM and CharacterNgramTransitionModelMarkovOffset");
} else {
transitionModel = new CharacterNgramTransitionModel(singleLm);
System.out.println("Using OnlyOneLanguageCodeSwitchLM and CharacterNgramTransitionModel");
}
}
return transitionModel;
}
private void incrementCounts(final EmissionModel emissionModel, final TransitionState[][] batchDecodeStates, final int[][] batchDecodeWidths) {
long nanoTime = System.nanoTime();
BetterThreader.Function<Integer, Object> func = new BetterThreader.Function<Integer, Object>() {
public void call(Integer line, Object ignore) {
emissionModel.incrementCounts(line, batchDecodeStates[line], batchDecodeWidths[line]);
}
};
BetterThreader<Integer, Object> threader = new BetterThreader<Integer, Object>(func, numMstepThreads);
for (int line = 0; line < emissionModel.numSequences(); ++line)
threader.addFunctionArgument(line);
threader.run();
System.out.println("Increment counts: " + ((System.nanoTime() - nanoTime) / 1000000) + "ms");
}
}