package edu.berkeley.cs.nlp.ocular.lm;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import edu.berkeley.cs.nlp.ocular.data.textreader.CharIndexer;
import edu.berkeley.cs.nlp.ocular.data.textreader.Charset;
import edu.berkeley.cs.nlp.ocular.data.textreader.TextReader;
import edu.berkeley.cs.nlp.ocular.util.ArrayHelper;
import edu.berkeley.cs.nlp.ocular.util.CollectionHelper;
import tberg.murphy.indexer.Indexer;
/**
* @author Taylor Berg-Kirkpatrick (tberg@eecs.berkeley.edu)
*/
public class NgramLanguageModel implements SingleLanguageModel {
private static final long serialVersionUID = 873286328149782L;
private Indexer<String> charIndexer;
private CountDbBig[] countDbs;
private int maxOrder;
private LMType type;
private double lmPower;
private Set<LongArrWrapper> allContextsSet;
private List<int[]> allContexts;
public static enum LMType { MLE, ABS_DISC, KNESER_NEY }
private Set<Integer> activeCharacters;
public Set<Integer> getActiveCharacters() { return activeCharacters; }
public NgramLanguageModel(Indexer<String> charIndexer, CountDbBig[] countDbs, Set<Integer> activeCharacters, LMType type, double lmPower) {
this.charIndexer = charIndexer;
this.countDbs = countDbs;
this.maxOrder = countDbs.length;
if (maxOrder <= 0) throw new RuntimeException("maxOrder must be greater than zero.");
this.type = type;
this.lmPower = lmPower;
this.allContextsSet = new HashSet<LongArrWrapper>();
this.allContexts = new ArrayList<int[]>();
for (int i = 0; i < this.maxOrder - 1; i++) {
for (long[] key : countDbs[i].getKeys()) {
if (key != null && countDbs[i].getCount(key, CountType.HISTORY_TYPE_INDEX) > 0) {
allContextsSet.add(new LongArrWrapper(key));
allContexts.add(LongNgram.convertToIntArr(key));
}
}
}
if (activeCharacters == null) throw new RuntimeException("activeCharacters is null!");
this.activeCharacters = activeCharacters;
}
public static NgramLanguageModel buildFromText(String fileName, int maxNumLines, int maxOrder, LMType type, double lmPower, TextReader textReader) {
return buildFromText(CollectionHelper.makeList(fileName), maxNumLines, maxOrder, type, lmPower, textReader);
}
public static NgramLanguageModel buildFromText(List<String> fileNames, int maxNumLines, int maxOrder, LMType type, double lmPower, TextReader textReader) {
CorpusCounter counter = new CorpusCounter(maxOrder);
Set<Integer> activeCharacters = counter.getActiveCharacters();
Indexer<String> charIndexer = new CharIndexer();
for (String fileName : fileNames) {
counter.countRecursive(fileName, maxNumLines, charIndexer, textReader);
}
activeCharacters.add(charIndexer.getIndex(Charset.SPACE));
charIndexer.lock();
counter.printStats(-1);
return new NgramLanguageModel(charIndexer, counter.getCounts(), activeCharacters, type, lmPower);
}
public void checkNormalizes(int[] context) {
double totalProb = 0;
for (int i = 0; i < charIndexer.size(); i++) {
totalProb += getCharNgramProb(context, i);
}
System.out.println("Total prob for context " + LongNgram.toString(context, charIndexer) + ": " + totalProb);
}
public Indexer<String> getCharacterIndexer() {
return charIndexer;
}
public int getMaxOrder() {
return maxOrder;
}
public double getLmPower() {
return lmPower;
}
public int[] shrinkContext(int[] originalContext) {
int[] newContext = originalContext;
if (newContext.length > maxOrder - 1) {
newContext = ArrayHelper.takeRight(newContext, maxOrder - 1);
}
while (!containsContext(newContext) && newContext.length > 0) {
newContext = ArrayHelper.takeRight(newContext, newContext.length - 1);
}
return newContext;
}
public boolean containsContext(int[] context) {
if (context.length == 0)
return true;
else
return allContextsSet.contains(new LongArrWrapper(LongNgram.convertToLong(context)));
}
public double getCharNgramProb(int[] context, int c) {
// Uncomment this to renormalize the distribution after exponentiating
// double normalizer = 0.0;
// for (int i = 0; i < charIndexer.size(); i++) {
// normalizer += getCharNgramProbRaw(context, i);
// }
// assert normalizer > 0;
// return getCharNgramProbRaw(context, c)/normalizer;
return getCharNgramProbRaw(context, c);
}
/**
* Returns an exponentiated probability, which won't necessarily
* sum to one
* @param context
* @param c
* @return
*/
private double getCharNgramProbRaw(int[] context, int c) {
int[] intNgram = new int[context.length+1];
System.arraycopy(context, 0, intNgram, 0, context.length);
intNgram[intNgram.length-1] = c;
NgramWrapper ngram = NgramWrapper.getNew(intNgram, 0, intNgram.length);
double prob = 0.0;
switch (type) {
case MLE: prob = new NgramCounts(ngram, countDbs).getTokenMle();
break;
case ABS_DISC: prob = new NgramCounts(ngram, countDbs).getAbsoluteDiscounting();
break;
case KNESER_NEY: prob = new NgramCounts(ngram, countDbs).getKneserNey();
break;
default: throw new RuntimeException("Bad type: " + type);
}
return Math.pow(prob, lmPower);
}
}