package edu.berkeley.nlp.lm;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import edu.berkeley.nlp.lm.collections.Counter;
/**
*
* Base interface for an n-gram language model, which exposes only inefficient
* convenience methods. See {@link ContextEncodedNgramLanguageModel} and
* {@link ArrayEncodedNgramLanguageModel} for more efficient accessors.
*
* @author adampauls
*
* @param <W>
*
*/
public interface NgramLanguageModel<W>
{
/**
* Maximum size of n-grams stored by the model.
*
* @return
*/
public int getLmOrder();
/**
* Each LM must have a WordIndexer which assigns integer IDs to each word W
* in the language.
*
* @return
*/
public WordIndexer<W> getWordIndexer();
/**
* Scores a complete sentence, taking appropriate care with the start- and
* end-of-sentence symbols. This is a convenience method and will generally
* be inefficient.
*
* @return
*/
public float scoreSentence(List<W> sentence);
/**
*
* Scores an n-gram. This is a convenience method and will generally be
* relatively inefficient. More efficient versions are available in
* {@link ArrayEncodedNgramLanguageModel#getLogProb(int[], int, int)} and
* {@link ContextEncodedNgramLanguageModel#getLogProb(long, int, int, edu.berkeley.nlp.lm.ContextEncodedNgramLanguageModel.LmContextInfo)}
* .
*/
public float getLogProb(List<W> ngram);
/**
* Sets the (log) probability for an OOV word. Note that this is in general
* different from the log prob of the <code>unk</code> tag probability.
*
* @author adampauls
*
*/
public void setOovWordLogProb(float logProb);
public static class StaticMethods
{
public static <T> int[] toIntArray(final List<T> ngram, final ArrayEncodedNgramLanguageModel<T> lm) {
final int[] ints = new int[ngram.size()];
final WordIndexer<T> wordIndexer = lm.getWordIndexer();
for (int i = 0; i < ngram.size(); ++i) {
ints[i] = wordIndexer.getIndexPossiblyUnk(ngram.get(i));
}
return ints;
}
public static <T> List<T> toObjectList(final int[] ngram, final ArrayEncodedNgramLanguageModel<T> lm) {
final List<T> ret = new ArrayList<T>(ngram.length);
final WordIndexer<T> wordIndexer = lm.getWordIndexer();
for (int i = 0; i < ngram.length; ++i) {
ret.add(wordIndexer.getWord(ngram[i]));
}
return ret;
}
/**
* Samples from this language model. This is not meant to be
* particularly efficient
*
* @param random
* @return
*/
public static <W> List<W> sample(Random random, final NgramLanguageModel<W> lm) {
return sample(random, lm, 1.0);
}
public static <W> List<W> sample(Random random, final NgramLanguageModel<W> lm, final double sampleTemperature) {
List<W> ret = new ArrayList<W>();
ret.add(lm.getWordIndexer().getStartSymbol());
while (true) {
final int contextEnd = ret.size();
final int contextStart = Math.max(0, contextEnd - lm.getLmOrder() + 1);
Counter<W> c = new Counter<W>();
List<W> ngram = new ArrayList<W>(ret.subList(contextStart, contextEnd));
ngram.add(null);
for (int index = 0; index < lm.getWordIndexer().numWords(); ++index) {
W word = lm.getWordIndexer().getWord(index);
if (word.equals(lm.getWordIndexer().getStartSymbol())) continue;
if (ret.size() <= 1 && word.equals(lm.getWordIndexer().getEndSymbol())) continue;
ngram.set(ngram.size() - 1, word);
c.setCount(word, Math.exp(sampleTemperature * lm.getLogProb(ngram) * Math.log(10)));
}
W sample = c.sample(random);
ret.add(sample);
if (sample.equals(lm.getWordIndexer().getEndSymbol())) break;
}
return ret.subList(1, ret.size() - 1);
}
/**
* Builds a distribution over next possible words given the context. Context can be of any length, but
* only at most <code>lm.getLmOrder() - 1</code> words are actually used.
*
* @param <W>
* @param lm
* @param context
* @return
*/
public static <W> Counter<W> getDistributionOverNextWords(final NgramLanguageModel<W> lm, List<W> context) {
List<W> ngram = new ArrayList<W>();
for (int i = 0; i < lm.getLmOrder() - 1 && i < context.size(); ++i) {
ngram.add(context.get(context.size() - i - 1));
}
if (ngram.size() < lm.getLmOrder() - 1) ngram.add(lm.getWordIndexer().getStartSymbol());
Collections.reverse(ngram);
ngram.add(null);
Counter<W> c = new Counter<W>();
for (int index = 0; index < lm.getWordIndexer().numWords(); ++index) {
W word = lm.getWordIndexer().getWord(index);
if (word.equals(lm.getWordIndexer().getStartSymbol())) continue;
ngram.set(ngram.size() - 1, word);
c.setCount(word, Math.exp(lm.getLogProb(ngram) * Math.log(10)));
}
return c;
}
}
}