package edu.berkeley.cs.nlp.ocular.gsm;
import static edu.berkeley.cs.nlp.ocular.data.textreader.Charset.makeAddTildeMap;
import static edu.berkeley.cs.nlp.ocular.data.textreader.Charset.makeCanBeElidedSet;
import static edu.berkeley.cs.nlp.ocular.data.textreader.Charset.makeCanBeReplacedSet;
import static edu.berkeley.cs.nlp.ocular.data.textreader.Charset.makeDiacriticDisregardMap;
import static edu.berkeley.cs.nlp.ocular.data.textreader.Charset.makeValidDoublableSet;
import static edu.berkeley.cs.nlp.ocular.data.textreader.Charset.makeValidSubstitutionCharsSet;
import static edu.berkeley.cs.nlp.ocular.util.CollectionHelper.makeSet;
import static edu.berkeley.cs.nlp.ocular.util.CollectionHelper.setUnion;
import java.util.List;
import java.util.Map;
import java.util.Set;
import edu.berkeley.cs.nlp.ocular.data.textreader.Charset;
import edu.berkeley.cs.nlp.ocular.gsm.GlyphChar.GlyphType;
import edu.berkeley.cs.nlp.ocular.model.DecodeState;
import edu.berkeley.cs.nlp.ocular.model.TransitionStateType;
import edu.berkeley.cs.nlp.ocular.model.transition.SparseTransitionModel.TransitionState;
import edu.berkeley.cs.nlp.ocular.util.ArrayHelper;
import edu.berkeley.cs.nlp.ocular.util.FileHelper;
import tberg.murphy.indexer.Indexer;
/**
* @author Dan Garrette (dhgarrette@gmail.com)
*/
public class BasicGlyphSubstitutionModel implements GlyphSubstitutionModel {
private static final long serialVersionUID = -8473038413268727114L;
private Indexer<String> langIndexer;
private Indexer<String> charIndexer;
private int numChars;
private double[/*language*/][/*lmChar*/][/*glyph*/] probs;
private double gsmPower;
public BasicGlyphSubstitutionModel(double[][][] probs,
double gsmPower,
Indexer<String> langIndexer,
Indexer<String> charIndexer) {
this.langIndexer = langIndexer;
this.charIndexer = charIndexer;
this.numChars = charIndexer.size();
this.probs = probs;
this.gsmPower = gsmPower;
}
public double glyphProb(int language, int lmChar, GlyphChar glyphChar) {
GlyphType glyphType = glyphChar.glyphType;
int glyph = (glyphType == GlyphType.NORMAL_CHAR) ? glyphChar.templateCharIndex : (numChars + glyphType.ordinal());
double p = probs[language][lmChar][glyph];
return Math.pow(p, gsmPower);
}
public Indexer<String> getLanguageIndexer() {
return langIndexer;
}
public Indexer<String> getCharacterIndexer() {
return charIndexer;
}
public static class BasicGlyphSubstitutionModelFactory {
private double gsmSmoothingCount;
private double elisionSmoothingCountMultiplier;
private Indexer<String> langIndexer;
private Indexer<String> charIndexer;
private Set<Integer>[] activeCharacterSets;
private Set<Integer> canBeReplaced;
private Set<Integer> canBeDoubled;
private Set<Integer> validSubstitutionChars;
private Set<Integer> canBeElided;
private Map<Integer,Integer> addTilde;
private Map<Integer,Integer> diacriticDisregardMap;
private int sCharIndex;
private int longsCharIndex;
private int fCharIndex;
private int lCharIndex;
private int hyphenCharIndex;
private int spaceCharIndex;
private int numLanguages;
private int numChars;
private int numGlyphs;
public final int GLYPH_ELISION_TILDE;
public final int GLYPH_TILDE_ELIDED;
public final int GLYPH_FIRST_ELIDED;
public final int GLYPH_DOUBLED;
public final int GLYPH_ELIDED;
//public final int GLYPH_RMRGN_HPHN_DROP;
private double gsmPower;
private int minCountsForEvalGsm;
private String outputPath;
public BasicGlyphSubstitutionModelFactory(
double gsmSmoothingCount,
double elisionSmoothingCountMultiplier,
Indexer<String> langIndexer,
Indexer<String> charIndexer,
Set<Integer>[] activeCharacterSets,
double gsmPower, int minCountsForEvalGsm,
String outputPath) {
this.gsmSmoothingCount = gsmSmoothingCount;
this.elisionSmoothingCountMultiplier = elisionSmoothingCountMultiplier;
this.langIndexer = langIndexer;
this.charIndexer = charIndexer;
this.activeCharacterSets = activeCharacterSets;
this.gsmPower = gsmPower;
this.minCountsForEvalGsm = minCountsForEvalGsm;
this.canBeReplaced = makeCanBeReplacedSet(charIndexer);
this.canBeDoubled = makeValidDoublableSet(charIndexer);
this.validSubstitutionChars = makeValidSubstitutionCharsSet(charIndexer);
this.canBeElided = makeCanBeElidedSet(charIndexer);
this.addTilde = makeAddTildeMap(charIndexer);
this.diacriticDisregardMap = makeDiacriticDisregardMap(charIndexer);
this.sCharIndex = charIndexer.contains("s") ? charIndexer.getIndex("s") : -1;
this.longsCharIndex = charIndexer.getIndex(Charset.LONG_S);
this.fCharIndex = charIndexer.contains("f") ? charIndexer.getIndex("f") : -1;
this.lCharIndex = charIndexer.contains("l") ? charIndexer.getIndex("l") : -1;
this.hyphenCharIndex = charIndexer.getIndex(Charset.HYPHEN);
this.spaceCharIndex = charIndexer.getIndex(Charset.SPACE);
this.numLanguages = langIndexer.size();
this.numChars = charIndexer.size();
this.numGlyphs = numChars + GlyphType.values().length-1;
this.GLYPH_ELISION_TILDE = numChars + GlyphType.ELISION_TILDE.ordinal();
this.GLYPH_TILDE_ELIDED = numChars + GlyphType.TILDE_ELIDED.ordinal();
this.GLYPH_FIRST_ELIDED = numChars + GlyphType.FIRST_ELIDED.ordinal();
this.GLYPH_DOUBLED = numChars + GlyphType.DOUBLED.ordinal();
//this.GLYPH_RMRGN_HPHN_DROP = numChars + GlyphType.RMRGN_HPHN_DROP.ordinal();
this.GLYPH_ELIDED = numChars + GlyphType.ELIDED.ordinal();
this.outputPath = outputPath;
}
public GlyphSubstitutionModel uniform() {
return make(initializeNewCountsMatrix(), 0, 0);
}
/**
* Initialize the counts matrix. Add smoothing counts (and no counts for invalid options).
*/
public double[][][] initializeNewCountsMatrix() {
double[/*language*/][/*lmChar*/][/*glyph*/] counts = new double[numLanguages][numChars][numGlyphs];
for (int language = 0; language < numLanguages; ++language) {
for (int lmChar = 0; lmChar < numChars; ++lmChar) {
for (int glyph = 0; glyph < numGlyphs; ++glyph) {
counts[language][lmChar][glyph] = getSmoothingValue(language, lmChar, glyph);
}
}
}
return counts;
}
// private boolean isElided(int glyph) {
// return glyph == GLYPH_TILDE_ELIDED || glyph == GLYPH_FIRST_ELIDED;
// }
public double getSmoothingValue(int language, int lmChar, int glyph) {
// if (glyph != GLYPH_TILDE_ELIDED && prevLmChar != spaceCharIndex) return 0.0; // unless we are trying to elide the current char, the previous char must be marked as a "space" since we don't want to actually condition on it.
// if (prevGlyph == GlyphType.ELISION_TILDE && glyph != GLYPH_TILDE_ELIDED) return 0.0; // an elision-tilde-decorated char must be followed by an elision
// if (glyph == GLYPH_TILDE_ELIDED && !(prevGlyph == GlyphType.ELISION_TILDE || prevGlyph == GlyphType.TILDE_ELIDED)) return 0.0; // an elision must be preceded by an elision-tilde-decorated char
// if (prevGlyph == GlyphType.NORMAL_CHAR && glyph == GLYPH_TILDE_ELIDED) return 0.0; // a normal char may not be followed by an elision
// if (glyph == GLYPH_FIRST_ELIDED && !(prevGlyph == GlyphType.NORMAL_CHAR && prevLmChar == spaceCharIndex)) return 0.0; // for a glyph to be first_elided, it must come after a normal space char
// elided chars can be followed by anything
// if (prevGlyph == GlyphType.ELISION_TILDE && addTilde.get(prevLmChar) == null) return 0.0; // a previous elision-tilde-decorated char must be elision-tilde-decoratable
// if (prevGlyph == GlyphType.TILDE_ELIDED && glyph == GLYPH_TILDE_ELIDED && !canBeElided.contains(prevLmChar)) return 0.0; // an elided previous char must be elidable, if we are trying to elide the current char (since we are conditioning on the actual character)
// if (prevGlyph == GlyphType.TILDE_ELIDED && glyph != GLYPH_TILDE_ELIDED && prevLmChar != spaceCharIndex) return 0.0; // ... otherwise the previous state must be marked as a "space" since we don't want to condition on the actual character
//if (prevGlyph == GlyphType.FIRST_ELIDED && !canBeElided.contains(prevLmChar)) return 0.0; // an first-elided previous char must be elidable (it can't be followed by another elision)
if (!(activeCharacterSets[language].contains(lmChar) || lmChar == hyphenCharIndex)) return 0.0; // lm char must be valid for the language
if (glyph == GLYPH_ELISION_TILDE) {
if (addTilde.get(lmChar) == null) return 0.0; // an elision-tilde-decorated char must be elision-tilde-decoratable
return gsmSmoothingCount * elisionSmoothingCountMultiplier;
}
else if (glyph == GLYPH_TILDE_ELIDED) {
if (!canBeElided.contains(lmChar)) return 0.0; // an elided char must be elidable
return gsmSmoothingCount * elisionSmoothingCountMultiplier;
}
else if (glyph == GLYPH_FIRST_ELIDED) {
if (!canBeElided.contains(lmChar)) return 0.0; // an elided char must be elidable
return gsmSmoothingCount * elisionSmoothingCountMultiplier;
}
else if (glyph == GLYPH_DOUBLED) {
if (!canBeDoubled.contains(lmChar)) return 0.0; // a doubled character has to be doubleable
return gsmSmoothingCount;// * elisionSmoothingCountMultiplier;
}
// else if (glyph == GLYPH_RMRGN_HPHN_DROP) {
// if (lmChar != hyphenCharIndex) return 0.0; // only a hyphen can be hyphen-dropped
// return gsmSmoothingCount;
// }
else if (glyph == GLYPH_ELIDED) {
if (!canBeElided.contains(lmChar)) return 0.0; // an elided char must be elidable
return gsmSmoothingCount;
}
else { // glyph is a normal character
Integer baseChar = diacriticDisregardMap.get(lmChar);
if (baseChar != null && baseChar.equals(glyph))
return gsmSmoothingCount * elisionSmoothingCountMultiplier;
else if (lmChar == sCharIndex && glyph == longsCharIndex)
return gsmSmoothingCount;
else if (lmChar == sCharIndex && (glyph == fCharIndex || glyph == lCharIndex))
return 0.0;
else if (lmChar == hyphenCharIndex && glyph == spaceCharIndex) // so that line-break hyphens can be elided
return gsmSmoothingCount;
else if (canBeReplaced.contains(lmChar) && validSubstitutionChars.contains(glyph) && activeCharacterSets[language].contains(glyph))
return gsmSmoothingCount;
else if (lmChar == glyph)
return gsmSmoothingCount;
else
return 0.0;
}
}
/**
* Traverse the sequence of viterbi states, adding counts
*/
public void incrementCounts(double[/*language*/][/*lmChar*/][/*glyph*/] counts, List<DecodeState> fullViterbiStateSeq) {
for (int i = 0; i < fullViterbiStateSeq.size(); ++i) {
TransitionState currTs = fullViterbiStateSeq.get(i).ts;
TransitionStateType currType = currTs.getType();
if (currType == TransitionStateType.TMPL) {
int language = currTs.getLanguageIndex();
if (language >= 0) {
int lmChar = currTs.getLmCharIndex();
int glyph = glyphIndex(currTs.getGlyphChar());
counts[language][lmChar][glyph] += 1;
}
}
else if (currType == TransitionStateType.RMRGN_HPHN_INIT) {
int language = currTs.getLanguageIndex();
if (language >= 0) {
GlyphChar currGlyphChar = currTs.getGlyphChar();
if (currGlyphChar.templateCharIndex == spaceCharIndex) { // line-break hyphen was elided
int glyph = glyphIndex(currGlyphChar);
counts[language][hyphenCharIndex][glyph] += 1;
}
}
}
}
}
private int glyphIndex(GlyphChar glyphChar) {
return glyphChar.glyphType == GlyphType.NORMAL_CHAR ? glyphChar.templateCharIndex : (numChars + glyphChar.glyphType.ordinal());
}
public BasicGlyphSubstitutionModel make(double[/*language*/][/*lmChar*/][/*glyph*/] counts, int iter, int batchId) {
// Normalize counts to get probabilities
double[/*language*/][/*lmChar*/][/*glyph*/] probs = new double[numLanguages][numChars][numGlyphs];
for (int language = 0; language < numLanguages; ++language) {
for (int prevLmChar = 0; prevLmChar < numChars; ++prevLmChar) {
for (int lmChar = 0; lmChar < numChars; ++lmChar) {
double sum = ArrayHelper.sum(counts[language][lmChar]);
for (int glyph = 0; glyph < numGlyphs; ++glyph) {
double c = counts[language][lmChar][glyph];
double p = (c > 1e-9 ? (c / sum) : 0.0);
probs[language][lmChar][glyph] = p;
}
}
}
}
//System.out.println("Writing out GSM information.");
//synchronized (this) { printGsmProbs3(numLanguages, numChars, numGlyphs, counts, probs, iter, batchId, gsmPrintoutFilepath(iter, batchId)); }
return new BasicGlyphSubstitutionModel(probs, gsmPower, langIndexer, charIndexer);
}
public BasicGlyphSubstitutionModel makeForEval(double[/*language*/][/*lmChar*/][/*glyph*/] counts, int iter, int batchId) {
return makeForEval(counts, iter, batchId, minCountsForEvalGsm);
}
public BasicGlyphSubstitutionModel makeForEval(double[/*language*/][/*lmChar*/][/*glyph*/] counts, int iter, int batchId, double minCountsForEvalGsm) {
if (minCountsForEvalGsm < 1) {
System.out.println("Estimating parameters of a new Glyph Substitution Model. Iter: "+iter+", batch: "+batchId);
return make(counts, iter, batchId);
}
else {
// Normalize counts to get probabilities
double[/*language*/][/*lmChar*/][/*glyph*/] evalCounts = new double[numLanguages][numChars][numGlyphs];
double[/*language*/][/*lmChar*/][/*glyph*/] probs = new double[numLanguages][numChars][numGlyphs];
for (int language = 0; language < numLanguages; ++language) {
for (int lmChar = 0; lmChar < numChars; ++lmChar) {
for (int glyph = 0; glyph < numGlyphs; ++glyph) {
double trueCount = counts[language][lmChar][glyph] - gsmSmoothingCount;
if (trueCount < 1e-9)
evalCounts[language][lmChar][glyph] = 0;
else if (trueCount < minCountsForEvalGsm-1e-9)
evalCounts[language][lmChar][glyph] = 0;
else
evalCounts[language][lmChar][glyph] = trueCount;
}
double sum = ArrayHelper.sum(evalCounts[language][lmChar]);
for (int glyph = 0; glyph < numGlyphs; ++glyph) {
double c = evalCounts[language][lmChar][glyph];
double p = (c > 1e-9 ? (c / sum) : 0.0);
probs[language][lmChar][glyph] = p;
}
}
}
//System.out.println("Writing out GSM information.");
//synchronized (this) { printGsmProbs3(numLanguages, numChars, numGlyphs, counts, probs, iter, batchId, gsmPrintoutFilepath(iter, batchId)+"_eval"); }
return new BasicGlyphSubstitutionModel(probs, gsmPower, langIndexer, charIndexer);
}
}
private void printGsmProbs3(int numLanguages, int numChars, int numGlyphs, double[][][] counts, double[][][] probs, int iter, int batchId, String outputFilenameBase) {
Set<String> CHARS_TO_PRINT = setUnion(makeSet(" ","-","a","b","c","d",Charset.LONG_S));
StringBuffer sb = new StringBuffer();
sb.append("language\tlmChar\tglyph\tcount\tminProb\tprob\n");
for (int language = 0; language < numLanguages; ++language) {
String slanguage = langIndexer.getObject(language);
for (int lmChar = 0; lmChar < numChars; ++lmChar) {
String slmChar = charIndexer.getObject(lmChar);
// figure out what the lowest count is, and then exclude things with that count
double lowProb = ArrayHelper.min(probs[language][lmChar]);
for (int glyph = 0; glyph < numGlyphs; ++glyph) {
String sglyph = glyph < numChars ? charIndexer.getObject(glyph) : GlyphType.values()[glyph-numChars].toString();
double p = probs[language][lmChar][glyph];
double c = counts[language][lmChar][glyph];
if (c > gsmSmoothingCount || (CHARS_TO_PRINT.contains(slmChar) && (CHARS_TO_PRINT.contains(sglyph) || glyph >= numChars))) {
//System.out.println("c="+c+", lang="+langIndexer.getObject(language)+"("+language+"), prevGlyphType="+prevGlyph+ ", prevLmChar="+charIndexer.getObject(prevLmChar)+"("+prevLmChar+"), lmChar="+charIndexer.getObject(lmChar)+"("+lmChar+"), glyphChar="+(glyph < numChars ? charIndexer.getObject(glyph) : (glyph == numGlyphs ? "EpsilonTilde": "Elided"))+"("+glyph+"), p="+p+", logp="+Math.log(p));
sb.append(slanguage).append("\t");
sb.append(slmChar).append("\t");
sb.append(sglyph).append("\t");
sb.append(c).append("\t");
sb.append(lowProb).append("\t");
sb.append(p).append("\t");
sb.append("\n");
}
}
}
}
String outputFilename = outputFilenameBase + ".tsv";
System.out.println("Writing info about newly-trained GSM on iteration "+iter+", batch "+batchId+" out to ["+outputFilename+"]");
FileHelper.writeString(outputFilename, sb.toString());
}
private String gsmPrintoutFilepath(int iter, int batchId) {
String preext = "newGSM";
String outputFilenameBase = outputPath + "/gsm/" + preext;
if (iter > 0) outputFilenameBase += "_iter-" + iter;
if (batchId > 0) outputFilenameBase += "_batch-" + batchId;
return outputFilenameBase;
}
}
public Indexer<String> getLangIndexer() { return langIndexer; }
public Indexer<String> getCharIndexer() { return charIndexer; }
}