package edu.cmu.sphinx.linguist.language.ngram.trie;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.PrintWriter;
import java.net.URL;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import edu.cmu.sphinx.linguist.WordSequence;
import edu.cmu.sphinx.linguist.dictionary.Dictionary;
import edu.cmu.sphinx.linguist.dictionary.Word;
import edu.cmu.sphinx.linguist.language.ngram.LanguageModel;
import edu.cmu.sphinx.linguist.util.LRUCache;
import edu.cmu.sphinx.util.LogMath;
import edu.cmu.sphinx.util.TimerPool;
import edu.cmu.sphinx.util.props.ConfigurationManagerUtils;
import edu.cmu.sphinx.util.props.PropertyException;
import edu.cmu.sphinx.util.props.PropertySheet;
import edu.cmu.sphinx.util.props.S4Boolean;
import edu.cmu.sphinx.util.props.S4Double;
import edu.cmu.sphinx.util.props.S4Integer;
import edu.cmu.sphinx.util.props.S4String;
/**
* Language model that uses a binary NGram language model file ("binary trie file")
* generated by the SphinxBase sphinx_lm_convert.
*/
public class NgramTrieModel implements LanguageModel {
/**
* The property for the name of the file that logs all the queried N-grams.
* If this property is set to null, it means that the queried N-grams are
* not logged.
*/
@S4String(mandatory = false)
public static final String PROP_QUERY_LOG_FILE = "queryLogFile";
/** The property that defines that maximum number of ngrams to be cached */
@S4Integer(defaultValue = 100000)
public static final String PROP_NGRAM_CACHE_SIZE = "ngramCacheSize";
/**
* The property that controls whether the ngram caches are cleared after
* every utterance
*/
@S4Boolean(defaultValue = false)
public static final String PROP_CLEAR_CACHES_AFTER_UTTERANCE = "clearCachesAfterUtterance";
/** The property that defines the language weight for the search */
@S4Double(defaultValue = 1.0f)
public final static String PROP_LANGUAGE_WEIGHT = "languageWeight";
/**
* The property that controls whether or not the language model will apply
* the language weight and word insertion probability
*/
@S4Boolean(defaultValue = false)
public final static String PROP_APPLY_LANGUAGE_WEIGHT_AND_WIP = "applyLanguageWeightAndWip";
/** Word insertion probability property */
@S4Double(defaultValue = 1.0f)
public final static String PROP_WORD_INSERTION_PROBABILITY = "wordInsertionProbability";
// ------------------------------
// Configuration data
// ------------------------------
URL location;
protected Logger logger;
protected LogMath logMath;
protected int maxDepth;
protected int curDepth;
protected int[] counts;
protected int ngramCacheSize;
protected boolean clearCacheAfterUtterance;
protected Dictionary dictionary;
protected String format;
protected boolean applyLanguageWeightAndWip;
protected float languageWeight;
protected float unigramWeight;
protected float logWip;
// -------------------------------
// Statistics
// -------------------------------
protected String ngramLogFile;
private int ngramMisses;
private int ngramHits;
// -------------------------------
// subcomponents
// --------------------------------
private PrintWriter logFile;
//-----------------------------
// Trie structure
//-----------------------------
protected TrieUnigram[] unigrams;
protected String[] words;
protected NgramTrieQuant quant;
protected NgramTrie trie;
//-----------------------------
// Working data
//-----------------------------
protected Map<Word, Integer> unigramIDMap;
private LRUCache<WordSequence, Float> ngramProbCache;
public NgramTrieModel(String format, URL location, String ngramLogFile,
int maxNGramCacheSize, boolean clearCacheAfterUtterance,
int maxDepth, Dictionary dictionary,
boolean applyLanguageWeightAndWip, float languageWeight,
double wip, float unigramWeight) {
logger = Logger.getLogger(getClass().getName());
this.format = format;
this.location = location;
this.ngramLogFile = ngramLogFile;
this.ngramCacheSize = maxNGramCacheSize;
this.clearCacheAfterUtterance = clearCacheAfterUtterance;
this.maxDepth = maxDepth;
logMath = LogMath.getLogMath();
this.dictionary = dictionary;
this.applyLanguageWeightAndWip = applyLanguageWeightAndWip;
this.languageWeight = languageWeight;
this.logWip = logMath.linearToLog(wip);
this.unigramWeight = unigramWeight;
}
public NgramTrieModel() {
}
/*
* (non-Javadoc)
*
* @see edu.cmu.sphinx.util.props.Configurable#newProperties(edu.cmu.sphinx.
* util.props.PropertySheet)
*/
@Override
public void newProperties(PropertySheet ps) throws PropertyException {
logger = ps.getLogger();
logMath = LogMath.getLogMath();
location = ConfigurationManagerUtils.getResource(PROP_LOCATION, ps);
ngramLogFile = ps.getString(PROP_QUERY_LOG_FILE);
maxDepth = ps.getInt(LanguageModel.PROP_MAX_DEPTH);
ngramCacheSize = ps.getInt(PROP_NGRAM_CACHE_SIZE);
clearCacheAfterUtterance = ps
.getBoolean(PROP_CLEAR_CACHES_AFTER_UTTERANCE);
dictionary = (Dictionary) ps.getComponent(PROP_DICTIONARY);
applyLanguageWeightAndWip = ps
.getBoolean(PROP_APPLY_LANGUAGE_WEIGHT_AND_WIP);
languageWeight = ps.getFloat(PROP_LANGUAGE_WEIGHT);
logWip = logMath.linearToLog(ps.getDouble(PROP_WORD_INSERTION_PROBABILITY));
unigramWeight = ps.getFloat(PROP_UNIGRAM_WEIGHT);
}
/**
* Builds the map from unigram to unigramID. Also finds the startWordID and
* endWordID.
*
* @param dictionary
* */
private void buildUnigramIDMap() {
int missingWords = 0;
if (unigramIDMap == null)
unigramIDMap = new HashMap<Word, Integer>();
for (int i = 0; i < words.length; i++) {
Word word = dictionary.getWord(words[i]);
if (word == null) {
logger.warning("The dictionary is missing a phonetic transcription for the word '"
+ words[i] + "'");
missingWords++;
}
unigramIDMap.put(word, i);
if (logger.isLoggable(Level.FINE))
logger.fine("Word: " + word);
}
if (missingWords > 0)
logger.warning("Dictionary is missing " + missingWords
+ " words that are contained in the language model.");
}
/*
* (non-Javadoc)
*
* @see edu.cmu.sphinx.linguist.language.ngram.LanguageModel#allocate()
*/
//@SuppressWarnings("unchecked")
public void allocate() throws IOException {
TimerPool.getTimer(this, "Load LM").start();
logger.info("Loading n-gram language model from: " + location);
// create the log file if specified
if (ngramLogFile != null)
logFile = new PrintWriter(new FileOutputStream(ngramLogFile));
BinaryLoader loader;
if (location.getProtocol() == null
|| location.getProtocol().equals("file")) {
try {
loader = new BinaryLoader(new File(location.toURI()));
} catch (Exception ex) {
loader = new BinaryLoader(new File(location.getPath()));
}
} else {
loader = new BinaryLoader(location);
}
loader.verifyHeader();
counts = loader.readCounts();
if (maxDepth <= 0 || maxDepth > counts.length)
maxDepth = counts.length;
if (maxDepth > 1) {
quant = loader.readQuant(maxDepth);
}
unigrams = loader.readUnigrams(counts[0]);
if (maxDepth > 1) {
trie = new NgramTrie(counts, quant.getProbBoSize(), quant.getProbSize());
loader.readTrieByteArr(trie.getMem());
}
//string words can be read here
words = loader.readWords(counts[0]);
buildUnigramIDMap();
ngramProbCache = new LRUCache<WordSequence, Float>(ngramCacheSize);
loader.close();
TimerPool.getTimer(this, "Load LM").stop();
}
/*
* (non-Javadoc)
*
* @see edu.cmu.sphinx.linguist.language.ngram.LanguageModel#deallocate()
*/
@Override
public void deallocate() throws IOException {
if (logFile != null) {
logFile.flush();
}
}
/**
* Selects ngram of highest order available for specified word sequence
* and extracts probability for it
* @param wordSequence - word sequence to score
* @param range - range to look bigram in
* @param prob - probability of unigram
* @return probability of of highest order ngram available
*/
private float getAvailableProb(WordSequence wordSequence, TrieRange range, float prob) {
if (!range.isSearchable()) return prob;
for (int reverseOrderMinusTwo = wordSequence.size() - 2; reverseOrderMinusTwo >= 0; reverseOrderMinusTwo--) {
int orderMinusTwo = wordSequence.size() - 2 - reverseOrderMinusTwo;
if (orderMinusTwo + 1 == maxDepth) break;
int wordId = unigramIDMap.get(wordSequence.getWord(reverseOrderMinusTwo));
float updatedProb = trie.readNgramProb(wordId, orderMinusTwo, range, quant);
if (!range.getFound()) break;
prob = updatedProb;
curDepth++;
if (!range.isSearchable()) break;
}
return prob;
}
/**
* Selects backoffs for part of word sequence
* unused in {@link #getAvailableProb(WordSequence, TrieRange, float) getAvailableProb}
* Amount of unused words is specified by local variable curDepth
* @param wordSequence - full word sequence that is scored
* @return backoff
*/
private float getAvailableBackoff(WordSequence wordSequence) {
float backoff = 0.0f;
int wordsNum = wordSequence.size();
int wordId = unigramIDMap.get(wordSequence.getWord(wordsNum - 2));
TrieRange range = new TrieRange(unigrams[wordId].next, unigrams[wordId + 1].next);
if (curDepth == 1) {
backoff += unigrams[wordId].backoff;
}
int sequenceIdx, orderMinusTwo;
for (sequenceIdx = wordsNum - 3, orderMinusTwo = 0; sequenceIdx >= 0; sequenceIdx--, orderMinusTwo++) {
int tmpWordId = unigramIDMap.get(wordSequence.getWord(sequenceIdx));
float tmpBackoff = trie.readNgramBackoff(tmpWordId, orderMinusTwo, range, quant);
if (!range.getFound()) break;
backoff += tmpBackoff;
if (!range.isSearchable()) break;
}
return backoff;
}
/**
* extracts raw word sequence probability without using caching,
* making fresh LM trie traversing
* @param wordSequence - sequence of words to get probability for
* @return probability of specialized sequence of words
*/
private float getProbabilityRaw(WordSequence wordSequence) {
int wordsNum = wordSequence.size();
int wordId = unigramIDMap.get(wordSequence.getWord(wordsNum - 1));
TrieRange range = new TrieRange(unigrams[wordId].next, unigrams[wordId + 1].next);
float prob = unigrams[wordId].prob;
curDepth = 1;
if (wordsNum == 1)
return prob;
//find prob of ngrams of higher order if any
prob = getAvailableProb(wordSequence, range, prob);
if (curDepth < wordsNum) {
//use backoff for rest of ngram
prob += getAvailableBackoff(wordSequence);
}
return prob;
}
/**
* Applies weights to scores produced by language model
* @param score - raw score
* @return weighted score
*/
private float applyWeights(float score) {
//TODO ignores unigram weight. Apply or remove from properties
if (applyLanguageWeightAndWip)
return score * languageWeight + logWip;
return score;
}
/**
* Gets the ngram probability of the word sequence represented by the word
* list
*
* @param wordSequence - the word sequence
* @return the probability of the word sequence.
* Probability is in logMath log base
*/
@Override
public float getProbability(WordSequence wordSequence) {
int numberWords = wordSequence.size();
if (numberWords > maxDepth) {
throw new Error("Unsupported NGram: " + wordSequence.size());
}
if (numberWords == maxDepth) {
Float probability = ngramProbCache.get(wordSequence);
if (probability != null) {
ngramHits++;
return probability;
}
ngramMisses++;
}
float probability = applyWeights(getProbabilityRaw(wordSequence));
if (numberWords == maxDepth)
ngramProbCache.put(wordSequence, probability);
if (logFile != null)
logFile.println(wordSequence.toString().replace("][", " ") + " : "
+ Float.toString(probability));
return probability;
}
/**
* Gets the smear term for the given wordSequence
*
* @param wordSequence - the word sequence
* @return the smear term associated with this word sequence
*/
@Override
public float getSmear(WordSequence wordSequence) {
//TODO not implemented
return 0;
}
/**
* Returns the set of words in the language model. The set is unmodifiable.
*
* @return the unmodifiable set of words
*/
@Override
public Set<String> getVocabulary() {
Set<String> vocabulary = new HashSet<String>(Arrays.asList(words));
return Collections.unmodifiableSet(vocabulary);
}
/**
* Returns the number of times when a NGram is queried, but there is no such
* NGram in the LM (in which case it uses the backoff probabilities).
*
* @return the number of NGram misses
*/
public int getNGramMisses() {
return ngramMisses;
}
/**
* Returns the number of NGram hits.
*
* @return the number of NGram hits
*/
public int getNGramHits() {
return ngramHits;
}
/**
* Returns the maximum depth of the language model
*
* @return the maximum depth of the language model
*/
@Override
public int getMaxDepth() {
return maxDepth;
}
/** Clears the various N-gram caches. */
private void clearCache() {
logger.info("LM Cache Size: " + ngramProbCache.size() + " Hits: "
+ ngramHits + " Misses: " + ngramMisses);
if (clearCacheAfterUtterance) {
ngramProbCache = new LRUCache<WordSequence, Float>(ngramCacheSize);
}
}
/**
* Called by lexicon after recognition.
* Used to clear caches
*/
public void onUtteranceEnd() {
clearCache();
if (logFile != null) {
logFile.println("<END_UTT>");
logFile.flush();
}
}
/**
* Structure that keeps unigram instance data in trie.
* Language model contains sorted array of TrieUnigram,
* where index in array is wordId
*/
public static class TrieUnigram {
public float prob;
public float backoff;
public int next;
}
/**
* Structure to keep ngram indexes range for trie traversal
*/
public static class TrieRange {
int begin;
int end;
boolean found;
TrieRange(int begin, int end) {
this.begin = begin;
this.end = end;
found = true;
}
int getWidth() {
return end - begin;
}
void setFound(boolean found) {
this.found = found;
}
boolean getFound() {
return found;
}
boolean isSearchable() {
return getWidth() > 0;
}
}
}