package edu.berkeley.cs.nlp.ocular.model.transition; import static edu.berkeley.cs.nlp.ocular.util.Tuple2.Tuple2; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; import tberg.murphy.arrays.a; import edu.berkeley.cs.nlp.ocular.data.textreader.Charset; import edu.berkeley.cs.nlp.ocular.gsm.GlyphChar; import edu.berkeley.cs.nlp.ocular.gsm.GlyphChar.GlyphType; import edu.berkeley.cs.nlp.ocular.lm.SingleLanguageModel; import edu.berkeley.cs.nlp.ocular.model.CharacterTemplate; import edu.berkeley.cs.nlp.ocular.model.TransitionStateType; import edu.berkeley.cs.nlp.ocular.util.Tuple2; /** * @author Taylor Berg-Kirkpatrick (tberg@eecs.berkeley.edu) */ public class CharacterNgramTransitionModelMarkovOffset implements SparseTransitionModel { public class CharacterNgramTransitionState implements SparseTransitionModel.TransitionState { private final int[] context; private final TransitionStateType type; private final int offset; private final int charIndex; private final int hashCode; public CharacterNgramTransitionState(int[] context, int offset, TransitionStateType type) { this.context = context; this.offset = offset; this.type = type; if (context.length == 0 || type == TransitionStateType.LMRGN || type == TransitionStateType.LMRGN_HPHN || type == TransitionStateType.RMRGN || type == TransitionStateType.RMRGN_HPHN) { this.charIndex = spaceCharIndex; } else if (type == TransitionStateType.RMRGN_HPHN_INIT) { this.charIndex = hyphenCharIndex; } else { this.charIndex = context[context.length-1]; } this.hashCode = 1013 * Arrays.hashCode(context) + 1009 * this.offset + 997 * this.type.ordinal(); } public boolean equals(Object other) { if (other instanceof CharacterNgramTransitionState) { CharacterNgramTransitionState that = (CharacterNgramTransitionState) other; if (this.type != that.type) { return false; } else if (this.offset != that.offset) { return false; } else if (!Arrays.equals(this.context, that.context)) { return false; } else { return true; } } else { return false; } } public int hashCode() { return hashCode; } public Collection<Tuple2<TransitionState,Double>> nextLineStartStates() { List<Tuple2<TransitionState,Double>> result = new ArrayList<Tuple2<TransitionState,Double>>(); TransitionStateType type = getType(); int[] context = getContext(); if (type == TransitionStateType.TMPL) { double scoreWithSpace = Math.log(lm.getCharNgramProb(context, spaceCharIndex)); if (scoreWithSpace != Double.NEGATIVE_INFINITY) { int[] contextWithSpace = shrinkContext(a.append(context, spaceCharIndex)); { double score = Math.log(LINE_MRGN_PROB) + scoreWithSpace; if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(contextWithSpace, 0, TransitionStateType.LMRGN), score)); } } for (int c=0; c<lm.getCharacterIndexer().size(); ++c) { double intermediateScore = Math.log((1.0 - LINE_MRGN_PROB)) + scoreWithSpace + Math.log(lm.getCharNgramProb(contextWithSpace, c)); if (intermediateScore != Double.NEGATIVE_INFINITY) { int[] nextContext = shrinkContext(a.append(contextWithSpace, c)); for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { double score = intermediateScore + LOG_OFFSET_START_PROBS[offset+CharacterTemplate.MAX_OFFSET]; if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, offset, TransitionStateType.TMPL), score)); } } } } } } else if (type == TransitionStateType.RMRGN) { { double score = Math.log(LINE_MRGN_PROB); if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, 0, TransitionStateType.LMRGN), score)); } } for (int c=0; c<lm.getCharacterIndexer().size(); ++c) { double intermediateScore = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, c)); if (intermediateScore != Double.NEGATIVE_INFINITY) { int[] nextContext = shrinkContext(a.append(context, c)); for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { double score = intermediateScore + LOG_OFFSET_START_PROBS[offset+CharacterTemplate.MAX_OFFSET]; if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, offset, TransitionStateType.TMPL), score)); } } } } } else if (type == TransitionStateType.RMRGN_HPHN || type == TransitionStateType.RMRGN_HPHN_INIT) { { double score = Math.log(LINE_MRGN_PROB); if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, 0, TransitionStateType.LMRGN_HPHN), score)); } } for (int c=0; c<lm.getCharacterIndexer().size(); ++c) { if (c != spaceCharIndex && !isPunc[c]) { double intermedateScore = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, c)); if (intermedateScore != Double.NEGATIVE_INFINITY) { int[] nextContext = shrinkContext(a.append(context, c)); for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { double score = intermedateScore + LOG_OFFSET_START_PROBS[offset+CharacterTemplate.MAX_OFFSET]; result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, offset, TransitionStateType.TMPL), score)); } } } } } else if (type == TransitionStateType.LMRGN || type == TransitionStateType.LMRGN_HPHN) { { double score = Math.log(LINE_MRGN_PROB); if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(new int[0], 0, TransitionStateType.LMRGN), score)); } } for (int c=0; c<lm.getCharacterIndexer().size(); ++c) { double intermediateScore = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, c)); if (intermediateScore != Double.NEGATIVE_INFINITY) { for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { double score = intermediateScore + LOG_OFFSET_START_PROBS[offset+CharacterTemplate.MAX_OFFSET]; if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(new int[] {c}, offset, TransitionStateType.TMPL), score)); } } } } } return result; } public double endLogProb() { return 0.0; } public Collection<Tuple2<TransitionState,Double>> forwardTransitions() { int[] context = getContext(); TransitionStateType type = getType(); List<Tuple2<TransitionState,Double>> result = new ArrayList<Tuple2<TransitionState,Double>>(); if (type == TransitionStateType.LMRGN) { { double score = Math.log(LINE_MRGN_PROB); if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, 0, TransitionStateType.LMRGN), score)); } } for (int c=0; c<lm.getCharacterIndexer().size(); ++c) { double intermediateScore = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, c)); if (intermediateScore != Double.NEGATIVE_INFINITY) { int[] nextContext = shrinkContext(a.append(context, c)); for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { double score = intermediateScore + LOG_OFFSET_START_PROBS[offset+CharacterTemplate.MAX_OFFSET]; if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, offset, TransitionStateType.TMPL), score)); } } } } } else if (type == TransitionStateType.LMRGN_HPHN) { { double score = Math.log(LINE_MRGN_PROB); if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, 0, TransitionStateType.LMRGN_HPHN), score)); } } for (int c=0; c<lm.getCharacterIndexer().size(); ++c) { if (c != spaceCharIndex && !isPunc[c]) { double intermediateScore = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, c)); if (intermediateScore != Double.NEGATIVE_INFINITY) { int[] nextContext = shrinkContext(a.append(context, c)); for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { double score = intermediateScore + LOG_OFFSET_START_PROBS[offset+CharacterTemplate.MAX_OFFSET]; if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, offset, TransitionStateType.TMPL), score)); } } } } } } else if (type == TransitionStateType.RMRGN) { double score = Math.log(LINE_MRGN_PROB); if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, 0, TransitionStateType.RMRGN), score)); } } else if (type == TransitionStateType.RMRGN_HPHN) { double score = Math.log(LINE_MRGN_PROB); if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, 0, TransitionStateType.RMRGN_HPHN), score)); } } else if (type == TransitionStateType.RMRGN_HPHN_INIT) { double score = Math.log(LINE_MRGN_PROB); if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, 0, TransitionStateType.RMRGN_HPHN), score)); } } else if (type == TransitionStateType.TMPL) { { double score = Math.log(LINE_MRGN_PROB) + Math.log(1.0 - LINE_END_HYPHEN_PROB) + Math.log(lm.getCharNgramProb(context, spaceCharIndex)); if (score != Double.NEGATIVE_INFINITY) { int[] nextContext = shrinkContext(a.append(context, spaceCharIndex)); result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, 0, TransitionStateType.RMRGN), score)); } } double[] logOffsetTransProbs = LOG_OFFSET_TRANS_PROBS[getOffset()+CharacterTemplate.MAX_OFFSET]; { double intermediateScore = Math.log(LINE_MRGN_PROB) + Math.log(LINE_END_HYPHEN_PROB); if (intermediateScore != Double.NEGATIVE_INFINITY) { for (int offset=Math.max(getOffset()-MAX_OFFSET_DIFF,-CharacterTemplate.MAX_OFFSET); offset<=Math.min(getOffset()+MAX_OFFSET_DIFF, CharacterTemplate.MAX_OFFSET); ++offset) { double score = intermediateScore + logOffsetTransProbs[offset+CharacterTemplate.MAX_OFFSET]; if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, offset, TransitionStateType.RMRGN_HPHN_INIT), score)); } } } } for (int nextC=0; nextC<lm.getCharacterIndexer().size(); ++nextC) { double intermediateScore = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, nextC)); if (intermediateScore != Double.NEGATIVE_INFINITY) { int[] nextContext = shrinkContext(a.append(context, nextC)); for (int offset=Math.max(getOffset()-MAX_OFFSET_DIFF,-CharacterTemplate.MAX_OFFSET); offset<=Math.min(getOffset()+MAX_OFFSET_DIFF, CharacterTemplate.MAX_OFFSET); ++offset) { double score = intermediateScore + logOffsetTransProbs[offset+CharacterTemplate.MAX_OFFSET]; if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, offset, TransitionStateType.TMPL), score)); } } } } } return result; } public int getLmCharIndex() { return charIndex; } public GlyphChar getGlyphChar() { return new GlyphChar(charIndex, GlyphType.NORMAL_CHAR); } public int getOffset() { return offset; } public int getExposure() { throw new Error("Method not implemented"); } public int[] getContext() { return context; } public TransitionStateType getType() { return type; } public int getLanguageIndex() { return -1; } } public static final double LINE_MRGN_PROB = 0.5; public static final double LINE_END_HYPHEN_PROB = 1e-8; public static int MAX_OFFSET_DIFF = 2; public static double MAX_OFFSET_TRANS_PROB_VAR = 0.05; public static final double[] LOG_OFFSET_START_PROBS = logOffsetStartProbs(); public static final double[][] LOG_OFFSET_TRANS_PROBS = logOffsetTransProbs(); private static double[] logOffsetStartProbs() { double[] offsetStartProbs = new double[CharacterTemplate.MAX_OFFSET*2+1]; for (int offset0=-CharacterTemplate.MAX_OFFSET; offset0<=CharacterTemplate.MAX_OFFSET; ++offset0) { offsetStartProbs[offset0+CharacterTemplate.MAX_OFFSET] = 1.0; } a.logi(offsetStartProbs); return offsetStartProbs; } private static double[][] logOffsetTransProbs() { double[][] offsetTransProbs = new double[CharacterTemplate.MAX_OFFSET*2+1][CharacterTemplate.MAX_OFFSET*2+1]; for (int offset0=-CharacterTemplate.MAX_OFFSET; offset0<=CharacterTemplate.MAX_OFFSET; ++offset0) { for (int offset1=-CharacterTemplate.MAX_OFFSET; offset1<=CharacterTemplate.MAX_OFFSET; ++offset1) { if (Math.abs(offset0 - offset1) <= MAX_OFFSET_DIFF) { double sqrDistFromMean = (offset0 - offset1)*(offset0 - offset1); offsetTransProbs[offset0+CharacterTemplate.MAX_OFFSET][offset1+CharacterTemplate.MAX_OFFSET] = Math.exp(-sqrDistFromMean/(2.0*MAX_OFFSET_TRANS_PROB_VAR)); } } } a.normalizecoli(offsetTransProbs); a.logi(offsetTransProbs); return offsetTransProbs; } private SingleLanguageModel lm; private int spaceCharIndex; private int hyphenCharIndex; private boolean[] isPunc; public CharacterNgramTransitionModelMarkovOffset(SingleLanguageModel lm) { this.lm = lm; this.spaceCharIndex = lm.getCharacterIndexer().getIndex(Charset.SPACE); this.hyphenCharIndex = lm.getCharacterIndexer().getIndex(Charset.HYPHEN); this.isPunc = new boolean[lm.getCharacterIndexer().size()]; Arrays.fill(this.isPunc, false); for (String c : lm.getCharacterIndexer().getObjects()) { if(Charset.isPunctuationChar(c)) isPunc[lm.getCharacterIndexer().getIndex(c)] = true; } } public Collection<Tuple2<TransitionState,Double>> startStates() { List<Tuple2<TransitionState,Double>> result = new ArrayList<Tuple2<TransitionState,Double>>(); result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(new int[0], 0, TransitionStateType.LMRGN), Math.log(LINE_MRGN_PROB))); for (int c=0; c<lm.getCharacterIndexer().size(); ++c) { double intermediateScore = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(new int[0], c)); if (intermediateScore != Double.NEGATIVE_INFINITY) { int[] nextContext = new int[] {c}; for (int offset=-CharacterTemplate.MAX_OFFSET; offset<=CharacterTemplate.MAX_OFFSET; ++offset) { double score = intermediateScore + LOG_OFFSET_START_PROBS[offset+CharacterTemplate.MAX_OFFSET]; if (score != Double.NEGATIVE_INFINITY) { result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, offset, TransitionStateType.TMPL), score)); } } } } return result; } private int[] shrinkContext(int[] context) { return lm.shrinkContext(context); } }