package joshua.decoder.ff.lm; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.ff.state_maintenance.DPState; import joshua.decoder.ff.state_maintenance.NgramDPState; import joshua.decoder.ff.tm.Rule; import joshua.decoder.hypergraph.HGNode; import joshua.decoder.hypergraph.HyperEdge; import joshua.util.Ngram; public class NgramExtractor { private int ngramStateID; private int baselineLMOrder; private SymbolTable symbolTable; private boolean useIntegerNgram; private static String START_SYM="<s>"; private int START_SYM_ID; private static String STOP_SYM="</s>"; private int STOP_SYM_ID; //default is useIntegerNgram public NgramExtractor(SymbolTable symbolTable, int ngramStateID, int baselineLMOrder){ this(symbolTable, ngramStateID, true, baselineLMOrder); } public NgramExtractor(SymbolTable symbolTable, int ngramStateID, boolean useIntegerNgram, int baselineLMOrder){ this.symbolTable = symbolTable; this.ngramStateID = ngramStateID; this.useIntegerNgram = useIntegerNgram; this.baselineLMOrder = baselineLMOrder; this.START_SYM_ID = this.symbolTable.addTerminal(START_SYM); this.STOP_SYM_ID = this.symbolTable.addTerminal(STOP_SYM); } /**for generative model, should set startNgramOrder=endNgramOrder*/ public HashMap<String,Integer> getTransitionNgrams(HyperEdge dt, int startNgramOrder, int endNgramOrder){ return getTransitionNgrams(dt.getRule(), dt.getAntNodes(), startNgramOrder, endNgramOrder); } /**for generative model, should set startNgramOrder=endNgramOrder*/ public HashMap<String,Integer> getTransitionNgrams(Rule rule, List<HGNode> antNodes, int startNgramOrder, int endNgramOrder){ return computeTransitionNgrams(rule, antNodes, startNgramOrder, endNgramOrder); } /**does not work for generative model*/ public HashMap<String,Integer> getRuleNgrams(Rule rule, int startNgramOrder, int endNgramOrder){ return computeRuleNgrams(rule, startNgramOrder, endNgramOrder); } /**for generative model, should always set startNgramOrder=1 to allow partional ngram*/ public HashMap<String,Integer> getFutureNgrams(Rule rule, DPState curDPState, int startNgramOrder, int endNgramOrder){ //TODO: do not consider <s> and </s> boolean addStart = false; boolean addEnd = false; return computeFutureNgrams( (NgramDPState)curDPState, startNgramOrder, endNgramOrder, addStart, addEnd); } /**for generative model, should always set startNgramOrder=1 to allow partional ngram*/ public HashMap<String,Integer> getFinalTransitionNgrams(HyperEdge edge, int startNgramOrder, int endNgramOrder){ return getFinalTransitionNgrams(edge.getAntNodes().get(0), startNgramOrder, endNgramOrder); } /**for generative model, should always set startNgramOrder=1 to allow partional ngram*/ public HashMap<String,Integer> getFinalTransitionNgrams(HGNode antNode, int startNgramOrder, int endNgramOrder){ return computeFinalTransitionNgrams(antNode, startNgramOrder, endNgramOrder); } /**work for both generative and discriminative model * */ //TODO: consider speed up this function private HashMap<String,Integer> computeTransitionNgrams(Rule rule, List<HGNode> antNodes, int startNgramOrder, int endNgramOrder){ if(baselineLMOrder < endNgramOrder){ System.out.println("baselineLMOrder is too small"); System.exit(0); } //==== hyperedges not under "goal item" HashMap<String, Integer> newNgramCounts = new HashMap<String, Integer>();//new ngrams created due to the combination HashMap<String, Integer> oldNgramCounts = new HashMap<String, Integer>();//the ngram that has already been computed int[] enWords = rule.getEnglish(); //a continous sequence of words due to combination; the sequence stops whenever the right-lm-state jumps in (i.e., having eclipsed words) List<Integer> words = new ArrayList<Integer>(); for(int c=0; c<enWords.length; c++){ int curID = enWords[c]; if(symbolTable.isNonterminal(curID)==true){//non-terminal words //== get the left and right context int index = symbolTable.getTargetNonterminalIndex(curID); HGNode antNode = antNodes.get(index); NgramDPState state = (NgramDPState) antNode.getDPState(this.ngramStateID); //System.out.println("lm_feat_is: " + this.lm_feat_id + " ; state is: " + state); List<Integer> leftContext = state.getLeftLMStateWords(); List<Integer> rightContext = state.getRightLMStateWords(); if (leftContext.size() != rightContext.size()) { System.out.println("getAllNgrams: left and right contexts have unequal lengths"); System.exit(1); } //== find new ngrams created for(int t : leftContext) words.add(t); this.getNgrams(oldNgramCounts, startNgramOrder, endNgramOrder, leftContext); if(rightContext.size()>=baselineLMOrder-1){//the right and left are NOT overlapping this.getNgrams(oldNgramCounts, startNgramOrder, endNgramOrder, rightContext); this.getNgrams(newNgramCounts, startNgramOrder, endNgramOrder, words); //start a new chunk; the sequence stops whenever the right-lm-state jumps in (i.e., having eclipsed words) words.clear(); for(int t : rightContext) words.add(t); } }else{//terminal words words.add(curID); } } this.getNgrams(newNgramCounts, startNgramOrder, endNgramOrder, words); //=== now deduct ngram counts HashMap<String, Integer> res = new HashMap<String, Integer>(); for(Map.Entry<String, Integer> entry : newNgramCounts.entrySet()){ String ngram = entry.getKey(); int finalCount = entry.getValue(); if(oldNgramCounts.containsKey(ngram)){ finalCount -= oldNgramCounts.get(ngram); if(finalCount<0){ System.out.println("error: negative count for ngram: "+ entry.getValue() +"; old: " +oldNgramCounts.get(ngram) ); System.exit(0); } } if(finalCount>0) res.put(ngram, finalCount); } return res; } /**work for both generative and discriminative model * */ private HashMap<String,Integer> computeFinalTransitionNgrams(HGNode antNode, int startNgramOrder, int endNgramOrder){ if(baselineLMOrder < endNgramOrder){ System.out.println("baselineLMOrder is too small"); System.exit(0); } HashMap<String, Integer> res = new HashMap<String, Integer>(); NgramDPState state = (NgramDPState) antNode.getDPState(this.ngramStateID); List<Integer> currentNgram = new ArrayList<Integer>(); List<Integer> leftContext = state.getLeftLMStateWords(); List<Integer> rightContext = state.getRightLMStateWords(); if (leftContext.size() != rightContext.size()) { System.out.println("computeFinalTransition: left and right contexts have unequal lengths"); System.exit(1); } //============ left context currentNgram.add(START_SYM_ID); for (int i = 0; i < leftContext.size(); i++) { int t = leftContext.get(i); currentNgram.add(t); if(currentNgram.size()>=startNgramOrder && currentNgram.size()<=endNgramOrder) this.getNgrams(res, currentNgram.size(), currentNgram.size(), currentNgram); if (currentNgram.size() == baselineLMOrder) { currentNgram.remove(0); } } //============ right context //switch context: get the last possible new ngram: this ngram can be <s> a </s> int tSize = currentNgram.size(); for (int i = 0; i < rightContext.size(); i++) {//replace context currentNgram.set(tSize - rightContext.size() + i, rightContext.get(i)); } currentNgram.add(STOP_SYM_ID); if(currentNgram.size()>=startNgramOrder && currentNgram.size()<=endNgramOrder) this.getNgrams(res, currentNgram.size(), currentNgram.size(), currentNgram); return res; } /**TODO: This does not work for a generative model. * For example, for a rule: a b x_0 c d, under generative model, we only want ngrams: * a; a b; c; c d;, but not b and d * * **/ private HashMap<String, Integer> computeRuleNgrams(Rule rule, int startNgramOrder, int endNgramOrder) { if(baselineLMOrder < endNgramOrder){ System.out.println("baselineLMOrder is too small"); System.exit(0); } HashMap<String, Integer> newNgramCounts = new HashMap<String, Integer>();//new ngrams created due to the combination int[] enWords = rule.getEnglish(); List<Integer> words = new ArrayList<Integer>(); for (int c = 0; c < enWords.length; c++) { int curWrd = enWords[c]; if (symbolTable.isNonterminal(curWrd)) { this.getNgrams(newNgramCounts, startNgramOrder, endNgramOrder, words); words.clear(); } else { words.add(curWrd); } } this.getNgrams(newNgramCounts, startNgramOrder, endNgramOrder, words); return newNgramCounts; } /**TODO: * This does not work when addStart == true or addEnd == true. * But, if both addStart == false or addEnd == false, then it works both for discrimnaitve and generative **/ private HashMap<String, Integer> computeFutureNgrams(NgramDPState state, int startNgramOrder, int endNgramOrder, boolean addStart, boolean addEnd) { if(baselineLMOrder < endNgramOrder){ System.out.println("baselineLMOrder is too small"); System.exit(0); } HashMap<String, Integer> res = new HashMap<String, Integer>(); List<Integer> currentNgram = new ArrayList<Integer>(); List<Integer> leftContext = state.getLeftLMStateWords(); List<Integer> rightContext = state.getRightLMStateWords(); if (leftContext.size() != rightContext.size()) { System.out.println("computeFinalTransition: left and right contexts have unequal lengths"); System.exit(1); } //============ left context if (addStart == true){//this does not really work /**TODO: this does not really work as * the new ngrams generated should be different for discriminative or generative model. * */ currentNgram.add(START_SYM_ID); } //approximate the full-ngram with smaller-order ngrams for (int i = 0; i < leftContext.size(); i++) { int t = leftContext.get(i); currentNgram.add(t); if(currentNgram.size()>=startNgramOrder && currentNgram.size()<=endNgramOrder-1) this.getNgrams(res, currentNgram.size(), currentNgram.size(), currentNgram); if (currentNgram.size() == baselineLMOrder) { currentNgram.remove(0); } } //============ right context //switch context: get the last possible new ngram: this ngram can be <s> a </s> if (addEnd == true){//only when add_end is true, we get a complete ngram, otherwise, all ngrams in r_state are incomplete and we should do nothing /**TODO: this will be different for discriminative or generative model. * For example, for discriminative model, we may get new ngrams like "a b </s>""b </s>". * But, for generative model, we will get at most one ngram, whose order is baselineLMOrder */ } return res; } /* private void getNgrams(HashMap<String,Integer> tbl, int startNgramOrder, int endNgramOrder, int[] wrds){ if(useIntegerNgram) Ngram.getNgrams(tbl, startNgramOrder, endNgramOrder, wrds); else Ngram.getNgrams(symbolTable, tbl, startNgramOrder, endNgramOrder, wrds); } */ private void getNgrams(HashMap<String,Integer> tbl, int startNgramOrder, int endNgramOrder, List<Integer> wrds){ if(useIntegerNgram) Ngram.getNgrams(tbl, startNgramOrder, endNgramOrder, wrds); else Ngram.getNgrams(symbolTable, tbl, startNgramOrder, endNgramOrder, wrds); } }