/* This file is part of the Joshua Machine Translation System. * * Joshua is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 * of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free * Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, * MA 02111-1307 USA */ package joshua.decoder.ff.lm.buildin_lm; import java.io.File; import java.io.FileNotFoundException; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.LinkedList; import java.util.Map; import java.util.Scanner; import java.util.logging.Level; import java.util.logging.Logger; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.JoshuaConfiguration; import joshua.decoder.ff.lm.AbstractLM; import joshua.decoder.ff.lm.ArpaFile; import joshua.decoder.ff.lm.ArpaNgram; import joshua.util.Bits; import joshua.util.Regex; /** * Relatively memory-compact language model * stored as a reversed-word-order trie. * <p> * The trie itself represents language model context. * <p> * Conceptually, each node in the trie stores a map * from conditioning word to log probability. * <p> * Additionally, each node in the trie stores * the backoff weight for that context. * * @author Lane Schwartz * @see <a href="http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html">SRILM ngram-discount documentation</a> */ public class TrieLM extends AbstractLM { //DefaultNGramLanguageModel { /** Logger for this class. */ private static Logger logger = Logger.getLogger(TrieLM.class.getName()); /** * Node ID for the root node. */ private static final int ROOT_NODE_ID = 0; /** * Maps from (node id, word id for child) --> node id of child. */ private final Map<Long,Integer> children; /** * Maps from (node id, word id for lookup word) --> * log prob of lookup word given context * * (the context is defined by where you are in the tree). */ private final Map<Long,Float> logProbs; /** * Maps from (node id) --> * backoff weight for that context * * (the context is defined by where you are in the tree). */ private final Map<Integer,Float> backoffs; public TrieLM(SymbolTable vocab, String file) throws FileNotFoundException { this(new ArpaFile(file,vocab)); } /** * Constructs a language model object from the specified ARPA file. * * @param arpaFile * @throws FileNotFoundException */ public TrieLM(ArpaFile arpaFile) throws FileNotFoundException { super(arpaFile.getVocab(), arpaFile.getOrder()); int ngramCounts = arpaFile.size(); if (logger.isLoggable(Level.FINE)) logger.fine("ARPA file contains " + ngramCounts + " n-grams"); this.children = new HashMap<Long,Integer>(ngramCounts); this.logProbs = new HashMap<Long,Float>(ngramCounts); this.backoffs = new HashMap<Integer,Float>(ngramCounts); int nodeCounter = 0; int lineNumber = 0; for (ArpaNgram ngram : arpaFile) { lineNumber += 1; if (lineNumber%100000==0) logger.info("Line: " + lineNumber); if (logger.isLoggable(Level.FINEST)) logger.finest(ngram.order() + "-gram: (" + ngram.getWord() + " | " + Arrays.toString(ngram.getContext()) + ")"); int word = ngram.getWord(); int[] context = ngram.getContext(); { // Find where the log prob should be stored int contextNodeID = ROOT_NODE_ID; { for (int i=context.length-1; i>=0; i--) { long key = Bits.encodeAsLong(contextNodeID, context[i]); int childID; if (children.containsKey(key)) { childID = children.get(key); } else { childID = ++nodeCounter; if (logger.isLoggable(Level.FINEST)) logger.finest("children.put(" + contextNodeID + ":"+context[i] + " , " + childID + ")"); children.put(key, childID); } contextNodeID = childID; } } // Store the log prob for this n-gram at this node in the trie { long key = Bits.encodeAsLong(contextNodeID, word); float logProb = ngram.getValue(); if (logger.isLoggable(Level.FINEST)) logger.finest("logProbs.put(" + contextNodeID + ":"+word + " , " + logProb); this.logProbs.put(key, logProb); } } { // Find where the backoff should be stored int backoffNodeID = ROOT_NODE_ID; { long backoffNodeKey = Bits.encodeAsLong(backoffNodeID, word); int wordChildID; if (children.containsKey(backoffNodeKey)) { wordChildID = children.get(backoffNodeKey); } else { wordChildID = ++nodeCounter; if (logger.isLoggable(Level.FINEST)) logger.finest("children.put(" + backoffNodeID + ":"+word + " , " + wordChildID + ")"); children.put(backoffNodeKey, wordChildID); } backoffNodeID = wordChildID; for (int i=context.length-1; i>=0; i--) { long key = Bits.encodeAsLong(backoffNodeID, context[i]); int childID; if (children.containsKey(key)) { childID = children.get(key); } else { childID = ++nodeCounter; if (logger.isLoggable(Level.FINEST)) logger.finest("children.put(" + backoffNodeID + ":"+context[i] + " , " + childID + ")"); children.put(key, childID); } backoffNodeID = childID; } } // Store the backoff for this n-gram at this node in the trie { float backoff = ngram.getBackoff(); if (logger.isLoggable(Level.FINEST)) logger.finest("backoffs.put(" + backoffNodeID + ":" +word+" , " + backoff + ")"); this.backoffs.put(backoffNodeID, backoff); } } } } @Override protected double logProbabilityOfBackoffState_helper( int[] ngram, int order, int qtyAdditionalBackoffWeight ) { throw new UnsupportedOperationException("probabilityOfBackoffState_helper undefined for TrieLM"); } @Override protected double ngramLogProbability_helper(int[] ngram, int order) { // @Override // public double ngramLogProbability(int[] ngram, int order) { float logProb = (float) -JoshuaConfiguration.lm_ceiling_cost;//Float.NEGATIVE_INFINITY; // log(0.0f) float backoff = 0.0f; // log(1.0f) int i = ngram.length - 1; int word = ngram[i]; i -= 1; int nodeID = ROOT_NODE_ID; while (true) { { long key = Bits.encodeAsLong(nodeID, word); if (logProbs.containsKey(key)) { logProb = logProbs.get(key); backoff = 0.0f; // log(0.0f) } } if (i < 0) { break; } { long key = Bits.encodeAsLong(nodeID, ngram[i]); if (children.containsKey(key)) { nodeID = children.get(key); backoff += backoffs.get(nodeID); i -= 1; } else { break; } } } double result = logProb + backoff; if (result < -JoshuaConfiguration.lm_ceiling_cost) { result = -JoshuaConfiguration.lm_ceiling_cost; } return result; } public Map<Long,Integer> getChildren() { return this.children; } public static void main(String[] args) throws IOException { logger.info("Constructing ARPA file"); ArpaFile arpaFile = new ArpaFile(args[0]); logger.info("Getting symbol table"); SymbolTable vocab = arpaFile.getVocab(); logger.info("Constructing TrieLM"); TrieLM lm = new TrieLM(arpaFile); int n = Integer.valueOf(args[2]); logger.info("N-gram order will be " + n); Scanner scanner = new Scanner(new File(args[1])); LinkedList<String> wordList = new LinkedList<String>(); LinkedList<String> window = new LinkedList<String>(); logger.info("Starting to scan " + args[1]); while (scanner.hasNext()) { logger.info("Getting next line..."); String line = scanner.nextLine(); logger.info("Line: " + line); String[] words = Regex.spaces.split(line); wordList.clear(); wordList.add("<s>"); for (String word : words) { wordList.add(word); } wordList.add("</s>"); ArrayList<Integer> sentence = new ArrayList<Integer>(); // int[] ids = new int[wordList.size()]; for (int i=0, size=wordList.size(); i<size; i++) { sentence.add(vocab.getID(wordList.get(i))); // ids[i] = ; } while (! wordList.isEmpty()) { window.clear(); { int i=0; for (String word : wordList) { if (i>=n) break; window.add(word); i++; } wordList.remove(); } { int i=0; int[] wordIDs = new int[window.size()]; for (String word : window) { wordIDs[i] = vocab.getID(word); i++; } logger.info("logProb " + window.toString() + " = " + lm.ngramLogProbability(wordIDs, n)); } } double logProb = lm.sentenceLogProbability(sentence, n, 2);//.ngramLogProbability(ids, n); double prob = Math.exp(logProb); logger.info("Total logProb = " + logProb); logger.info("Total prob = " + prob); } } }