package edu.berkeley.cs.nlp.ocular.model.transition;
import static edu.berkeley.cs.nlp.ocular.util.Tuple2.Tuple2;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import tberg.murphy.arrays.a;
import edu.berkeley.cs.nlp.ocular.data.textreader.Charset;
import edu.berkeley.cs.nlp.ocular.gsm.GlyphChar;
import edu.berkeley.cs.nlp.ocular.gsm.GlyphChar.GlyphType;
import edu.berkeley.cs.nlp.ocular.lm.SingleLanguageModel;
import edu.berkeley.cs.nlp.ocular.model.TransitionStateType;
import edu.berkeley.cs.nlp.ocular.util.Tuple2;
/**
* @author Taylor Berg-Kirkpatrick (tberg@eecs.berkeley.edu)
*/
public class CharacterNgramTransitionModel implements SparseTransitionModel {
public class CharacterNgramTransitionState implements SparseTransitionModel.TransitionState {
private final int[] context;
private final TransitionStateType type;
private final int lmCharIndex;
public CharacterNgramTransitionState(int[] context, TransitionStateType type) {
this.context = context;
this.type = type;
if (context.length == 0 || type == TransitionStateType.LMRGN || type == TransitionStateType.LMRGN_HPHN || type == TransitionStateType.RMRGN || type == TransitionStateType.RMRGN_HPHN) {
this.lmCharIndex = spaceCharIndex;
} else if (type == TransitionStateType.RMRGN_HPHN_INIT) {
this.lmCharIndex = hyphenCharIndex;
} else {
this.lmCharIndex = context[context.length-1];
}
}
public boolean equals(Object other) {
if (other instanceof CharacterNgramTransitionState) {
CharacterNgramTransitionState that = (CharacterNgramTransitionState) other;
if (this.type != that.type) {
return false;
} else if (!Arrays.equals(this.context, that.context)) {
return false;
} else {
return true;
}
} else {
return false;
}
}
public int hashCode() {
return 1013 * Arrays.hashCode(context) + 1009 * this.type.ordinal();
}
public Collection<Tuple2<TransitionState,Double>> nextLineStartStates() {
List<Tuple2<TransitionState,Double>> result = new ArrayList<Tuple2<TransitionState,Double>>();
TransitionStateType type = getType();
int[] context = getContext();
if (type == TransitionStateType.TMPL) {
double scoreWithSpace = Math.log(lm.getCharNgramProb(context, spaceCharIndex));
if (scoreWithSpace != Double.NEGATIVE_INFINITY) {
int[] contextWithSpace = shrinkContext(a.append(context, spaceCharIndex));
{
double score = Math.log(LINE_MRGN_PROB) + scoreWithSpace;
if (score != Double.NEGATIVE_INFINITY) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(contextWithSpace, TransitionStateType.LMRGN), score));
}
}
for (int c=0; c<lm.getCharacterIndexer().size(); ++c) {
double score = Math.log((1.0 - LINE_MRGN_PROB)) + scoreWithSpace + Math.log(lm.getCharNgramProb(contextWithSpace, c));
if (score != Double.NEGATIVE_INFINITY) {
int[] nextContext = shrinkContext(a.append(contextWithSpace, c));
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, TransitionStateType.TMPL), score));
}
}
}
} else if (type == TransitionStateType.RMRGN) {
{
double score = Math.log(LINE_MRGN_PROB);
if (score != Double.NEGATIVE_INFINITY) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, TransitionStateType.LMRGN), score));
}
}
for (int c=0; c<lm.getCharacterIndexer().size(); ++c) {
double score = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, c));
if (score != Double.NEGATIVE_INFINITY) {
int[] nextContext = shrinkContext(a.append(context, c));
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, TransitionStateType.TMPL), score));
}
}
} else if (type == TransitionStateType.RMRGN_HPHN || type == TransitionStateType.RMRGN_HPHN_INIT) {
{
double score = Math.log(LINE_MRGN_PROB);
if (score != Double.NEGATIVE_INFINITY) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, TransitionStateType.LMRGN_HPHN), score));
}
}
for (int c=0; c<lm.getCharacterIndexer().size(); ++c) {
double score = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, c));
if (c != spaceCharIndex && !isPunc[c]) {
int[] nextContext = shrinkContext(a.append(context, c));
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, TransitionStateType.TMPL), score));
}
}
} else if (type == TransitionStateType.LMRGN || type == TransitionStateType.LMRGN_HPHN) {
{
double score = Math.log(LINE_MRGN_PROB);
if (score != Double.NEGATIVE_INFINITY) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(new int[0], TransitionStateType.LMRGN), score));
}
}
for (int c=0; c<lm.getCharacterIndexer().size(); ++c) {
double score = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, c));
if (score != Double.NEGATIVE_INFINITY) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(new int[] {c}, TransitionStateType.TMPL), score));
}
}
}
return result;
}
public double endLogProb() {
return 0.0;
}
public Collection<Tuple2<TransitionState,Double>> forwardTransitions() {
int[] context = getContext();
TransitionStateType type = getType();
List<Tuple2<TransitionState,Double>> result = new ArrayList<Tuple2<TransitionState,Double>>();
if (type == TransitionStateType.LMRGN) {
{
double score = Math.log(LINE_MRGN_PROB);
if (score != Double.NEGATIVE_INFINITY) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, TransitionStateType.LMRGN), score));
}
}
for (int c=0; c<lm.getCharacterIndexer().size(); ++c) {
double score = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, c));
if (score != Double.NEGATIVE_INFINITY) {
int[] nextContext = shrinkContext(a.append(context, c));
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, TransitionStateType.TMPL), score));
}
}
} else if (type == TransitionStateType.LMRGN_HPHN) {
{
double score = Math.log(LINE_MRGN_PROB);
if (score != Double.NEGATIVE_INFINITY) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, TransitionStateType.LMRGN_HPHN), score));
}
}
for (int c=0; c<lm.getCharacterIndexer().size(); ++c) {
double score = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, c));
if (c != spaceCharIndex && !isPunc[c]) {
int[] nextContext = shrinkContext(a.append(context, c));
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, TransitionStateType.TMPL), score));
}
}
} else if (type == TransitionStateType.RMRGN) {
double score = Math.log(LINE_MRGN_PROB);
if (score != Double.NEGATIVE_INFINITY) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, TransitionStateType.RMRGN), score));
}
} else if (type == TransitionStateType.RMRGN_HPHN) {
double score = Math.log(LINE_MRGN_PROB);
if (score != Double.NEGATIVE_INFINITY) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, TransitionStateType.RMRGN_HPHN), score));
}
} else if (type == TransitionStateType.RMRGN_HPHN_INIT) {
double score = Math.log(LINE_MRGN_PROB);
if (score != Double.NEGATIVE_INFINITY) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, TransitionStateType.RMRGN_HPHN), score));
}
} else if (type == TransitionStateType.TMPL) {
{
double score = Math.log(LINE_MRGN_PROB) + Math.log(1.0 - LINE_END_HYPHEN_PROB) + Math.log(lm.getCharNgramProb(context, spaceCharIndex));
if (score != Double.NEGATIVE_INFINITY) {
int[] nextContext = shrinkContext(a.append(context, spaceCharIndex));
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, TransitionStateType.RMRGN), score));
}
}
{
double score = Math.log(LINE_MRGN_PROB) + Math.log(LINE_END_HYPHEN_PROB);
if (score != Double.NEGATIVE_INFINITY) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(context, TransitionStateType.RMRGN_HPHN_INIT), score));
}
}
for (int nextC=0; nextC<lm.getCharacterIndexer().size(); ++nextC) {
double score = Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(context, nextC));
if (score != Double.NEGATIVE_INFINITY) {
int[] nextContext = shrinkContext(a.append(context, nextC));
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(nextContext, TransitionStateType.TMPL), score));
}
}
}
return result;
}
public int getLmCharIndex() {
return lmCharIndex;
}
public GlyphChar getGlyphChar() {
// Always render the character proposed by the language model
return new GlyphChar(lmCharIndex, GlyphType.NORMAL_CHAR);
}
public int getOffset() {
throw new Error("Method not implemented");
}
public int getExposure() {
throw new Error("Method not implemented");
}
public int[] getContext() {
return context;
}
public TransitionStateType getType() {
return type;
}
public int getLanguageIndex() {
return -1;
}
}
public static final double LINE_MRGN_PROB = 0.5;
public static final double LINE_END_HYPHEN_PROB = 1e-8;
private SingleLanguageModel lm;
private int spaceCharIndex;
private int hyphenCharIndex;
private boolean[] isPunc;
public CharacterNgramTransitionModel(SingleLanguageModel lm) {
this.lm = lm;
this.spaceCharIndex = lm.getCharacterIndexer().getIndex(Charset.SPACE);
this.hyphenCharIndex = lm.getCharacterIndexer().getIndex(Charset.HYPHEN);
this.isPunc = new boolean[lm.getCharacterIndexer().size()];
Arrays.fill(this.isPunc, false);
for (String c : lm.getCharacterIndexer().getObjects()) {
if(Charset.isPunctuationChar(c))
isPunc[lm.getCharacterIndexer().getIndex(c)] = true;
}
}
public Collection<Tuple2<TransitionState,Double>> startStates() {
List<Tuple2<TransitionState,Double>> result = new ArrayList<Tuple2<TransitionState,Double>>();
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(new int[0], TransitionStateType.LMRGN), Math.log(LINE_MRGN_PROB)));
for (int c=0; c<lm.getCharacterIndexer().size(); ++c) {
result.add(Tuple2((TransitionState) new CharacterNgramTransitionState(new int[] {c}, TransitionStateType.TMPL), Math.log((1.0 - LINE_MRGN_PROB)) + Math.log(lm.getCharNgramProb(new int[0], c))));
}
return result;
}
private int[] shrinkContext(int[] context) {
return lm.shrinkContext(context);
}
}