package edu.berkeley.cs.nlp.ocular.model;
import static org.junit.Assert.assertEquals;
import java.util.Collection;
import java.util.List;
import org.junit.Test;
import edu.berkeley.cs.nlp.ocular.gsm.GlyphChar;
import edu.berkeley.cs.nlp.ocular.gsm.GlyphChar.GlyphType;
import edu.berkeley.cs.nlp.ocular.model.transition.SparseTransitionModel.TransitionState;
import edu.berkeley.cs.nlp.ocular.train.FontTrainer;
import edu.berkeley.cs.nlp.ocular.util.Tuple2;
import static edu.berkeley.cs.nlp.ocular.util.CollectionHelper.*;
import tberg.murphy.indexer.HashMapIndexer;
import tberg.murphy.indexer.Indexer;
import edu.berkeley.cs.nlp.ocular.model.DecodeState;
import static edu.berkeley.cs.nlp.ocular.model.TransitionStateType.*;
/**
* @author Dan Garrette (dhgarrette@gmail.com)
*/
public class FontTrainEMTests {
class TS implements TransitionState {
public final int id;
private int languageIndex;
private int lmCharIndex;
private TransitionStateType type;
private GlyphChar glyphChar;
public TS(int id, int languageIndex, int lmCharIndex, TransitionStateType type, GlyphChar glyphChar) {
this.id = id;
this.languageIndex = languageIndex;
this.lmCharIndex = lmCharIndex;
this.type = type;
this.glyphChar = glyphChar;
}
@Override public int getLanguageIndex() { return languageIndex; }
@Override public int getLmCharIndex() { return lmCharIndex; }
@Override public TransitionStateType getType() { return type; }
@Override public GlyphChar getGlyphChar() { return glyphChar; }
@Override public int getOffset() { return -1; }
@Override public int getExposure() { return -1; }
@Override public Collection<Tuple2<TransitionState, Double>> forwardTransitions() { return null; }
@Override public Collection<Tuple2<TransitionState, Double>> nextLineStartStates() { return null; }
@Override public double endLogProb() { return -1; }
@Override public String toString() {
return "TS("+id+", "+languageIndex+", "+lmCharIndex+", "+type+", "+glyphChar+")";
}
}
private DecodeState DS(TS ts) {
return new DecodeState(ts, 0, 0, 0, 0);
}
@Test
public void test_makeFullViterbiStateSeq() {
Indexer<String> charIndexer = new HashMapIndexer<String>();
charIndexer.index(new String[] { " ", "-", "a", "b", "c" });
DecodeState[][] decodeStates = new DecodeState[][] {
new DecodeState[]{ DS(new TS(1, -1, 0, LMRGN, new GlyphChar(0, GlyphType.NORMAL_CHAR))),
DS(new TS(2, -1, 0, LMRGN, new GlyphChar(0, GlyphType.NORMAL_CHAR))),
DS(new TS(3, -1, 0, TMPL, new GlyphChar(0, GlyphType.NORMAL_CHAR))),
DS(new TS(4, 1, 2, TMPL, new GlyphChar(2, GlyphType.NORMAL_CHAR))),
DS(new TS(5, 1, 3, TMPL, new GlyphChar(3, GlyphType.NORMAL_CHAR))),
DS(new TS(6, 1, 4, TMPL, new GlyphChar(4, GlyphType.NORMAL_CHAR))),
DS(new TS(7, 1, 1, RMRGN_HPHN_INIT, new GlyphChar(1, GlyphType.NORMAL_CHAR))),
DS(new TS(8, 1, 0, RMRGN_HPHN, new GlyphChar(0, GlyphType.NORMAL_CHAR))),
DS(new TS(9, 1, 0, RMRGN_HPHN, new GlyphChar(0, GlyphType.NORMAL_CHAR))) },
new DecodeState[]{ DS(new TS(10, 1, 0, LMRGN_HPHN, new GlyphChar(0, GlyphType.NORMAL_CHAR))),
DS(new TS(11, 1, 0, LMRGN_HPHN, new GlyphChar(0, GlyphType.NORMAL_CHAR))),
DS(new TS(12, 1, 0, TMPL, new GlyphChar(0, GlyphType.NORMAL_CHAR))),
DS(new TS(13, 1, 2, TMPL, new GlyphChar(2, GlyphType.NORMAL_CHAR))),
DS(new TS(14, 1, 3, TMPL, new GlyphChar(3, GlyphType.NORMAL_CHAR))),
DS(new TS(15, 1, 4, TMPL, new GlyphChar(4, GlyphType.NORMAL_CHAR))),
DS(new TS(16, 1, 0, RMRGN, new GlyphChar(0, GlyphType.NORMAL_CHAR))),
DS(new TS(17, 1, 0, RMRGN, new GlyphChar(0, GlyphType.NORMAL_CHAR))) }
};
List<DecodeState> tsSeq = FontTrainer.makeFullViterbiStateSeq(decodeStates, charIndexer);
List<Integer> expectedIds = makeList(2, 3, 4, 1);
for (int i = 0; i < expectedIds.size(); ++i) {
assertEquals(expectedIds.get(i).intValue(), ((TS)tsSeq.get(i).ts).id);
}
}
}