package edu.cmu.sphinx.linguist.language.ngram;
import java.io.IOException;
import java.util.*;
import edu.cmu.sphinx.linguist.WordSequence;
import edu.cmu.sphinx.linguist.dictionary.Dictionary;
import edu.cmu.sphinx.linguist.dictionary.Word;
import edu.cmu.sphinx.util.LogMath;
import edu.cmu.sphinx.util.props.PropertyException;
import edu.cmu.sphinx.util.props.PropertySheet;
/**
* 3-gram language model that can change its content at runtime.
*
* @author Alexander Solovets
*
*/
public class DynamicTrigramModel implements LanguageModel {
private Dictionary dictionary;
private final Set<String> vocabulary;
private int maxDepth;
private float unigramWeight;
private List<String> sentences;
private Map<WordSequence, Float> logProbs;
private Map<WordSequence, Float> logBackoffs;
public DynamicTrigramModel() {
vocabulary = new HashSet<String>();
logProbs = new HashMap<WordSequence, Float>();
logBackoffs = new HashMap<WordSequence, Float>();
}
public DynamicTrigramModel(Dictionary dictionary) {
this();
this.dictionary = dictionary;
}
public void newProperties(PropertySheet ps) throws PropertyException {
dictionary = (Dictionary) ps.getComponent(PROP_DICTIONARY);
maxDepth = ps.getInt(PROP_MAX_DEPTH);
unigramWeight = ps.getFloat(PROP_UNIGRAM_WEIGHT);
}
public void allocate() throws IOException {
vocabulary.clear();
logProbs.clear();
logBackoffs.clear();
HashMap<WordSequence, Integer> unigrams = new HashMap<WordSequence, Integer>();
HashMap<WordSequence, Integer> bigrams = new HashMap<WordSequence, Integer>();
HashMap<WordSequence, Integer> trigrams = new HashMap<WordSequence, Integer>();
int wordCount = 0;
for (String sentence : sentences) {
String[] textWords = sentence.split("\\s+");
List<Word> words = new ArrayList<Word>();
words.add(dictionary.getSentenceStartWord());
for (String wordString : textWords) {
if (wordString.length() == 0) {
continue;
}
vocabulary.add(wordString);
Word word = dictionary.getWord(wordString);
if (word == null) {
words.add(Word.UNKNOWN);
} else {
words.add(word);
}
}
words.add(dictionary.getSentenceEndWord());
if (words.size() > 0) {
addSequence(unigrams, new WordSequence(words.get(0)));
wordCount++;
}
if (words.size() > 1) {
wordCount++;
addSequence(unigrams, new WordSequence(words.get(1)));
addSequence(bigrams, new WordSequence(words.get(0), words.get(1)));
}
for (int i = 2; i < words.size(); ++i) {
wordCount++;
addSequence(unigrams, new WordSequence(words.get(i)));
addSequence(bigrams, new WordSequence(words.get(i - 1), words.get(i)));
addSequence(trigrams, new WordSequence(words.get(i - 2), words.get(i - 1), words.get(i)));
}
}
float discount = .5f;
float deflate = 1 - discount;
Map<WordSequence, Float> uniprobs = new HashMap<WordSequence, Float>();
for (Map.Entry<WordSequence, Integer> e : unigrams.entrySet()) {
uniprobs.put(e.getKey(), (float) e.getValue() * deflate / wordCount);
}
LogMath lmath = LogMath.getLogMath();
float logUnigramWeight = lmath.linearToLog(unigramWeight);
float invLogUnigramWeight = lmath.linearToLog(1 - unigramWeight);
float logUniformProb = -lmath.linearToLog(uniprobs.size());
Set<WordSequence> sorted1grams = new TreeSet<WordSequence>(unigrams.keySet());
Iterator<WordSequence> iter = new TreeSet<WordSequence>(bigrams.keySet()).iterator();
WordSequence ws = iter.hasNext() ? iter.next() : null;
for (WordSequence unigram : sorted1grams) {
float p = lmath.linearToLog(uniprobs.get(unigram));
p += logUnigramWeight;
p = lmath.addAsLinear(p, logUniformProb + invLogUnigramWeight);
logProbs.put(unigram, p);
float sum = 0.f;
while (ws != null) {
int cmp = ws.getOldest().compareTo(unigram);
if (cmp > 0) {
break;
}
if (cmp == 0) {
sum += uniprobs.get(ws.getNewest());
}
ws = iter.hasNext() ? iter.next() : null;
}
logBackoffs.put(unigram, lmath.linearToLog(discount / (1 - sum)));
}
Map<WordSequence, Float> biprobs = new HashMap<WordSequence, Float>();
for (Map.Entry<WordSequence, Integer> entry : bigrams.entrySet()) {
int unigramCount = unigrams.get(entry.getKey().getOldest());
biprobs.put(entry.getKey(), entry.getValue() * deflate / unigramCount);
}
Set<WordSequence> sorted2grams = new TreeSet<WordSequence>(bigrams.keySet());
iter = new TreeSet<WordSequence>(trigrams.keySet()).iterator();
ws = iter.hasNext() ? iter.next() : null;
for (WordSequence biword : sorted2grams) {
logProbs.put(biword, lmath.linearToLog(biprobs.get(biword)));
float sum = 0.f;
while (ws != null) {
int cmp = ws.getOldest().compareTo(biword);
if (cmp > 0) {
break;
}
if (cmp == 0) {
sum += biprobs.get(ws.getNewest());
}
ws = iter.hasNext() ? iter.next() : null;
}
logBackoffs.put(biword, lmath.linearToLog(discount / (1 - sum)));
}
for (Map.Entry<WordSequence, Integer> e : trigrams.entrySet()) {
float p = e.getValue() * deflate;
p /= bigrams.get(e.getKey().getOldest());
logProbs.put(e.getKey(), lmath.linearToLog(p));
}
}
private void addSequence(HashMap<WordSequence, Integer> grams, WordSequence wordSequence) {
Integer count = grams.get(wordSequence);
if (count != null) {
grams.put(wordSequence, count + 1);
} else {
grams.put(wordSequence, 1);
}
}
public void deallocate() throws IOException {
}
public float getProbability(WordSequence wordSequence) {
float prob;
if (logProbs.containsKey(wordSequence)) {
prob = logProbs.get(wordSequence);
} else if (wordSequence.size() > 1) {
Float backoff = logBackoffs.get(wordSequence.getOldest());
if (backoff == null) {
prob = LogMath.LOG_ONE + getProbability(wordSequence.getNewest());
} else {
prob = backoff + getProbability(wordSequence.getNewest());
}
} else {
prob = LogMath.LOG_ZERO;
}
return prob;
}
public float getSmear(WordSequence wordSequence) {
// TODO: implement
return 0;
}
public Set<String> getVocabulary() {
return vocabulary;
}
public int getMaxDepth() {
return maxDepth;
}
@Override
public void onUtteranceEnd() {
//TODO not implemented
}
public void setText(List<String> sentences) {
this.sentences = sentences;
}
}