package joshua.discriminative.training.oracle; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.BLEU; import joshua.decoder.Support; import joshua.decoder.hypergraph.HGNode; import joshua.decoder.hypergraph.HyperEdge; import joshua.decoder.hypergraph.HyperGraph; import joshua.decoder.hypergraph.KBestExtractor; public class OracleExtractionOnHGV3 extends RefineHG<DPStateOracle> { SymbolTable symbolTable; protected int srcSentLen =0; EquivLMState equiv; static protected boolean doLocalNgramClip =false; static protected boolean maitainLengthState = false; static protected int bleuOrder = 4; //== the way to compute effective reference length boolean useShortestRef = false; //TODO //derived from reference sentence protected HashMap<String, Integer> refNgramsTbl = new HashMap<String, Integer>(); protected double refSentLen =0; public OracleExtractionOnHGV3(SymbolTable symbolTable ) { this.symbolTable = symbolTable; this.equiv = new EquivLMState(this.symbolTable, bleuOrder); } // find the oracle hypothesis in the nbest list public Object[] oracleExtractOnNbest(KBestExtractor kbest_extractor, HyperGraph hg, int n, boolean do_ngram_clip, String ref_sent){ if(hg.goalNode==null) return null; kbest_extractor.resetState(); int next_n=0; double orc_bleu=-1; String orc_sent=null; while(true){ String hyp_sent = kbest_extractor.getKthHyp(hg.goalNode, ++next_n, -1, null, null);//????????? //System.out.println(hyp_sent); if(hyp_sent==null || next_n > n) break; double t_bleu = computeSentenceBleu(this.symbolTable, ref_sent, hyp_sent, do_ngram_clip, 4); if(t_bleu>orc_bleu){ orc_bleu = t_bleu; orc_sent = hyp_sent; } } System.out.println("Oracle sent in nbest: " + orc_sent); System.out.println("Oracle bleu in nbest: " + orc_bleu); Object[] res = new Object[2]; res[0]=orc_sent; res[1]=orc_bleu; return res; } public HyperGraph oracleExtractOnHG(HyperGraph hg, int srcSentLenIn, int baselineLMOrder, String refSentStr){ //TODO: baselineLMOrder srcSentLen = srcSentLenIn; //== ref tbL and effective ref len int[] refWords = this.symbolTable.addTerminals(refSentStr.split("\\s+")); refSentLen = refWords.length; refNgramsTbl.clear(); getNgrams(refNgramsTbl, bleuOrder, refWords, false); equiv.setupPrefixAndSurfixTbl(refNgramsTbl); HyperGraph res= splitHG(hg); return res; } public HyperGraph oracleExtractOnHG(HyperGraph hg, int srcSentLenIn, int baselineLMOrder, String[] refSentStrs){ // TODO: baselineLMOrder srcSentLen = srcSentLenIn; //== ref tbL and effective ref len int[] refLens = new int[refSentStrs.length]; ArrayList<HashMap<String, Integer>> listRefNgramTbl = new ArrayList<HashMap<String, Integer>>(); for(int i =0; i<refSentStrs.length; i++){ int[] refWords = this.symbolTable.addTerminals(refSentStrs[i].split("\\s+")); refLens[i] = refWords.length; HashMap<String, Integer> tRefNgramsTbl = new HashMap<String, Integer>(); getNgrams(tRefNgramsTbl, bleuOrder, refWords, false); listRefNgramTbl.add(tRefNgramsTbl); } refSentLen = BLEU.computeEffectiveLen(refLens, useShortestRef); refNgramsTbl = BLEU.computeMaxRefCountTbl(listRefNgramTbl); equiv.setupPrefixAndSurfixTbl(refNgramsTbl); HyperGraph res= splitHG(hg); return res; } private double computeAvgLen(int spanLen, int srcSentLen, double refSentLen){ return (spanLen>=srcSentLen) ? refSentLen : spanLen*refSentLen*1.0/srcSentLen;//avg len? } @Override protected HyperEdge createNewHyperEdge(HyperEdge originalEdge, List<HGNode> antVirtualItems, DPStateOracle dps) { /**compared wit the original edge, two changes: * (1) change the list of ant nodes * (2) change the transition logProb to BLEU from Model logProb * */ return new HyperEdge(originalEdge.getRule(), dps.bestDerivationLogP, null, antVirtualItems, originalEdge.getSourcePath()); } /*This procedure does * (1) create a new hyperedge (based on curEdge and ant_virtual_item) * (2) find whether an Item can contain this hyperedge (based on virtualItemSigs which is a hashmap specific to a parent_item) * (2.1) if yes, add the hyperedge, * (2.2) otherwise * (2.2.1) create a new item * (2.2.2) and add the item into virtualItemSigs **/ protected DPStateOracle computeState(HGNode originalParentItem, HyperEdge originalEdge, List<HGNode> antVirtualItems){ double refLen = computeAvgLen(originalParentItem.j-originalParentItem.i, srcSentLen, refSentLen); //=== hypereges under "goal item" does not have rule if(originalEdge.getRule()==null){ if(antVirtualItems.size()!=1){ System.out.println("error deduction under goal item have more than one item"); System.exit(0); } double bleu = antVirtualItems.get(0).bestHyperedge.bestDerivationLogP; return new DPStateOracle(0, null, null, null, bleu);//no DPState at all } ComputeOracleStateResult lmState = computeLMState(originalEdge, antVirtualItems); int hypLen = lmState.numNewWordsAtEdge; int[] numNgramMatches = new int[bleuOrder]; Iterator iter = lmState.newNgramsTbl.keySet().iterator(); while(iter.hasNext()){ String ngram = (String)iter.next(); int finalCount = lmState.newNgramsTbl.get(ngram); if(doLocalNgramClip) numNgramMatches[ngram.split("\\s+").length-1] += Support.findMin(finalCount, refNgramsTbl.get(ngram)) ; else numNgramMatches[ngram.split("\\s+").length-1] += finalCount; //do not do any cliping } if(antVirtualItems!=null){ for(int i=0; i<antVirtualItems.size(); i++){ DPStateOracle antDPState = (DPStateOracle)((RefinedNode)antVirtualItems.get(i)).dpState; hypLen += antDPState.bestLen; for(int t=0; t<bleuOrder; t++) numNgramMatches[t] += antDPState.ngramMatches[t]; } } double bleu = computeBleu(hypLen, refLen, numNgramMatches, bleuOrder); return new DPStateOracle(hypLen, numNgramMatches, lmState.leftEdgeWords, lmState.rightEdgeWords, bleu); } protected ComputeOracleStateResult computeLMState(HyperEdge dt, List<HGNode> antVirtualItems){ //=== hypereges under "goal item" does not have rule if(dt.getRule()==null){ /**TODO: we did not consider <s> and </s> here * */ HashMap<String, Integer> finalNewGramCounts = null; int hypLen = 0; return new ComputeOracleStateResult(finalNewGramCounts, null, null, hypLen); } //======== hypereges *not* under "goal item" HashMap<String, Integer> newGramCounts = new HashMap<String, Integer>();//new ngrams created due to the combination HashMap<String, Integer> oldNgramCounts = new HashMap<String, Integer>();//the ngram that has already been computed int hypLen =0; int[] enWords = dt.getRule().getEnglish(); List<Integer> words= new ArrayList<Integer>(); List<Integer> leftStateSequence = new ArrayList<Integer>(); List<Integer> rightStateSequence = new ArrayList<Integer>(); //==== get left_state_sequence, right_state_sequence, total_hyp_len, num_ngram_match for(int c=0; c<enWords.length; c++){ int word = enWords[c]; if(symbolTable.isNonterminal(word)==true){ int index = this.symbolTable.getTargetNonterminalIndex(word); DPStateOracle antState = (DPStateOracle)((RefinedNode)antVirtualItems.get(index)).dpState;//TODO List<Integer> leftContext = antState.leftLMState; List<Integer> rightContext = antState.rightLMState; for(int t : leftContext){//always have l_context words.add(t); if(leftStateSequence!=null && leftStateSequence.size()<bleuOrder-1) leftStateSequence.add(t); } getNgrams(oldNgramCounts, bleuOrder, leftContext, true); if(rightContext.size()>=bleuOrder-1){//the right and left are NOT overlapping getNgrams(newGramCounts, bleuOrder, words, true); getNgrams(oldNgramCounts, bleuOrder, rightContext, true); words.clear();//start a new chunk if(rightStateSequence!=null) rightStateSequence.clear(); for(int t : rightContext) words.add(t); } if(rightStateSequence!=null) for(int t : rightContext) rightStateSequence.add(t); }else{ words.add(word); hypLen++; if(leftStateSequence!=null && leftStateSequence.size()<bleuOrder-1) leftStateSequence.add(word); if(rightStateSequence!=null) rightStateSequence.add(word); } } getNgrams(newGramCounts, bleuOrder, words, true); //=== now deduct ngram counts HashMap<String, Integer> finalNewGramCounts = new HashMap<String, Integer>(); Iterator iter = newGramCounts.keySet().iterator(); //System.out.println("new size= " + newGramCounts.size()); //System.out.println("old size= " + oldNgramCounts.size()); while(iter.hasNext()){ String ngram = (String)iter.next(); if(refNgramsTbl.containsKey(ngram)){//TODO int finalCount = newGramCounts.get(ngram); if(oldNgramCounts.containsKey(ngram)){ finalCount -= oldNgramCounts.get(ngram); if(finalCount<0){ System.out.println("error: negative count for ngram: "+ this.symbolTable.getWord(11844) + "; new: " + newGramCounts.get(ngram) +"; old: " +oldNgramCounts.get(ngram) ); System.exit(0); } } if(finalCount>0){ finalNewGramCounts.put(ngram, finalCount); } } } List<Integer> leftLMState = equiv.getLeftEquivState(leftStateSequence); List<Integer> rightLMState = equiv.getRightEquivState(rightStateSequence); ComputeOracleStateResult res = new ComputeOracleStateResult(finalNewGramCounts, leftLMState, rightLMState, hypLen); //res.printInfo(); return res; } // ================================================================================================= //==================== BLEU-related functions ========================================== //================================================================================================= //TODO: consider merge with joshua.decoder.BLEU //do_ngram_clip: consider global n-gram clip public double computeSentenceBleu(SymbolTable p_symbol, String ref_sent, String hyp_sent, boolean do_ngram_clip, int bleu_order){ int[] numeric_ref_sent = p_symbol.addTerminals(ref_sent.split("\\s+")); int[] numeric_hyp_sent = p_symbol.addTerminals(hyp_sent.split("\\s+")); return computeSentenceBleu(numeric_ref_sent, numeric_hyp_sent, do_ngram_clip, bleu_order); } public double computeSentenceBleu( int[] ref_sent, int[] hyp_sent, boolean do_ngram_clip, int bleu_order){ double res_bleu = 0; int order =4; HashMap<String, Integer> ref_ngram_tbl = new HashMap<String, Integer> (); getNgrams(ref_ngram_tbl, order, ref_sent,false); HashMap<String, Integer> hyp_ngram_tbl = new HashMap<String, Integer> (); getNgrams(hyp_ngram_tbl, order, hyp_sent,false); int[] num_ngram_match = new int[order]; for(Iterator it = hyp_ngram_tbl.keySet().iterator(); it.hasNext();){ String ngram = (String) it.next(); if(ref_ngram_tbl.containsKey(ngram)){ if(do_ngram_clip) num_ngram_match[ngram.split("\\s+").length-1] += Support.findMin(ref_ngram_tbl.get(ngram),hyp_ngram_tbl.get(ngram)); //ngram clip else num_ngram_match[ngram.split("\\s+").length-1] += hyp_ngram_tbl.get(ngram);//without ngram count clipping } } res_bleu = computeBleu(hyp_sent.length, ref_sent.length, num_ngram_match, bleu_order); //System.out.println("hyp_len: " + hyp_sent.length + "; ref_len:" + ref_sent.length + "; bleu: " + res_bleu +" num_ngram_matches: " + num_ngram_match[0] + " " +num_ngram_match[1]+ // " " + num_ngram_match[2] + " " +num_ngram_match[3]); return res_bleu; } // sentence-bleu: BLEU= bp * prec; where prec = exp (sum 1/4 * log(prec[order])) public static double computeBleu(int hypLen, double refLen, int[] numNgramMatches, int bleuOrder){ if(hypLen<=0 || refLen<=0){ System.out.println("error: ref or hyp is zero len"); System.exit(0); } double res=0; double wt = 1.0/bleuOrder; double prec = 0; double smoothFactor=1.0; for(int t=0; t<bleuOrder && t<hypLen; t++){ if(numNgramMatches[t]>0) prec += wt*Math.log(numNgramMatches[t]*1.0/(hypLen-t)); else{ smoothFactor *= 0.5;//TODO prec += wt*Math.log(smoothFactor/(hypLen-t)); } } double bp = (hypLen>=refLen) ? 1.0 : Math.exp(1-refLen/hypLen); res = bp*Math.exp(prec); //System.out.println("hyp_len: " + hyp_len + "; ref_len:" + ref_len + "prec: " + Math.exp(prec) + "; bp: " + bp + "; bleu: " + res); return res; } //================================================================================================= //==================== ngram extraction functions ========================================== //================================================================================================= protected void getNgrams(HashMap<String, Integer> tbl, int order, int[] wrds, boolean ignoreNullEquivSymbol){ for(int i=0; i<wrds.length; i++) for(int j=0; j<order && j+i<wrds.length; j++){//ngram: [i,i+j] boolean contain_null=false; StringBuffer ngram = new StringBuffer(); for(int k=i; k<=i+j; k++){ ngram.append(wrds[k]); if(k<i+j) ngram.append(" "); } if(ignoreNullEquivSymbol && contain_null) continue;//skip this ngram String ngram_str = ngram.toString(); if(tbl.containsKey(ngram_str)) tbl.put(ngram_str, tbl.get(ngram_str)+1); else tbl.put(ngram_str, 1); } } // accumulate ngram counts into tbl protected void getNgrams(HashMap<String, Integer> tbl, int order, List<Integer> wrds, boolean ignoreNullEquivSymbol){ for(int i=0; i<wrds.size(); i++) for(int j=0; j<order && j+i<wrds.size(); j++){//ngram: [i,i+j] boolean contain_null=false; StringBuffer ngram = new StringBuffer(); for(int k=i; k<=i+j; k++){ int t_wrd = wrds.get(k); ngram.append(t_wrd); if(k<i+j) ngram.append(" "); } if(ignoreNullEquivSymbol && contain_null) continue;//skip this ngram String ngram_str = ngram.toString(); if(tbl.containsKey(ngram_str)) tbl.put(ngram_str, tbl.get(ngram_str)+1); else tbl.put(ngram_str, 1); } } }