package joshua.discriminative.variational_decoder.nbest; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import joshua.corpus.vocab.BuildinSymbol; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.hypergraph.DiskHyperGraph; import joshua.decoder.hypergraph.HyperGraph; import joshua.decoder.hypergraph.KBestExtractor; import joshua.discriminative.semiring_parsing.AtomicSemiring; import joshua.discriminative.training.risk_annealer.nbest.NbestRiskGradientComputer; public class EstimateEntropyGapOnNbest { int topN=300; boolean useUniqueNbest =false; boolean useTreeNbest = false; boolean addCombinedCost = true; SymbolTable symbolTbl; KBestExtractor kbestExtractor; double scalingFactor= 1.0; AtomicSemiring atomicSemirng = new AtomicSemiring(1,0); public EstimateEntropyGapOnNbest(SymbolTable symbol_, double scalingFactor, int topn){ this.scalingFactor = scalingFactor; this.topN = topn; this.symbolTbl = symbol_; this.kbestExtractor = new KBestExtractor(this.symbolTbl, this.useUniqueNbest, this.useTreeNbest, false, this.addCombinedCost, false, true); } //if return_disorder_nbest=true; then a disorder nbest with sum log_prob (sent_id ||| hyp ||| empty_feature_scores ||| sum_log_prob) //else return 1best (just the hypotheses itself) public double processOneSent(HyperGraph hg, int sentenceID, boolean returnDisorderNbest){ //### step-2: get a nbest derivations List<String> nbestNonUniqueStrings = new ArrayList<String>(); kbestExtractor.lazyKBestExtractOnHG(hg, null, this.topN, sentenceID, nbestNonUniqueStrings); //### step-3: get the sum for each of the unique strings HashMap<String, Double> uniqueStringsSumProbTbl = new HashMap<String, Double>(); HashMap<String, List<Double>> uniqueStringsListProbTbl = new HashMap<String, List<Double>>(); double globalLogNorm = Double.NEGATIVE_INFINITY; for(String derivation_string : nbestNonUniqueStrings){ String[] fds = derivation_string.split("\\s+\\|{3}\\s+"); String hyp_string = fds[1]; double log_prob = new Double(fds[fds.length-1])*scalingFactor;//normalized log prob //TODO: use inside_outside_scaling_factor here //sum probablity Double old_sum = (Double)uniqueStringsSumProbTbl.get(hyp_string); if(old_sum==null){ old_sum= Double.NEGATIVE_INFINITY;//zero prob } uniqueStringsSumProbTbl.put(hyp_string, atomicSemirng.add_in_atomic_semiring(old_sum, log_prob)); globalLogNorm = atomicSemirng.add_in_atomic_semiring(globalLogNorm, log_prob); //list of probabilities List<Double> oldList = uniqueStringsListProbTbl.get(hyp_string); if(oldList==null){ oldList= new ArrayList<Double>(); uniqueStringsListProbTbl.put(hyp_string, oldList); } oldList.add(log_prob); } //### step-4: find the nbest or find the translation string having the best sum-probablity ArrayList<Double> listSumProbs = new ArrayList<Double>(); double tGlobalSum = 0; double gap = 0; for(String hyp : uniqueStringsSumProbTbl.keySet()){//each unique string double sumProb = Math.exp(uniqueStringsSumProbTbl.get(hyp) - globalLogNorm); tGlobalSum += sumProb; listSumProbs.add(sumProb); //compute H(d|y) double tLocalSum = 0; double sumLogProb = uniqueStringsSumProbTbl.get(hyp); List<Double> derivationProbs = uniqueStringsListProbTbl.get(hyp); for(int i=0; i<derivationProbs.size(); i++){ double tProb = Math.exp( derivationProbs.get(i)-sumLogProb ); tLocalSum += tProb; derivationProbs.set(i, tProb); } if(Math.abs(tLocalSum-1.0)>1e-4){System.out.println("local P is not sum to one, must be wrong; " +tLocalSum); System.exit(1);} double entropyDGivenY = NbestRiskGradientComputer.computeEntropy(derivationProbs); gap += entropyDGivenY*sumProb; } double stringEntropy = NbestRiskGradientComputer.computeEntropy(listSumProbs); if(Math.abs(tGlobalSum-1.0)>1e-4){System.out.println("global P is not sum to one, must be wrong"); System.exit(1);} System.out.println("stringEntropy " + stringEntropy + "gap " + gap); return gap; } public static void main(String[] args) throws InterruptedException, IOException { if(args.length!=6){ System.out.println("Wrong number of parameters, it must be 5"); System.exit(1); } //long start_time = System.currentTimeMillis(); String f_test_items=args[0].trim(); String f_test_rules=args[1].trim(); int num_sents=new Integer(args[2].trim()); //String f_1best=args[3].trim();//output int topN = new Integer(args[4].trim()); double inside_outside_scaling_factor = new Double(args[5].trim()); int baseline_lm_feat_id = 0; SymbolTable p_symbol = new BuildinSymbol(null); EstimateEntropyGapOnNbest cruncher = new EstimateEntropyGapOnNbest(p_symbol, inside_outside_scaling_factor, topN); //BufferedWriter t_writer_1best = FileUtilityOld.getWriteFileStream(f_1best); System.out.println("############Process file " + f_test_items); DiskHyperGraph dhg_test = new DiskHyperGraph(p_symbol, baseline_lm_feat_id, true, null); //have model costs stored dhg_test.initRead(f_test_items, f_test_rules,null); double sumGap = 0; for(int sent_id=0; sent_id < num_sents; sent_id ++){ System.out.println("#Process sentence " + sent_id); HyperGraph hg_test = dhg_test.readHyperGraph(); double gap = cruncher.processOneSent(hg_test, sent_id, false);//produce the reranked onebest sumGap += gap; } sumGap *= 1.44; System.out.println("sum of bits in gap is: " + sumGap); //FileUtilityOld.close_write_file(t_writer_1best); } }