package joshua.discriminative.semiring_parsingv2.applications; import joshua.corpus.vocab.BuildinSymbol; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.hypergraph.DiskHyperGraph; import joshua.decoder.hypergraph.HGNode; import joshua.decoder.hypergraph.HyperEdge; import joshua.decoder.hypergraph.HyperGraph; import joshua.discriminative.semiring_parsingv2.DefaultIOParserWithXLinearCombinator; import joshua.discriminative.semiring_parsingv2.SignedValue; import joshua.discriminative.semiring_parsingv2.pmodule.ScalarPM; import joshua.discriminative.semiring_parsingv2.semiring.LogSemiring; public class EntropyOnHGUsingIO extends DefaultIOParserWithXLinearCombinator<LogSemiring,ScalarPM>{ double scale; public EntropyOnHGUsingIO(double scale){ super(); this.scale = scale; } @Override protected ScalarPM createNewXWeight() { return new ScalarPM(); } @Override protected LogSemiring createNewKWeight() { return new LogSemiring(); } @Override protected LogSemiring getEdgeKWeight(HyperEdge dt, HGNode parent_item) { double logProb = scale * dt.getTransitionLogP(false); return new LogSemiring(logProb); } @Override protected ScalarPM getEdgeXWeight(HyperEdge dt, HGNode parent_item) { double logProb = scale * dt.getTransitionLogP(false); LogSemiring p = new LogSemiring(logProb); double val = logProb; ScalarPM r = new ScalarPM( SignedValue.createSignedValueFromRealNumber(val) ); moduleMultiSemiring(r, p); return r; } @Override public void normalizeGoal() { LogSemiring goalKVal = getGoalK(); ScalarPM goalX = getGoalX(); //goalKVal.printInfor(); //goalXVal.printInfor(); double normConstant = goalKVal.getLogValue();//p goalX.getValue().multiLogNumber(-normConstant);//r } public double getEntropy(HyperGraph hg){ return getGoalK().getLogValue() - getGoalX().getValue().convertToRealValue(); } // ####################################################################### public static void main(String[] args) { if(args.length>4){ System.out.println("Wrong number of parameters, it must have at least four parameters: java NbestMinRiskAnnealer use_shortest_ref f_config gain_factor f_dev_src f_nbest_prefix f_dev_ref1 f_dev_ref2...."); System.exit(1); } String f_dev_hg_prefix=args[0].trim(); String f_dev_items = f_dev_hg_prefix +".items"; String f_dev_rules = f_dev_hg_prefix +".rules"; double scale = 1; if(args.length>=2) scale = new Double(args[1].trim()); int numSents =5; if(args.length>=3) numSents = new Integer(args[2].trim()); int numSrcWords =1; if(args.length>=4) numSrcWords = new Integer(args[3].trim()); SymbolTable symbolTbl = new BuildinSymbol(null); int ngramStateID =0; double sumEntropy = 0; EntropyOnHGUsingIO ds = new EntropyOnHGUsingIO(1.0); DiskHyperGraph diskHG = new DiskHyperGraph(symbolTbl, ngramStateID, true, null); //have model costs stored diskHG.initRead(f_dev_items, f_dev_rules,null); for(int sentID=0; sentID < numSents; sentID ++){ System.out.println("#Process sentence " + sentID); HyperGraph testHG = diskHG.readHyperGraph(); ds.setHyperGraph(testHG); ds.runInsideOutside(); ds.printGoalX(); ds.normalizeGoal(); double entropy = ds.getEntropy(testHG); System.out.println("entropy is " + entropy); sumEntropy += entropy; ds.clearState(); } System.out.println("scale=" + scale + "; num_sents=" + numSents +"; numSrcWords="+numSrcWords); //a nats has 1.44 bits System.out.println("sum_entropy: " + scale + " " + 1.44*sumEntropy/numSrcWords); } }