package edu.berkeley.nlp.lm; import java.io.Serializable; import java.util.List; import edu.berkeley.nlp.lm.map.NgramMap; import edu.berkeley.nlp.lm.util.LongRef; import edu.berkeley.nlp.lm.values.CountValueContainer; /** * Language model implementation which uses stupid backoff (Brants et al., 2007) * computation. Note that stupid backoff does not properly normalize, so the * scores this LM computes are not in fact probabilities. Also, unliked LMs estimated * using {@link LmReaders.createKneserNeyLmFromTextFiles}, this model returns natural * logarithms instead of log10. * * @author adampauls * * @param <W> */ public class StupidBackoffLm<W> extends AbstractArrayEncodedNgramLanguageModel<W> implements ArrayEncodedNgramLanguageModel<W>, Serializable { /** * */ private static final long serialVersionUID = 1L; protected final NgramMap<LongRef> map; private final float alpha; public StupidBackoffLm(final int lmOrder, final WordIndexer<W> wordIndexer, final NgramMap<LongRef> map, final ConfigOptions opts) { super(lmOrder, wordIndexer, (float) opts.unknownWordLogProb); this.map = map; this.alpha = (float) opts.stupidBackoffAlpha; } /* * (non-Javadoc) * * @see * edu.berkeley.nlp.lm.AbstractArrayEncodedNgramLanguageModel#getLogProb * (int[], int, int) */ @Override public float getLogProb(final int[] ngram, final int startPos, final int endPos) { final NgramMap<LongRef> localMap = map; float logProb = oovWordLogProb; long probContext = 0L; int probContextOrder = -1; long backoffContext = 0L; int backoffContextOrder = -1; final LongRef scratch = new LongRef(-1L); for (int i = endPos - 1; i >= startPos; --i) { assert (probContext >= 0); probContext = localMap.getValueAndOffset(probContext, probContextOrder, ngram[i], scratch); if (probContext < 0) { return logProb; } else { final long currCount = scratch.value; long backoffCount = -1L; if (i == endPos - 1) { backoffCount = ((CountValueContainer) map.getValues()).getUnigramSum(); } else { backoffContext = localMap.getValueAndOffset(backoffContext, backoffContextOrder++, ngram[i], scratch); backoffCount = scratch.value; } logProb = (float) Math.log(currCount) - (float) Math.log(backoffCount) + (i - startPos)*(float)Math.log(alpha); probContextOrder++; } } return logProb; } /** * Gets the raw count of an n-gram. * * @param ngram * @param startPos * @param endPos * @return count of n-gram, or -1 if n-gram is not in the map. */ public long getRawCount(final int[] ngram, final int startPos, final int endPos) { final NgramMap<LongRef> localMap = map; long probContext = 0L; final LongRef scratch = new LongRef(-1L); for (int probContextOrder = -1; probContextOrder < endPos - startPos - 1; ++probContextOrder) { assert (probContext >= 0); probContext = localMap.getValueAndOffset(probContext, probContextOrder, ngram[endPos - probContextOrder - 2], scratch); if (probContext < 0) { return -1; } } return scratch.value; } private static float pow(final float alpha, final int n) { float ret = 1.0f; for (int i = 0; i < n; ++i) ret *= alpha; return ret; } /* * (non-Javadoc) * * @see * edu.berkeley.nlp.lm.AbstractArrayEncodedNgramLanguageModel#getLogProb * (int[]) */ @Override public float getLogProb(final int[] ngram) { return ArrayEncodedNgramLanguageModel.DefaultImplementations.getLogProb(ngram, this); } /* * (non-Javadoc) * * @see * edu.berkeley.nlp.lm.AbstractArrayEncodedNgramLanguageModel#getLogProb * (java.util.List) */ @Override public float getLogProb(final List<W> ngram) { return ArrayEncodedNgramLanguageModel.DefaultImplementations.getLogProb(ngram, this); } public NgramMap<LongRef> getNgramMap() { return map; } }