package edu.berkeley.cs.nlp.ocular.eval;
import java.util.List;
import edu.berkeley.cs.nlp.ocular.data.textreader.Charset;
import edu.berkeley.cs.nlp.ocular.lm.CodeSwitchLanguageModel;
import edu.berkeley.cs.nlp.ocular.util.CollectionHelper;
/**
* @author Hannah Alpert-Abrams (halperta@gmail.com)
* @author Dan Garrette (dhgarrette@gmail.com)
*/
public class LmPerplexity {
private CodeSwitchLanguageModel lm;
private final int spaceIndex;
public LmPerplexity(CodeSwitchLanguageModel lm) {
this.lm = lm;
this.spaceIndex = lm.getCharacterIndexer().getIndex(Charset.SPACE);
}
public double perplexity(List<Integer> viterbiNormalizedTranscriptionCharIndices, List<Integer> viterbiNormalizedTranscriptionLangIndices) {
double logTotalProbability = 0.0;
for (int i=0; i<viterbiNormalizedTranscriptionCharIndices.size(); ++i) {
int curC = viterbiNormalizedTranscriptionCharIndices.get(i);
int curL = getLangIndex(viterbiNormalizedTranscriptionLangIndices, i);
double langTransitionProb = getLangTransitionProb(i, curL, viterbiNormalizedTranscriptionCharIndices, viterbiNormalizedTranscriptionLangIndices);
double ngramProb = getNgramProb(i, curC, curL, viterbiNormalizedTranscriptionCharIndices, viterbiNormalizedTranscriptionLangIndices);
logTotalProbability += Math.log(langTransitionProb) + Math.log(ngramProb);
// StringBuilder ctxString = new StringBuilder();
// for (int c: viterbiNormalizedTranscriptionCharIndices.subList(findStartPoint(i, curL, viterbiNormalizedTranscriptionLangIndices), i))
// ctxString.append(lm.getCharacterIndexer().getObject(c));
// System.out.println(String.format("P_%d(%s | %s) = %s * %s", curL, lm.getCharacterIndexer().getObject(curC), ctxString, ngramProb, langTransitionProb));
}
return Math.exp(-logTotalProbability / viterbiNormalizedTranscriptionCharIndices.size());
}
private double getNgramProb(int i, int curC, int curL, List<Integer> viterbiNormalizedTranscriptionCharIndices, List<Integer> viterbiNormalizedTranscriptionLangIndices) {
int startPoint = findStartPoint(i, curL, viterbiNormalizedTranscriptionLangIndices);
int[] context = CollectionHelper.intListToArray(viterbiNormalizedTranscriptionCharIndices.subList(startPoint, i));
return lm.get(curL).getCharNgramProb(context, curC);
}
private int findStartPoint(int i, int curL, List<Integer> viterbiNormalizedTranscriptionLangIndices) {
int startPoint = i;
while (startPoint > 0 && getLangIndex(viterbiNormalizedTranscriptionLangIndices, startPoint-1) == curL && i-startPoint < lm.get(curL).getMaxOrder()-1) {
--startPoint;
}
return startPoint;
}
private double getLangTransitionProb(int i, int curL, List<Integer> viterbiNormalizedTranscriptionCharIndices, List<Integer> viterbiNormalizedTranscriptionLangIndices) {
if (i > 0) {
int prevC = viterbiNormalizedTranscriptionCharIndices.get(i-1);
int prevL = getLangIndex(viterbiNormalizedTranscriptionLangIndices, i-1);
if (prevC != spaceIndex) {
if (prevL != curL) throw new RuntimeException("Characters cannot change languages mid-word.");
return 1.0;
}
else {
return lm.languageTransitionProb(prevL, curL);
}
}
else {
return lm.languagePrior(curL);
}
}
private int getLangIndex(List<Integer> viterbiNormalizedTranscriptionLangIndices, int i) {
int curL = viterbiNormalizedTranscriptionLangIndices.get(i);
if (curL < 0) {
if (this.lm.getLanguageIndexer().size() == 1)
curL = 0;
else if (i > 0)
throw new RuntimeException("curl="+curL+", i="+i);
}
return curL;
}
}