package joshua.discriminative.semiring_parsingv2; import java.util.HashMap; import joshua.decoder.hypergraph.HGNode; import joshua.decoder.hypergraph.HyperEdge; import joshua.discriminative.semiring_parsingv2.semiring.Semiring; /**This implements the algorithm of Figure-3 in the emnlp paper (Li and Eisner, 2009), * in addition to the algorithm of Figure-2 that is implemented by the parent class DefaultInsideSemiringParser. * We may not want to extend this class directly. Instead, we will extend its child class DefaultIOParserWithXLinearCombinator, * which runs inside-outside and also collects the posterior counts * */ public abstract class DefaultInsideOutsideSemiringParser<K extends Semiring<K>> extends DefaultInsideSemiringParser<K> { private HashMap<HGNode, K> outsideSemiringWeightsTable ; /** * for each node, remember how many hyperedges pointering * to me, this is needed for outside estimation. * An node will recursive call its incoming * hyperedges to do outside-estimation only after it itself * is done with outside estimation, this is necessary * because the outside estimation of the node's incoding edges * require the node's outside value */ private HashMap<HGNode,Integer> numParentHyperedgesTable; public DefaultInsideOutsideSemiringParser() { super(); outsideSemiringWeightsTable = new HashMap<HGNode, K>(); numParentHyperedgesTable = new HashMap<HGNode,Integer>(); } /**for correctness and saving memory, * external class should call this method*/ @Override public void clearState(){ super.clearState(); outsideSemiringWeightsTable.clear(); numParentHyperedgesTable.clear(); } @Override protected K insideEstimationOverNode(HGNode it){ rememberNumParentHyperedges(it); return super.insideEstimationOverNode(it); } // =============================== top-downn outside estimation ====================== public void outsideEstimationOverHG(){ outsideSemiringWeightsTable.clear(); K initWeight = createNewKWeight(); initWeight.setToOne(); outsideSemiringWeightsTable.put(hg.goalNode, initWeight);//initialize for(HyperEdge dt : hg.goalNode.hyperedges) outsideEstimationOverHyperedge(dt, hg.goalNode, initWeight); } final protected void outsideEstimationOverNode(HGNode node, K parentNodeOutsideWeight, HyperEdge parentEdge, K parentEdgeWeight){ Integer numCalled = numParentHyperedgesTable.get(node); if (null == numCalled || 0 == numCalled) { System.out.println("num_called="+numCalled); throw new RuntimeException("un-expected call (the number of calls is greater than the number of parent hyperedges), must be wrong"); } numParentHyperedgesTable.put(node, numCalled-1); //====== compute: outside(v) * k_e * product of inside prob of sibling nodes K additionalOutsideProb = createNewKWeight(); additionalOutsideProb.setToOne(); //=== upper item's outside weight //K outsideProbNode = outsideSemiringWeightsTable.get(parentNode);//outside prob additionalOutsideProb.multi(parentNodeOutsideWeight);//outside(v) //=== parent hyperedge weight additionalOutsideProb.multi(parentEdgeWeight);//k_e //=== sibing specifc inside weights if(parentEdge.getAntNodes()!=null && parentEdge.getAntNodes().size()>1) for(HGNode antNode : parentEdge.getAntNodes()){ if(antNode != node){ K nodeInsideProb = insideSemiringWeightsTable.get(antNode);//inside prob additionalOutsideProb.multi(nodeInsideProb); } } //=== add to old prob K oldOutsideProb = outsideSemiringWeightsTable.get(node); if (oldOutsideProb == null) { oldOutsideProb = createNewKWeight(); oldOutsideProb.setToZero(); } oldOutsideProb.add(additionalOutsideProb); outsideSemiringWeightsTable.put(node, oldOutsideProb); //=== recursive call on each deduction if( numCalled-1<=0){//i am done for(HyperEdge dt : node.hyperedges){ outsideEstimationOverHyperedge(dt, node, oldOutsideProb); } } } protected void outsideEstimationOverHyperedge(HyperEdge dt, HGNode parentNode, K parentNodeOutsideWeight){ //we do not need to compute outside prob if no ant items if(dt.getAntNodes()!=null){ //=== deduction specific prob K edgeWeight = getEdgeKWeight(dt, parentNode); //=== recursive call on each ant item for(HGNode antNode : dt.getAntNodes()){ outsideEstimationOverNode(antNode, parentNodeOutsideWeight, dt, edgeWeight); } } } //=============================== end outside estimation //================ get number of hyperedges that point to me /**This function will be used to get the number of parent hyperedges, * which in turn will be used for outside estimation*/ final private void rememberNumParentHyperedges(HGNode node){ //System.out.println("called"); Integer numCalled = (Integer)numParentHyperedgesTable.get(node); if (null == numCalled) { numParentHyperedgesTable.put(node, 1); } else { numParentHyperedgesTable.put(node, numCalled+1); } } }