package edu.berkeley.nlp.lm; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import java.util.Random; import edu.berkeley.nlp.lm.collections.Counter; import edu.berkeley.nlp.lm.map.ContextEncodedNgramMap; import edu.berkeley.nlp.lm.map.NgramMap; import edu.berkeley.nlp.lm.values.ProbBackoffValueContainer; import edu.berkeley.nlp.lm.values.ProbBackoffPair; /** * Language model implementation which uses Kneser-Ney-style backoff * computation. * * Note that unlike the description in Pauls and Klein (2011), we store trie for * which the first word in n-gram points to its prefix for this particular * implementation. This is in contrast to {@link ContextEncodedProbBackoffLm}, * which stores a trie for which the last word points to its suffix. This was * done because it simplifies the code significantly, without significantly * changing speed or memory usage. * * @author adampauls * * @param <W> */ public class ArrayEncodedProbBackoffLm<W> extends AbstractArrayEncodedNgramLanguageModel<W> implements ArrayEncodedNgramLanguageModel<W>, Serializable { /** * */ private static final long serialVersionUID = 1L; private final NgramMap<ProbBackoffPair> map; private final ProbBackoffValueContainer values; private final boolean useScratchValues; private final long numWords; public ArrayEncodedProbBackoffLm(final int lmOrder, final WordIndexer<W> wordIndexer, final NgramMap<ProbBackoffPair> map, final ConfigOptions opts) { super(lmOrder, wordIndexer, (float) opts.unknownWordLogProb); this.map = map; this.values = (ProbBackoffValueContainer) map.getValues(); useScratchValues = !(map instanceof ContextEncodedNgramMap); numWords = map.getNumNgrams(0); } /* * (non-Javadoc) * * @see * edu.berkeley.nlp.lm.AbstractArrayEncodedNgramLanguageModel#getLogProb * (int[], int, int) */ @Override public float getLogProb(final int[] ngram, final int startPos, final int endPos) { int startPos_ = startPos; while (endPos - startPos_ > lmOrder) { startPos_++; } final NgramMap<ProbBackoffPair> localMap = map; if (endPos - startPos_ < 1) return 0.0f; final ProbBackoffPair scratch = !useScratchValues ? null : new ProbBackoffPair(Float.NaN, Float.NaN); final int unigramWord = ngram[endPos - 1]; if (unigramWord < 0 || unigramWord >= numWords) return oovWordLogProb; long matchedProbContext = unigramWord; int matchedProbContextOrder = -1; for (int i = endPos - 2; i >= startPos_; --i) { final int probContextOrder = endPos - i - 2; final long probContext = localMap.getValueAndOffset(matchedProbContext, probContextOrder, ngram[i], scratch); if (probContext < 0) break; matchedProbContext = probContext; matchedProbContextOrder = probContextOrder; } float logProb = scratch == null ? values.getProb(matchedProbContextOrder + 1, matchedProbContext) : scratch.prob; if (Float.isNaN(logProb)) { // this was a fake entry, let's do it again, but only keep track of the biggest match which was not fake matchedProbContext = 0; matchedProbContextOrder = -1; for (int i = endPos - 1; i >= startPos_; --i) { final int probContextOrder = endPos - i - 2; final long probContext = localMap.getValueAndOffset(matchedProbContext, probContextOrder, ngram[i], scratch); if (probContext < 0) break; final float tmpProb = scratch == null ? values.getProb(probContextOrder + 1, probContext) : scratch.prob; if (!Float.isNaN(tmpProb)) { logProb = tmpProb; matchedProbContext = probContext; matchedProbContextOrder = probContextOrder; } } } final float backoff = matchedProbContextOrder == endPos - startPos_ - 2 || endPos - startPos_ <= 1 ? 0.0f : getBackoffSum(ngram, startPos_, endPos, localMap, matchedProbContextOrder, scratch); return logProb + backoff; } /** * @param ngram * @param startPos * @param endPos * @param localMap * @param matchedProbContextOrder * @param scratch * @return */ private float getBackoffSum(final int[] ngram, final int startPos, final int endPos, final NgramMap<ProbBackoffPair> localMap, int matchedProbContextOrder, final ProbBackoffPair scratch) { final long unigramWord = ngram[endPos - 2]; if (unigramWord < 0 || unigramWord >= numWords) return 0.0f; long backoffContext = unigramWord; float backoff = 0.0f; // check if must include unigram backoff if (matchedProbContextOrder < 0) { if (scratch != null) { localMap.getValueAndOffset(0, -1, ngram[endPos - 2], scratch); backoff = scratch.backoff; } else { backoff = values.getBackoff(0, backoffContext); } } int i = 1; for (; i <= matchedProbContextOrder && backoffContext >= 0; ++i) { backoffContext = localMap.getValueAndOffset(backoffContext, i - 1, ngram[endPos - i - 2], null); } for (; i < endPos - startPos - 1 && backoffContext >= 0; ++i) { final int backoffContextOrder = i - 1; backoffContext = localMap.getValueAndOffset(backoffContext, backoffContextOrder, ngram[endPos - i - 2], scratch); if (backoffContext < 0) break; assert i > matchedProbContextOrder; final float currBackoff = scratch == null ? values.getBackoff(backoffContextOrder + 1, backoffContext) : scratch.backoff; backoff += Float.isNaN(currBackoff) ? 0.0f : currBackoff; } return backoff; } /* * (non-Javadoc) * * @see * edu.berkeley.nlp.lm.AbstractArrayEncodedNgramLanguageModel#getLogProb * (int[]) */ @Override public float getLogProb(final int[] ngram) { return ArrayEncodedNgramLanguageModel.DefaultImplementations.getLogProb(ngram, this); } /* * (non-Javadoc) * * @see * edu.berkeley.nlp.lm.AbstractArrayEncodedNgramLanguageModel#getLogProb * (java.util.List) */ @Override public float getLogProb(final List<W> ngram) { return ArrayEncodedNgramLanguageModel.DefaultImplementations.getLogProb(ngram, this); } public NgramMap<ProbBackoffPair> getNgramMap() { return map; } }