package edu.berkeley.cs.nlp.ocular.eval;
import static org.junit.Assert.*;
import tberg.murphy.indexer.HashMapIndexer;
import tberg.murphy.indexer.Indexer;
import java.util.Arrays;
import java.util.Set;
import org.junit.Test;
import edu.berkeley.cs.nlp.ocular.data.textreader.CharIndexer;
import edu.berkeley.cs.nlp.ocular.lm.CodeSwitchLanguageModel;
import edu.berkeley.cs.nlp.ocular.lm.SingleLanguageModel;
import static edu.berkeley.cs.nlp.ocular.util.CollectionHelper.intArrayToList;
/**
* @author Hannah Alpert-Abrams (halperta@gmail.com)
* @author Dan Garrette (dhgarrette@gmail.com)
*/
public class LmPerplexityTests {
@SuppressWarnings("serial")
@Test
public void test_calculatePerplexity() {
final CharIndexer charIndexer = new CharIndexer();
charIndexer.index(new String[] { "a", "b", "x", "y", " " });
final Indexer<String> langIndexer = new HashMapIndexer<String>();
langIndexer.index(new String[]{ "Lang1", "Lang2" });
final int a = charIndexer.getIndex("a");
final int b = charIndexer.getIndex("b");
final int x = charIndexer.getIndex("x");
final int y = charIndexer.getIndex("y");
final int s = charIndexer.getIndex(" ");
final int l1 = langIndexer.getIndex("Lang1");
final int l2 = langIndexer.getIndex("Lang2");
final int[] ctx_ = new int[] {};
final int[] ctx_a = new int[] { a };
final int[] ctx_ab = new int[] { a , b };
final int[] ctx_ab_ = new int[] { a , b , s };
final int[] ctx_b_a = new int[] { b , s , a };
final int[] ctx__ab = new int[] { s, a , b };
final int[] ctx_x = new int[] { x };
final int[] ctx_xy = new int[] { x , y };
final SingleLanguageModel lang1Lm = new SingleLanguageModel() {
@Override public int getMaxOrder() { return 4; }
@Override public double getCharNgramProb(int[] context, int c) {
if (c == a) {
if (sameIntArray(context, ctx_)) return 0.11;
if (sameIntArray(context, ctx_ab_)) return 0.12;
}
if (c == b) {
if (sameIntArray(context, ctx_a)) return 0.13;
if (sameIntArray(context, ctx_b_a)) return 0.14;
}
if (c == s) {
if (sameIntArray(context, ctx_ab)) return 0.15;
if (sameIntArray(context, ctx__ab)) return 0.16;
}
throw new RuntimeException("getCharNgramProb(" + intArrayToList(context) + ", " + c + ")");
}
@Override public Indexer<String> getCharacterIndexer() { return charIndexer; }
@Override public Set<Integer> getActiveCharacters() { throw new RuntimeException(); }
@Override public int[] shrinkContext(int[] originalContext) { throw new RuntimeException(); }
@Override public boolean containsContext(int[] context) { throw new RuntimeException(); }
};
final SingleLanguageModel lang2Lm = new SingleLanguageModel() {
@Override public int getMaxOrder() { return 4; }
@Override public double getCharNgramProb(int[] context, int c) {
if (c == x) {
if (sameIntArray(context, ctx_)) return 0.21;
}
if (c == y) {
if (sameIntArray(context, ctx_x)) return 0.22;
}
if (c == s) {
if (sameIntArray(context, ctx_xy)) return 0.23;
}
throw new RuntimeException("getCharNgramProb(" + intArrayToList(context) + ", " + c + ")");
}
@Override public Indexer<String> getCharacterIndexer() { return charIndexer; }
@Override public Set<Integer> getActiveCharacters() { throw new RuntimeException(); }
@Override public int[] shrinkContext(int[] originalContext) { throw new RuntimeException(); }
@Override public boolean containsContext(int[] context) { throw new RuntimeException(); }
};
final CodeSwitchLanguageModel csLm = new CodeSwitchLanguageModel() {
@Override public double getCharNgramProb(int[] context, int c) { throw new RuntimeException(); }
@Override public Indexer<String> getCharacterIndexer() { return charIndexer; }
@Override public Indexer<String> getLanguageIndexer() { return langIndexer; }
@Override public SingleLanguageModel get(int language) {
if (language == langIndexer.getIndex("Lang1")) return lang1Lm;
if (language == langIndexer.getIndex("Lang2")) return lang2Lm;
throw new RuntimeException();
}
@Override public double languagePrior(int language) {
if (language == langIndexer.getIndex("Lang1")) return 0.31;
throw new RuntimeException();
}
@Override public double languageTransitionProb(int fromLanguage, int destLanguage) {
if (fromLanguage == langIndexer.getIndex("Lang1")) {
if (destLanguage == langIndexer.getIndex("Lang1")) return 0.32;
if (destLanguage == langIndexer.getIndex("Lang2")) return 0.33;
}
if (fromLanguage == langIndexer.getIndex("Lang2")) {
if (destLanguage == langIndexer.getIndex("Lang1")) return 0.35;
if (destLanguage == langIndexer.getIndex("Lang2")) return 0.34;
}
throw new RuntimeException();
}
@Override public double getProbKeepSameLanguage() { throw new RuntimeException(); }
};
LmPerplexity lmPerplexity = new LmPerplexity(csLm);
/*
* "aba"
*
* P1(a|[]) * P(L1) = 0.11 * 0.31
* P1(b|[a]) * P(L1|L1) = 0.13 * 1.00
* P1( |[ab]) * P(L1|L1) = 0.15 * 1.00
* --------------------
* = 0.00066495 ^(-1/3)
* = 11.456984790348551
*/
double p1 = lmPerplexity.perplexity(Arrays.asList(a, b, s), Arrays.asList(l1, l1, l1));
assertEquals(11.456984790348551, p1, 0.00000000000001);
/*
* Lang1: a,b
* Lang2: x,y
*
* "ab ab xy ab"
*
* P1(a|[]) * P(L1) = 0.11 * 0.31
* P1(b|[a]) * P(L1|L1) = 0.13 * 1.00
* P1( |[ab]) * P(L1|L1) = 0.15 * 1.00
* P1(a|[ab ]) * P(L1|L1) = 0.12 * 0.32
* P1(b|[b a]) * P(L1|L1) = 0.14 * 1.00
* P1( |[ ab]) * P(L1|L1) = 0.16 * 1.00
* P2(x|[]) * P(L2|L1) = 0.21 * 0.33
* P2(y|[x]) * P(L2|L2) = 0.22 * 1.00
* P2( |[xy]) * P(L2|L2) = 0.23 * 1.00
* P1(a|[]) * P(L1|L2) = 0.11 * 0.35
* P1(b|[a]) * P(L1|L1) = 0.13 * 1.00
* --------------------
* = 1.0038205132552398E-11 ^(-1/11)
* = 9.996534024760905
*/
double p2 = lmPerplexity.perplexity(Arrays.asList(a, b, s, a, b, s, x, y, s, a, b), Arrays.asList(l1, l1, l1, l1, l1, l1, l2, l2, l2, l1, l1));
assertEquals(9.996534024760905, p2, 0.00000000000001);
}
@SuppressWarnings("serial")
@Test
public void test_calculatePerplexity_differentMaxOrders() {
final CharIndexer charIndexer = new CharIndexer();
charIndexer.index(new String[] { "a", "b", "x", "y", " " });
final Indexer<String> langIndexer = new HashMapIndexer<String>();
langIndexer.index(new String[]{ "Lang1", "Lang2" });
final int a = charIndexer.getIndex("a");
final int b = charIndexer.getIndex("b");
final int x = charIndexer.getIndex("x");
final int y = charIndexer.getIndex("y");
final int s = charIndexer.getIndex(" ");
final int l1 = langIndexer.getIndex("Lang1");
final int l2 = langIndexer.getIndex("Lang2");
final int[] ctx_ = new int[] {};
final int[] ctx_a = new int[] { a };
final int[] ctx_ab = new int[] { a , b };
final int[] ctx_ab_ = new int[] { a , b , s };
final int[] ctx_ab_a = new int[] { a, b , s , a };
final int[] ctx_b_ab = new int[] { b, s, a , b };
final int[] ctx_x = new int[] { x };
final int[] ctx_xy = new int[] { x , y };
final SingleLanguageModel lang1Lm = new SingleLanguageModel() {
@Override public int getMaxOrder() { return 5; }
@Override public double getCharNgramProb(int[] context, int c) {
if (c == a) {
if (sameIntArray(context, ctx_)) return 0.11;
if (sameIntArray(context, ctx_ab_)) return 0.12;
}
if (c == b) {
if (sameIntArray(context, ctx_a)) return 0.13;
if (sameIntArray(context, ctx_ab_a)) return 0.14;
}
if (c == s) {
if (sameIntArray(context, ctx_ab)) return 0.15;
if (sameIntArray(context, ctx_b_ab)) return 0.16;
}
throw new RuntimeException("getCharNgramProb(" + intArrayToList(context) + ", " + c + ")");
}
@Override public Indexer<String> getCharacterIndexer() { return charIndexer; }
@Override public Set<Integer> getActiveCharacters() { throw new RuntimeException(); }
@Override public int[] shrinkContext(int[] originalContext) { throw new RuntimeException(); }
@Override public boolean containsContext(int[] context) { throw new RuntimeException(); }
};
final SingleLanguageModel lang2Lm = new SingleLanguageModel() {
@Override public int getMaxOrder() { return 4; }
@Override public double getCharNgramProb(int[] context, int c) {
if (c == x) {
if (sameIntArray(context, ctx_)) return 0.21;
}
if (c == y) {
if (sameIntArray(context, ctx_x)) return 0.22;
}
if (c == s) {
if (sameIntArray(context, ctx_xy)) return 0.23;
}
throw new RuntimeException("getCharNgramProb(" + intArrayToList(context) + ", " + c + ")");
}
@Override public Indexer<String> getCharacterIndexer() { return charIndexer; }
@Override public Set<Integer> getActiveCharacters() { throw new RuntimeException(); }
@Override public int[] shrinkContext(int[] originalContext) { throw new RuntimeException(); }
@Override public boolean containsContext(int[] context) { throw new RuntimeException(); }
};
final CodeSwitchLanguageModel csLm = new CodeSwitchLanguageModel() {
@Override public double getCharNgramProb(int[] context, int c) { throw new RuntimeException(); }
@Override public Indexer<String> getCharacterIndexer() { return charIndexer; }
@Override public Indexer<String> getLanguageIndexer() { return langIndexer; }
@Override public SingleLanguageModel get(int language) {
if (language == langIndexer.getIndex("Lang1")) return lang1Lm;
if (language == langIndexer.getIndex("Lang2")) return lang2Lm;
throw new RuntimeException();
}
@Override public double languagePrior(int language) {
if (language == langIndexer.getIndex("Lang1")) return 0.31;
throw new RuntimeException();
}
@Override public double languageTransitionProb(int fromLanguage, int destLanguage) {
if (fromLanguage == langIndexer.getIndex("Lang1")) {
if (destLanguage == langIndexer.getIndex("Lang1")) return 0.32;
if (destLanguage == langIndexer.getIndex("Lang2")) return 0.33;
}
if (fromLanguage == langIndexer.getIndex("Lang2")) {
if (destLanguage == langIndexer.getIndex("Lang1")) return 0.35;
if (destLanguage == langIndexer.getIndex("Lang2")) return 0.34;
}
throw new RuntimeException();
}
@Override public double getProbKeepSameLanguage() { throw new RuntimeException(); }
};
LmPerplexity lmPerplexity = new LmPerplexity(csLm);
/*
* Lang1: a,b
* Lang2: x,y
*
* "ab ab xy ab"
*
* P1(a|[]) * P(L1) = 0.11 * 0.31
* P1(b|[a]) * P(L1|L1) = 0.13 * 1.00
* P1( |[ab]) * P(L1|L1) = 0.15 * 1.00
* P1(a|[ab ]) * P(L1|L1) = 0.12 * 0.32
* P1(b|[ab a]) * P(L1|L1) = 0.14 * 1.00
* P1( |[b ab]) * P(L1|L1) = 0.16 * 1.00
* P2(x|[]) * P(L2|L1) = 0.21 * 0.33
* P2(y|[x]) * P(L2|L2) = 0.22 * 1.00
* P2( |[xy]) * P(L2|L2) = 0.23 * 1.00
* P1(a|[]) * P(L1|L2) = 0.11 * 0.35
* P1(b|[a]) * P(L1|L1) = 0.13 * 1.00
* --------------------
* = 1.0038205132552398E-11 ^(-1/11)
* = 9.996534024760905
*/
double p2 = lmPerplexity.perplexity(Arrays.asList(a, b, s, a, b, s, x, y, s, a, b), Arrays.asList(l1, l1, l1, l1, l1, l1, l2, l2, l2, l1, l1));
assertEquals(9.996534024760905, p2, 0.00000000000001);
}
private boolean sameIntArray(int[] a, int[] b) {
if (a.length != b.length)
return false;
for (int i = 0; i < a.length; ++i) {
if (a[i] != b[i])
return false;
}
return true;
}
}