package joshua.discriminative.training.oracle; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; import joshua.decoder.hypergraph.HGNode; import joshua.decoder.hypergraph.HyperEdge; import joshua.decoder.hypergraph.HyperGraph; /* Zhifei Li, <zhifei.work@gmail.com> * Johns Hopkins University */ /*This class implements general ways of spliting the hypergraph based on coarse-to-fine idea * input is a hypergraph * output is another hypergraph that has changed state structures * */ public abstract class SplitHg { HashMap<HGNode, ArrayList<VirtualItem> > g_tbl_split_virtual_items = new HashMap<HGNode, ArrayList<VirtualItem> >();//Key: item; Value: a list of split virtual items //number of items or deductions after splitting the hypergraph public int g_num_virtual_items = 0; public int g_num_virtual_deductions = 0; //Note: the implementaion of the folowing two functions should call add_deduction protected abstract void process_one_combination_axiom(HGNode parent_item, HashMap virtual_item_sigs, HyperEdge cur_dt); protected abstract void process_one_combination_nonaxiom(HGNode parent_item, HashMap virtual_item_sigs, HyperEdge cur_dt, ArrayList<VirtualItem> l_ant_virtual_item); //#### all the functions should be called after running split_hg(), before clearing g_tbl_split_virtual_items public double get_best_goal_cost(HyperGraph hg, HashMap g_tbl_split_virtual_items){ double res = get_virtual_goal_item(hg, g_tbl_split_virtual_items).best_virtual_deduction.best_cost; //System.out.println("best bleu is " +res); return res; } public VirtualItem get_virtual_goal_item(HyperGraph original_hg, HashMap g_tbl_split_virtual_items){ ArrayList l_virtual_items = (ArrayList) g_tbl_split_virtual_items.get(original_hg.goalNode); if(l_virtual_items.size()!=1){System.out.println("number of virtual goal items is not equal to one"); System.exit(0);} return (VirtualItem)l_virtual_items.get(0); } //get the 1best tree hg, the 1-best is ranked by the split hypergraph, but the return hypergraph is in the form of the original hg public HyperGraph get_1best_tree_hg(HyperGraph original_hg, HashMap g_tbl_split_virtual_items){ VirtualItem virutal_goal_item = get_virtual_goal_item(original_hg, g_tbl_split_virtual_items); HGNode onebest_goal_item = clone_item_with_best_deduction(virutal_goal_item); HyperGraph res = new HyperGraph(onebest_goal_item, -1, -1, original_hg.sentID, original_hg.sentLen);//TODO: number of items/deductions get_1best_tree_item(virutal_goal_item, onebest_goal_item); return res; } private void get_1best_tree_item(VirtualItem virtual_it, HGNode onebest_item){ VirtualDeduction virtual_dt = virtual_it.best_virtual_deduction; if(virtual_dt.l_ant_virtual_items!=null) for(int i=0; i< virtual_dt.l_ant_virtual_items.size(); i++){ VirtualItem ant_it = (VirtualItem) virtual_dt.l_ant_virtual_items.get(i); HGNode new_it = clone_item_with_best_deduction(ant_it); onebest_item.bestHyperedge.getAntNodes().set(i, new_it); get_1best_tree_item(ant_it,new_it); } } //TODO: tbl_states private static HGNode clone_item_with_best_deduction(VirtualItem virtual_it){ HGNode original_it = virtual_it.p_item; ArrayList<HyperEdge> l_deductions = new ArrayList<HyperEdge>(); HyperEdge clone_dt = clone_deduction(virtual_it.best_virtual_deduction); l_deductions.add(clone_dt); return new HGNode(original_it.i, original_it.j, original_it.lhs, l_deductions, clone_dt, original_it.getDPStates()); } private static HyperEdge clone_deduction(VirtualDeduction virtual_dt){ HyperEdge originalEdge = virtual_dt.p_dt; ArrayList<HGNode> l_ant_items = null; if(originalEdge.getAntNodes()!=null) l_ant_items = new ArrayList<HGNode>(originalEdge.getAntNodes());//l_ant_items will be changed in get_1best_tree_item HyperEdge res = new HyperEdge(originalEdge.getRule(), originalEdge.bestDerivationLogP, originalEdge.getTransitionLogP(false), l_ant_items, null); return res; } // ############### split hg ##### public void split_hg(HyperGraph hg){ //TODO: more pre-process in the extended class g_tbl_split_virtual_items.clear(); g_num_virtual_items = 0; g_num_virtual_deductions = 0; split_item(hg.goalNode); } //for each original Item, get a list of VirtualItem private void split_item(HGNode it){ if(g_tbl_split_virtual_items.containsKey(it)) return;//already processed HashMap virtual_item_sigs = new HashMap(); //### recursive call on each deduction if( speed_up_item(it) ){ for(HyperEdge dt : it.hyperedges){ split_deduction(dt, virtual_item_sigs, it); } } //### item-specific operation ArrayList l_virtual_items = new ArrayList();//a list of items result by splitting me for(Iterator iter = virtual_item_sigs.keySet().iterator(); iter.hasNext();) l_virtual_items.add(virtual_item_sigs.get(iter.next())); g_tbl_split_virtual_items.put(it,l_virtual_items); g_num_virtual_items += l_virtual_items.size(); //if(virtual_item_sigs.size()!=1)System.out.println("num of split items is " + virtual_item_sigs.size()); //get_best_virtual_score(it);//debug } private void split_deduction(HyperEdge cur_dt, HashMap virtual_item_sigs, HGNode parent_item){ if(speed_up_deduction(cur_dt)==false) return;//no need to continue //### recursively split all my ant items, get a l_split_items for each original item if(cur_dt.getAntNodes()!=null) for(HGNode ant_it : cur_dt.getAntNodes()) split_item(ant_it); //### recombine the deduction redo_combine(cur_dt, virtual_item_sigs, parent_item); } private void redo_combine(HyperEdge cur_dt, HashMap virtual_item_sigs, HGNode parent_item){ List<HGNode> l_ant_items = cur_dt.getAntNodes(); if(l_ant_items!=null){ if(l_ant_items.size()==1){//arity: one HGNode it = (HGNode)l_ant_items.get(0); ArrayList<VirtualItem> l_virtual_items = (ArrayList<VirtualItem>)g_tbl_split_virtual_items.get(it); for(VirtualItem ant_virtual_item: l_virtual_items){ ArrayList l_ant_virtual_item = new ArrayList();//used in combination l_ant_virtual_item.add(ant_virtual_item); process_one_combination_nonaxiom(parent_item, virtual_item_sigs, cur_dt, l_ant_virtual_item); } }else if(l_ant_items.size()==2){//arity: two HGNode it1 = (HGNode)l_ant_items.get(0); HGNode it2 = (HGNode)l_ant_items.get(1); ArrayList<VirtualItem> l_virtual_items1 = (ArrayList<VirtualItem>)g_tbl_split_virtual_items.get(it1); ArrayList<VirtualItem> l_virtual_items2 = (ArrayList<VirtualItem>)g_tbl_split_virtual_items.get(it2); for(VirtualItem virtual_it1: l_virtual_items1){ for(VirtualItem virtual_it2: l_virtual_items2){ ArrayList l_ant_virtual_item = new ArrayList();//used in combination l_ant_virtual_item.add(virtual_it1); l_ant_virtual_item.add(virtual_it2); process_one_combination_nonaxiom(parent_item, virtual_item_sigs, cur_dt, l_ant_virtual_item); } } }else{ System.out.println("Sorry, we can only deal with rules with at most TWO non-terminals");System.exit(0); } }else{//axiom case: no nonterminal //System.out.println("no ant!!! " + Symbol.get_string(cur_dt.get_rule().english)); process_one_combination_axiom(parent_item, virtual_item_sigs, cur_dt); } } //this function should be called by process_one_combination_axiom/process_one_combination_nonaxiom //virtual_item_sigs is specific to parent_item protected void add_deduction(HGNode parent_item, HashMap virtual_item_sigs, VirtualDeduction t_ded, DPState dpstate, boolean maintain_onebest_only){ if(t_ded==null) {System.out.println("deduction is null"); System.exit(0);} String sig = VirtualItem.get_signature(parent_item, dpstate); VirtualItem t_virtual_item = (VirtualItem)virtual_item_sigs.get(sig); if(t_virtual_item!=null){ t_virtual_item.add_deduction(t_ded, dpstate, maintain_onebest_only); }else{ t_virtual_item = new VirtualItem(parent_item, dpstate, t_ded, maintain_onebest_only); virtual_item_sigs.put(sig,t_virtual_item ); } } //return false if we can skip the item; protected boolean speed_up_item(HGNode it){ return true;//e.g., if the lm state is not valid, then no need to continue } // return false if we can skip the deduction; protected boolean speed_up_deduction(HyperEdge dt){ return true;// if the rule state is not valid, then no need to continue } protected abstract static class DPState { protected abstract String get_signature(); }; /*In general, variables of items * (1) list of hyperedges * (2) best hyperedge * (3) DP state * (4) signature (operated on part/full of DP state) * */ protected static class VirtualItem { HGNode p_item =null;//pointer to the true item ArrayList<VirtualDeduction> l_virtual_deductions = null; VirtualDeduction best_virtual_deduction=null; DPState dp_state;//dynamic programming state: not all the variable in dp_state are in the signature public VirtualItem(HGNode item, DPState dstate, VirtualDeduction fdt, boolean maintain_onebest_only){ p_item = item; add_deduction(fdt, dstate, maintain_onebest_only); } public void add_deduction(VirtualDeduction fdt, DPState dstate, boolean maintain_onebest_only){ if(maintain_onebest_only==false){ if(l_virtual_deductions==null) l_virtual_deductions = new ArrayList<VirtualDeduction>();; l_virtual_deductions.add(fdt); } if( best_virtual_deduction==null || fdt.best_cost < best_virtual_deduction.best_cost ) { best_virtual_deduction = fdt; dp_state = dstate; } } // not all the variable in dp_state are in the signature public String get_signature(){ return get_signature(p_item, dp_state); } public static String get_signature(HGNode item, DPState dstate){ /*StringBuffer res = new StringBuffer(); //res.append(item); res.append(" ");//TODO: res.append(dstate.get_signature()); return res.toString();*/ return dstate.get_signature(); } } protected static class VirtualDeduction { HyperEdge p_dt =null;//pointer to the true deduction ArrayList<VirtualItem> l_ant_virtual_items=null; double best_cost=Double.POSITIVE_INFINITY;//the 1-best cost of all possible derivation: best costs of ant items + non_stateless_transition_cost + r.statelesscost public VirtualDeduction(HyperEdge dt, ArrayList<VirtualItem> ant_items, double best_cost_in){ p_dt=dt; l_ant_virtual_items = ant_items; best_cost = best_cost_in; } public double get_transition_cost(){//note: transition_cost is already linearly interpolated double res = best_cost; if(l_ant_virtual_items!=null) for(VirtualItem ant_it : l_ant_virtual_items) res -= ant_it.best_virtual_deduction.best_cost; return res; } } }