package edu.berkeley.nlp.lm;
import java.util.List;
import edu.berkeley.nlp.lm.collections.BoundedList;
/**
* Top-level interface for an n-gram language model which accepts n-gram in an
* array-of-integers encoding. The integers represent words of type
* <code>W</code> in the vocabulary, and the mapping from the vocabulary to
* integers is managed by an instance of the {@link WordIndexer} class.
*
* @author adampauls
*/
public interface ArrayEncodedNgramLanguageModel<W> extends NgramLanguageModel<W>
{
/**
* Calculate language model score of an n-gram. <b>Warning:</b> if you
* pass in an n-gram of length greater than <code>getLmOrder()</code>,
* this call will silently ignore the extra words of context. In other
* words, if you pass in a 5-gram (<code>endPos-startPos == 5</code>) to
* a 3-gram model, it will only score the words from <code>startPos + 2</code>
* to <code>endPos</code>.
*
* @param ngram
* array of words in integer representation
* @param startPos
* start of the portion of the array to be read
* @param endPos
* end of the portion of the array to be read.
* @return
*/
public float getLogProb(int[] ngram, int startPos, int endPos);
/**
* Equivalent to <code>getLogProb(ngram, 0, ngram.length)</code>
*
* @see #getLogProb(int[], int, int)
*/
public float getLogProb(int[] ngram);
public static class DefaultImplementations
{
public static <T> float scoreSentence(final List<T> sentence, final ArrayEncodedNgramLanguageModel<T> lm) {
final List<T> sentenceWithBounds = new BoundedList<T>(sentence, lm.getWordIndexer().getStartSymbol(), lm.getWordIndexer().getEndSymbol());
final int lmOrder = lm.getLmOrder();
float sentenceScore = 0.0f;
for (int i = 1; i < lmOrder - 1 && i <= sentenceWithBounds.size() + 1; ++i) {
final List<T> ngram = sentenceWithBounds.subList(-1, i);
final float scoreNgram = lm.getLogProb(ngram);
sentenceScore += scoreNgram;
}
for (int i = lmOrder - 1; i < sentenceWithBounds.size() + 2; ++i) {
final List<T> ngram = sentenceWithBounds.subList(i - lmOrder, i);
final float scoreNgram = lm.getLogProb(ngram);
sentenceScore += scoreNgram;
}
return sentenceScore;
}
public static <T> float getLogProb(final int[] ngram, final ArrayEncodedNgramLanguageModel<T> lm) {
return lm.getLogProb(ngram, 0, ngram.length);
}
public static <T> float getLogProb(final List<T> ngram, final ArrayEncodedNgramLanguageModel<T> lm) {
final int[] ints = StaticMethods.toIntArray(ngram, lm);
return lm.getLogProb(ints, 0, ints.length);
}
}
}