/** * */ package joshua.discriminative.variational_decoder; import java.util.HashMap; import java.util.Map; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.hypergraph.HyperGraph; import joshua.decoder.hypergraph.TrivialInsideOutside; import joshua.discriminative.DiscriminativeSupport; import joshua.discriminative.feature_related.FeatureExtractionHG; import joshua.discriminative.feature_related.feature_template.FeatureTemplate; /**esimate a variational model * */ public class VariationalNgramApproximator { FeatureTemplate featureTemplate; HashMap<String, Double> featureTbl = new HashMap<String, Double>(); private double adjustAlpha; private static String ZERO_GRAM = "lzfzerogram"; private static String STOP_SYM="</s>"; private int STOP_SYM_ID; private int ngramOrder = 3; private SymbolTable symbolTable; public VariationalNgramApproximator(SymbolTable symbolTable, FeatureTemplate ft, double adjustAlpha, int ngramOrder){ this.symbolTable = symbolTable; this.featureTemplate = ft; this.ngramOrder = ngramOrder; this.STOP_SYM_ID = this.symbolTable.addTerminal(STOP_SYM); this.adjustAlpha = adjustAlpha; if(adjustAlpha<=0 || adjustAlpha >1){ System.out.println("adjustAlpha is not with range of (0,1]; it is " + adjustAlpha); System.exit(0); } } public HashMap<String, Double> estimateModel(HyperGraph hg, TrivialInsideOutside pInsideOutside){ featureTbl.clear(); //==== collect posterio count FeatureExtractionHG.featureExtractionOnHG(hg, pInsideOutside, this.featureTbl, null, this.featureTemplate); System.out.println("after feature extraction, feat tbl is " + featureTbl.size()); //=== normalize the model return getNormalizedLM(this.featureTbl, this.ngramOrder, this.adjustAlpha); } private HashMap<String, Double> getNormalizedLM(HashMap<String, Double> ngramFeatCountTbl, int order, double adjustAlpha){ HashMap<String, Double> denominatorTbl = new HashMap<String, Double>(); int[] numNgrams = new int[order]; //=== first get normalized constants System.out.println("#### Begin to get the normalization constants"); for(Map.Entry<String, Double> entry : ngramFeatCountTbl.entrySet()){ String ngram = entry.getKey(); double count = entry.getValue(); String[] wrds = ngram.split("\\s+"); //System.out.println("ngram: " + ngram); numNgrams[wrds.length-1]++; if(wrds.length==1){//unigram DiscriminativeSupport.increaseCount(denominatorTbl, ZERO_GRAM, count); }else{ StringBuffer history = new StringBuffer(); for(int i=0; i<wrds.length-1; i++){ history.append(wrds[i]); if(i<wrds.length-2) history.append(" "); } DiscriminativeSupport.increaseCount(denominatorTbl, history.toString(), count); } } //=== now adjust the denominator; if necessary if(adjustAlpha!=1.0){ adjustDenominator(denominatorTbl, ngramFeatCountTbl); } //=== now get change normalizedModel System.out.println("=== Begin to get normalize the original ngram tbl"); HashMap<String, Double> normalizedModel = new HashMap<String, Double>(); for(Map.Entry<String, Double> entry : ngramFeatCountTbl.entrySet()){ String ngram = entry.getKey(); double count = entry.getValue(); String[] wrds = ngram.split("\\s+"); if(wrds.length==1){//unigram normalizedModel.put(ngram, getNormalizedCost(ngram, count, denominatorTbl.get(ZERO_GRAM))); }else{ StringBuffer history = new StringBuffer(); for(int i=0; i<wrds.length-1; i++){ history.append(wrds[i]); //history.append(wrds[0]);//????????????????????????????????????? wrong version if(i<wrds.length-2) history.append(" "); } normalizedModel.put(ngram, getNormalizedCost(ngram, count, denominatorTbl.get(history.toString()))); } } //print stat for(int i=0; i<order; i++){ System.out.println((i+1) + "-gram: " + numNgrams[i]); } return normalizedModel; } private void adjustDenominator(HashMap<String, Double> tblDenominator, HashMap<String, Double> ngramFeatCountTbl){ if(this.adjustAlpha!=1.0){ System.out.println("=== Begin to reajust the denominator table, whose size is " + tblDenominator.size()); //TODO: what about the history is ZERO_GRAM; in general, we should never have a unigram that is STOP_SYM_ID, so we do not need to worry about it for(Map.Entry<String, Double> entry : tblDenominator.entrySet()){ String history = entry.getKey(); String stopNgram = history + " " + this.STOP_SYM_ID; // System.out.println("stop_ngram: " +stop_ngram); Double countForStop = ngramFeatCountTbl.get(stopNgram); if(countForStop!=null){ double oldVal = entry.getValue(); entry.setValue(oldVal+(adjustAlpha-1.0)*countForStop); //System.out.println("old: " + old_val + "; new:" + entry.getValue()); } } }else{ //do nothing } } //adjust_alpha is to reajust the probability to deal with the issue that the LM will favor short sentences private double getNormalizedCost(String ngram, double ngramCount, double historyCount){ //note: history_count has already been adjusted if(this.adjustAlpha!=1.0 && ngram.endsWith(" " + this.STOP_SYM_ID)){//last wrd is stop symbol return -Math.log(this.adjustAlpha*ngramCount/historyCount); }else{ return -Math.log(ngramCount/historyCount); } } }