package joshua.decoder.chart_parser; import java.util.HashMap; import java.util.List; import joshua.decoder.ff.FeatureFunction; import joshua.decoder.ff.state_maintenance.DPState; import joshua.decoder.ff.state_maintenance.StateComputer; import joshua.decoder.ff.tm.Rule; import joshua.decoder.hypergraph.HGNode; import joshua.decoder.hypergraph.HyperEdge; /** * * @author Zhifei Li, <zhifei.work@gmail.com> * @version $LastChangedDate: 2009-12-22 14:00:36 -0500 (星期二, 22 十二月 2009) $ */ public class ComputeNodeResult { private double expectedTotalLogP; private double finalizedTotalLogP; private double transitionTotalLogP; // the key is state id; private HashMap<Integer,DPState> dpStates; /** * Compute logPs and the states of thE node */ public ComputeNodeResult(List<FeatureFunction> featureFunctions, Rule rule, List<HGNode> antNodes, int i, int j, SourcePath srcPath, List<StateComputer> stateComputers, int sentID){ double finalizedTotalLogP = 0.0; if (null != antNodes) { for (HGNode item : antNodes) { finalizedTotalLogP += item.bestHyperedge.bestDerivationLogP; //semiring times } } HashMap<Integer,DPState> allDPStates = null; if(stateComputers!=null){ for(StateComputer stateComputer : stateComputers){ DPState dpState = stateComputer.computeState(rule, antNodes, i, j, srcPath); if(allDPStates==null) allDPStates = new HashMap<Integer,DPState>(); allDPStates.put(stateComputer.getStateID(), dpState); } } //=== compute feature logPs double transitionLogPSum = 0.0; double futureLogPEstimation = 0.0; for (FeatureFunction ff : featureFunctions) { transitionLogPSum += ff.getWeight() * ff.transitionLogP(rule, antNodes, i, j, srcPath, sentID); DPState dpState = null; if(allDPStates!=null) dpState = allDPStates.get(ff.getStateID()); futureLogPEstimation += ff.getWeight() * ff.estimateFutureLogP(rule, dpState, sentID); } /* if we use this one (instead of compute transition * logP on the fly, we will rely on the correctness * of rule.statelesscost. This will cause a nasty * bug for MERT. Specifically, even we change the * weight vector for features along the iteration, * the HG cost does not reflect that as the Grammar * is not reestimated!!! Of course, compute it on * the fly will slow down the decoding (e.g., from * 5 seconds to 6 seconds, for the example test * set) */ //transitionCostSum += rule.getStatelessCost(); //System.out.println(futureLogPEstimation); finalizedTotalLogP += transitionLogPSum; double expectedTotalLogP = finalizedTotalLogP + futureLogPEstimation; //== set the final results this.expectedTotalLogP = expectedTotalLogP; this.finalizedTotalLogP = finalizedTotalLogP; this.transitionTotalLogP = transitionLogPSum; this.dpStates = allDPStates; //System.out.println(rule.toString()); //printInfo(); } public static double computeCombinedTransitionLogP(List<FeatureFunction> featureFunctions, HyperEdge edge, int i, int j, int sentID){ double res = 0; for(FeatureFunction ff : featureFunctions) { if(edge.getRule()!=null) res += ff.getWeight() * ff.transitionLogP(edge, i, j, sentID); else res += ff.getWeight() * ff.finalTransitionLogP(edge, i, j, sentID); } return res; } public static double computeCombinedTransitionLogP(List<FeatureFunction> featureFunctions, Rule rule, List<HGNode> antNodes, int i, int j, SourcePath srcPath, int sentID){ double res = 0; for(FeatureFunction ff : featureFunctions) { if(rule!=null) res += ff.getWeight() * ff.transitionLogP(rule, antNodes, i, j, srcPath, sentID); else res += ff.getWeight() * ff.finalTransitionLogP(antNodes.get(0), i, j, srcPath, sentID); } return res; } public static double[] computeModelTransitionLogPs(List<FeatureFunction> featureFunctions, HyperEdge edge, int i, int j, int sentID){ double[] res = new double[featureFunctions.size()]; //=== compute feature logPs int k=0; for(FeatureFunction ff : featureFunctions) { if(edge.getRule()!=null) res[k] = ff.transitionLogP(edge, i, j, sentID); else res[k] = ff.finalTransitionLogP(edge, i, j, sentID); k++; } return res; } public static double[] computeModelTransitionLogPs(List<FeatureFunction> featureFunctions, Rule rule, List<HGNode> antNodes, int i, int j, SourcePath srcPath, int sentID){ double[] res = new double[featureFunctions.size()]; //=== compute feature logPs int k=0; for(FeatureFunction ff : featureFunctions) { if(rule!=null) res[k] = ff.transitionLogP(rule, antNodes, i, j, srcPath, sentID); else res[k] = ff.finalTransitionLogP(antNodes.get(0), i, j, srcPath, sentID); k++; } return res; } void setExpectedTotalLogP(double logP) { this.expectedTotalLogP = logP; } public double getExpectedTotalLogP() { return this.expectedTotalLogP; } void setFinalizedTotalLogP(double logP) { this.finalizedTotalLogP = logP; } double getFinalizedTotalLogP() { return this.finalizedTotalLogP; } void setTransitionTotalLogP(double logP) { this.transitionTotalLogP = logP; } double getTransitionTotalLogP() { return this.transitionTotalLogP; } void setDPStates(HashMap<Integer,DPState> states) { this.dpStates = states; } HashMap<Integer,DPState> getDPStates() { return this.dpStates; } public void printInfo(){ System.out.println("scores: "+ transitionTotalLogP + "; " + finalizedTotalLogP + "; " + expectedTotalLogP); } }