package joshua.decoder.chart_parser; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.PriorityQueue; import joshua.decoder.JoshuaConfiguration; import joshua.decoder.ff.FeatureFunction; import joshua.decoder.ff.state_maintenance.StateComputer; import joshua.decoder.ff.tm.Rule; import joshua.decoder.hypergraph.HGNode; public class CubePruneCombiner implements Combiner{ private List<FeatureFunction> featureFunctions; private List<StateComputer> stateComputers; public CubePruneCombiner(List<FeatureFunction> featureFunctions, List<StateComputer> stateComputers){ this.featureFunctions = featureFunctions; this.stateComputers = stateComputers; } //BUG:???????????????????? CubePrune will depend on relativeThresholdPruning, but cell.beamPruner can be null ???????????????? public void addAxioms(Chart chart, Cell cell, int i, int j, List<Rule> rules, SourcePath srcPath) { for (Rule rule : rules) { addAxiom(chart, cell, i, j, rule, srcPath); } } public void addAxiom(Chart chart, Cell cell, int i, int j, Rule rule, SourcePath srcPath) { cell.addHyperEdgeInCell( new ComputeNodeResult(this.featureFunctions, rule, null, i, j, srcPath, stateComputers, chart.segmentID), rule, i, j, null, srcPath, false); } /** Add complete Items in Chart pruning inside this function */ // TODO: our implementation do the prunining for each DotItem // under each grammar, not aggregated as in the python // version // TODO: the implementation is little bit different from // the description in Liang'2007 ACL paper public void combine(Chart chart, Cell cell, int i, int j, List<SuperNode> superNodes, List<Rule> rules, int arity, SourcePath srcPath) { //combinations: rules, antecent nodes //in the paper, combinationHeap is called cand[v] PriorityQueue<CubePruneState> combinationHeap = new PriorityQueue<CubePruneState>(); // rememeber which state has been explored HashMap<String,Integer> cubeStateTbl = new HashMap<String,Integer>(); if (null == rules || rules.size() <= 0) { return; } //===== seed the heap with best node Rule currentRule = rules.get(0); List<HGNode> currentAntNodes = new ArrayList<HGNode>(); for (SuperNode si : superNodes) { // TODO: si.nodes must be sorted currentAntNodes.add(si.nodes.get(0)); } ComputeNodeResult result = new ComputeNodeResult(featureFunctions, currentRule, currentAntNodes, i, j, srcPath, stateComputers, chart.segmentID); int[] ranks = new int[1+superNodes.size()]; // rule, ant items for (int d = 0; d < ranks.length; d++) { ranks[d] = 1; } CubePruneState bestState = new CubePruneState(result, ranks, currentRule, currentAntNodes); combinationHeap.add(bestState); cubeStateTbl.put(bestState.getSignature(),1); // cube_state_tbl.put(best_state,1); //====== extend the heap Rule oldRule = null; HGNode oldItem = null; int tem_c = 0; while (combinationHeap.size() > 0) { //========== decide if the top in the heap should be pruned tem_c++; CubePruneState curState = combinationHeap.poll(); currentRule = curState.rule; currentAntNodes = new ArrayList<HGNode>(curState.antNodes); // critical to create a new list //cube_state_tbl.remove(cur_state.get_signature()); // TODO, repeat cell.addHyperEdgeInCell(curState.nodeStatesTbl, curState.rule, i, j, curState.antNodes, srcPath, false); // pre-pruning inside this function //if the best state is pruned, then all the remaining states should be pruned away if (curState.nodeStatesTbl.getExpectedTotalLogP() < cell.beamPruner.getCutoffLogP() - JoshuaConfiguration.fuzz1) { //n_prepruned += heap_cands.size(); chart.nPreprunedFuzz1 += combinationHeap.size(); break; } //========== extend the curState, and add the candidates into the heap for (int k = 0; k < curState.ranks.length; k++) { //GET new_ranks int[] newRanks = new int[curState.ranks.length]; for (int d = 0; d < curState.ranks.length; d++) { newRanks[d] = curState.ranks[d]; } newRanks[k] = curState.ranks[k] + 1; String new_sig = CubePruneState.getSignature(newRanks); if (cubeStateTbl.containsKey(new_sig) // explored before || (k == 0 && newRanks[k] > rules.size()) || (k != 0 && newRanks[k] > superNodes.get(k-1).nodes.size()) ) { continue; } if (k == 0) { // slide rule oldRule = currentRule; currentRule = rules.get(newRanks[k]-1); } else { // slide ant oldItem = currentAntNodes.get(k-1); // conside k == 0 is rule currentAntNodes.set(k-1, superNodes.get(k-1).nodes.get(newRanks[k]-1)); } CubePruneState tState = new CubePruneState( new ComputeNodeResult(featureFunctions, currentRule, currentAntNodes, i, j, srcPath, stateComputers, chart.segmentID), newRanks, currentRule, currentAntNodes); // add state into heap cubeStateTbl.put(new_sig,1); if (result.getExpectedTotalLogP() > cell.beamPruner.getCutoffLogP() - JoshuaConfiguration.fuzz2) { combinationHeap.add(tState); } else { //n_prepruned += 1; chart.nPreprunedFuzz2 += 1; } // recover if (k == 0) { // rule currentRule = oldRule; } else { // ant currentAntNodes.set(k-1, oldItem); } } } } // =============================================================== // CubePruneState class // =============================================================== private static class CubePruneState implements Comparable<CubePruneState> { int[] ranks; ComputeNodeResult nodeStatesTbl; Rule rule; List<HGNode> antNodes; public CubePruneState(ComputeNodeResult state, int[] ranks, Rule rule, List<HGNode> antecedents) { this.nodeStatesTbl = state; this.ranks = ranks; this.rule = rule; // create a new vector is critical, because // currentAntecedents will change later this.antNodes = new ArrayList<HGNode>(antecedents); } private static String getSignature(int[] ranks2) { StringBuffer sb = new StringBuffer(); if (null != ranks2) { for (int i = 0; i < ranks2.length; i++) { sb.append(' ').append(ranks2[i]); } } return sb.toString(); } private String getSignature() { return getSignature(ranks); } /** * Compares states by ExpectedTotalLogP, allowing states * to be sorted according to their inverse order (high-prob first). */ public int compareTo(CubePruneState another) { if (this.nodeStatesTbl.getExpectedTotalLogP() < another.nodeStatesTbl.getExpectedTotalLogP()) { return 1; } else if (this.nodeStatesTbl.getExpectedTotalLogP() == another.nodeStatesTbl.getExpectedTotalLogP()) { return 0; } else { return -1; } } } }