package joshua.discriminative.training.contrastive_estimation; import java.io.BufferedWriter; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Formatter; import java.util.HashMap; import java.util.Iterator; import java.util.List; import joshua.corpus.vocab.BuildinSymbol; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.ff.tm.Rule; import joshua.decoder.ff.tm.hiero.MemoryBasedBatchGrammar; import joshua.decoder.hypergraph.DiskHyperGraph; import joshua.decoder.hypergraph.HGNode; import joshua.decoder.hypergraph.HyperEdge; import joshua.decoder.hypergraph.HyperGraph; import joshua.discriminative.FileUtilityOld; /* Zhifei Li, <zhifei.work@gmail.com> * Johns Hopkins University */ public class ConfusionExtractor { /**TODO: [X,1] should be synchronized with TMGrammar * */ static protected String nonterminalRegexp = "^\\[[A-Z]+\\,[0-9]*\\]$"; static String KEY_SEPARATOR=" ||| "; static String DEFAULT_NON_TERMINAL="X"; HashMap<String, Double> oneWayConfusionTbl =new HashMap<String, Double>(); HashMap<HGNode, Integer> processedItemsTbl = new HashMap<HGNode, Integer>();//Cell-spcific: used for chart construction; Item-specific: for the confusion collection int numProcessedNodes=0; int numHyperEdges=0; //chart, which can be contructed from the hyper-graph ArrayList<HGNode>[][] bins; SymbolTable symbolTbl; /** conditions to decide if two rules are confusible * */ boolean mustNotSameRule = false; boolean mustHaveSameLHS = false; boolean mustHaveSameArity = true; boolean mustNotOOVRule = true; //boolean mustHaveSameAntItemSpans = false; public ConfusionExtractor(SymbolTable symbol_){ symbolTbl = symbol_; } //===================================================================================== //*****Cell specific confusion (but the lhs, cell span, ant spans are the same)******** //===================================================================================== public void cellSpecificConfusionExtraction(HyperGraph hg, int fr_sent_len){ reconstructChartFromHypergraph(hg, fr_sent_len); //get confusion for(int width=1; width<=fr_sent_len; width++){ for(int i=0; i<=fr_sent_len-width; i++){ int j= i + width; if(bins[i][j]!=null) getConfusionWithinCell(bins[i][j]); } } } private void getConfusionWithinCell(List<HGNode> l_items){ //===first get a list of hyper-edges List<HyperEdge> listHyperedges = new ArrayList<HyperEdge>(); for(HGNode it : l_items) listHyperedges.addAll(it.hyperedges); //===O(n^2) symetric comparison getConfusionFromRules( getListRules(listHyperedges) ); } //===================================================================================== //*****reconstruct a chart from a hypergraph ******** //===================================================================================== @SuppressWarnings("unchecked") private void reconstructChartFromHypergraph(HyperGraph hg, int fr_sent_len){ processedItemsTbl.clear(); bins = new ArrayList[fr_sent_len][fr_sent_len+1]; //TODO: ignore confusion in goal_item for(HyperEdge dt : hg.goalNode.hyperedges){ if(dt.getAntNodes()!=null) for(HGNode ant_it : dt.getAntNodes()) reconstructChartForItem(ant_it); } } private void reconstructChartForItem(HGNode it){ if(processedItemsTbl.containsKey(it)) return; //if(it==null)System.out.println("Item i j is :" + it.i + " " + it.j); processedItemsTbl.put(it,1); numProcessedNodes++; if(bins[it.i][it.j]==null) bins[it.i][it.j] = new ArrayList<HGNode>(); bins[it.i][it.j].add(it); for(HyperEdge dt : it.hyperedges){ if(dt.getAntNodes()!=null) for(HGNode ant_it : dt.getAntNodes()) reconstructChartForItem(ant_it); } } // ===================================================================================== // *****LM item specific confusion extraction ******** // ===================================================================================== public void itemSpecificConfusionExtraction(HyperGraph hg){ processedItemsTbl.clear(); System.out.println("----before call: number of forward entries is---: " + oneWayConfusionTbl.size()); getConfusionWithinLMItem(hg.goalNode); } //get confusion existing in a given LM item private void getConfusionWithinLMItem(HGNode it){ if(processedItemsTbl.containsKey(it)) return; processedItemsTbl.put(it,1); numProcessedNodes++; numHyperEdges += it.hyperedges.size(); //process current item: O(n^2) comparison, symetric getConfusionFromRules( getListRules(it.hyperedges) ); //recursively call ant items //TODO: ?? what if an item are shared by many times, presently: we only process for each unique item, otherwise it is too slow for(HyperEdge dt : it.hyperedges){ if(dt.getAntNodes()!=null) for(HGNode ant_it : dt.getAntNodes()) getConfusionWithinLMItem(ant_it); } } // ===================================================================================== // ***** common functions ******** // ===================================================================================== private List<Rule> getListRules(List<HyperEdge> edges){ List<Rule> res = new ArrayList<Rule>(); for(HyperEdge ed : edges){ res.add(ed.getRule()); } return res; } // O(n^2) comparisons protected void getConfusionFromRules(List<Rule> rules, List<Double> probs){ for(int i=0; i<rules.size(); i++){ Rule rule1= rules.get(i); for(int j=0; j<rules.size(); j++){ Rule rule2= rules.get(j); /**use the probability of rule j **/ processRulePair(rule1, rule2, probs.get(j)); } } } //O(n^2) comparisons protected void getConfusionFromRules(List<Rule> rules){ for(int i=0; i<rules.size(); i++){ Rule rule1= rules.get(i); for(int j=0; j<rules.size(); j++){ Rule rule2= rules.get(j); processRulePair(rule1, rule2, 1.0);//TODO } } } //one direction only private void processRulePair(Rule rule1, Rule rule2, double softCount){ if(isConfusable(rule1, rule2)){ String key1 = getRulePairKey(rule1, rule2); Double oldCount = oneWayConfusionTbl.get(key1); if(oldCount!=null){ oneWayConfusionTbl.put(key1, oldCount+softCount); }else{ oneWayConfusionTbl.put(key1, softCount); } } } /* //O(n^2) comparison protected void getConfusionFromRules(List<Rule> rules){ for(int i=0; i<rules.size(); i++){ Rule rule1= rules.get(i); for(int j=i; j<rules.size(); j++){ Rule rule2= rules.get(j); processRulePair(rule1,rule2, 1.0);//TODO } } } //two directions private void processRulePair(Rule rule1, Rule rule2, double softCount){ if(isConfusable(rule1, rule2)){ String key1 = getRulePairKey(rule1, rule2); String key2 = getRulePairKey(rule2, rule1);//reverse Double oldCount = oneWayConfusionTbl.get(key1); if(oldCount!=null){ oneWayConfusionTbl.put(key1, oldCount+softCount); oneWayConfusionTbl.put(key2, oldCount+softCount); }else{ oneWayConfusionTbl.put(key1, softCount); oneWayConfusionTbl.put(key2, softCount); } } } */ private boolean isConfusable(Rule from, Rule to){ if(from==null || to==null) return false; if( (mustNotSameRule && from.getRuleID() == to.getRuleID()) || //must not be the same rule (mustHaveSameLHS && from.getLHS() != to.getLHS()) || //must have the same lhs (mustHaveSameArity && from.getArity() != to.getArity()) || //must have the same arity (mustNotOOVRule && isOutOfVocabularyRule(from)) || (mustNotOOVRule && isOutOfVocabularyRule(to)) ) //must not be the oov rule return false; /* //all ant items must have the same span if(mustHaveSameAntItemSpans){ for(int i=0; i<from.get_rule().getArity(); i++){ HGNode it1= from.get_ant_items().get(i); HGNode it2= to.get_ant_items().get(i); if(it1.i!=it2.i || it1.j!=it2.j) return false; } }*/ return true; } private final boolean isOutOfVocabularyRule(Rule rl) { return (rl.getRuleID() == MemoryBasedBatchGrammar.OOV_RULE_ID); } private String getRulePairKey(Rule rl1, Rule rl2){ return getRuleSignatureInEnglish(rl1) + KEY_SEPARATOR + getRuleSignatureInEnglish(rl2); } //TODO: the lhs symbol private String getRuleSignatureInEnglish(Rule rl){ /*StringBuffer res = new StringBuffer(); for(int i=0; i<rl.english.length; i++){ res.append(rl.english[i]); if(i<rl.english.length-1)res.append(" "); } return res.toString();*/ return symbolTbl.getWords(rl.getEnglish()); } // ===================================================================================== // ***** normalize and print mono-lingual Synchronous Grammar ******** // ===================================================================================== public void printConfusionTbl(String file){ BufferedWriter out= FileUtilityOld.handleNullFile(file); System.out.println("----number of hyper-edges ---: " + numHyperEdges); System.out.println("----number of processed items is---: " + numProcessedNodes); System.out.println("----number of confusion entries is---: " + oneWayConfusionTbl.size()); normalizeHashtable(oneWayConfusionTbl, out); FileUtilityOld.closeWriteFile(out); //System.out.println("----number of inverse entries is---: " + tbl_confusion_inverse.size()); //normalize_hashtable(tbl_confusion_inverse, null, out); //merge_normalized_hashtable(tbl_confusion_forward, tbl_confusion_inverse); } //assume an input table with format( key: (key_sub1 ||| key_sub2); value: count) //output the normalized grammmar rules private void normalizeHashtable(HashMap<String, Double> oneWayConfusionTbl, BufferedWriter out){ String keyPart1=null; //all the entries with the same french side HashMap<String, Double> valuesTbl =new HashMap<String, Double>(); double totalCount =0; for (Iterator<String> e = getSortedKeysIterator(oneWayConfusionTbl); e.hasNext();) { String keyFull = e.next(); String[] fds = keyFull.split("\\s+\\|{3}\\s+");//TODO: key separator if(fds.length!=2){ System.out.println("The key does not have two fds, must be error"); System.exit(0); } //== we get all the possible Englsih for the same french, now normalize if(keyPart1!=null && fds[0].compareTo(keyPart1)!=0){ saveEnglishsForSameFrench(out, valuesTbl, keyPart1, totalCount); valuesTbl.clear(); totalCount=0; } keyPart1 = fds[0]; double tCount = oneWayConfusionTbl.get(keyFull); totalCount += tCount; valuesTbl.put(fds[1], tCount); } //for the last one saveEnglishsForSameFrench(out, valuesTbl, keyPart1, totalCount); } private void saveEnglishsForSameFrench(BufferedWriter out, HashMap<String, Double> valuesTbl, String keyPart1, double totalCount){ for(Iterator<String> itVal = getSortedKeysIterator(valuesTbl); itVal.hasNext();){ String keyPart2 = itVal.next(); double tCount = valuesTbl.get(keyPart2); //TODO: only one non-terminal FileUtilityOld.writeLzf(out, "["+DEFAULT_NON_TERMINAL+"]" + KEY_SEPARATOR + correctIndexOrder(keyPart1 , keyPart2) + KEY_SEPARATOR + new Formatter().format("%.3f", -Math.log( tCount*1.0/totalCount) ) +"\n"); } } //get the correct order for non-terminals such that the order in the french string is strictly increasing private String correctIndexOrder(String french, String english){ StringBuffer res = new StringBuffer(); HashMap<Integer, Integer> id_maps = new HashMap<Integer, Integer>();//old_id -> new_id int cur_id=1; //french String[] wrds = french.split("\\s+"); for(int i=0; i<wrds.length; i++){ if(isNonTerminal(nonterminalRegexp, wrds[i])){ int old_id = symbolTbl.getTargetNonterminalIndex(wrds[i]); wrds[i] = "["+DEFAULT_NON_TERMINAL+","+cur_id+"]";//replace id_maps.put(old_id, cur_id); cur_id++; } res.append(wrds[i]); if(i< wrds.length-1) res.append(" "); } res.append(KEY_SEPARATOR); //english wrds = english.split("\\s+"); for(int i=0; i<wrds.length; i++){ if(isNonTerminal(nonterminalRegexp,wrds[i])){ int old_id = symbolTbl.getTargetNonterminalIndex(wrds[i]); wrds[i] = "["+DEFAULT_NON_TERMINAL+","+(Integer)id_maps.get(old_id)+"]";//replace } res.append(wrds[i]); if(i< wrds.length-1) res.append(" "); } return res.toString(); } private static final boolean isNonTerminal(String nonterminalRegexp_, String symbol) { return symbol.matches(nonterminalRegexp_); } //################################### not used ##################################### /* private void merge_normalized_hashtable(HashMap tbl1, HashMap tbl2){ if(tbl1.size()!=tbl2.size()){System.out.println("in merge, tbl sizes are different"); System.exit(0);} for (Iterator e = get_sorted_keys_iterator(tbl1); e.hasNext();) { String key = (String)e.next(); String[] fds = key.split("\\s+\\|{3}\\s+");//TODO: key separator String key_inverse = fds[1] + KEY_SEPARATOR + fds[0]; double val1 = (Double)tbl1.get(key); double val2 = (Double)tbl2.get(key_inverse); System.out.println(key + KEY_SEPARATOR + new Formatter().format("%.3f %.3f", val1, val2)); } } //assume a input table with format: key (key_sub1 ||| key_sub2), and count: private void normalize_hashtable(HashMap tbl){ String key_part1=null; HashMap values_tbl =new HashMap(); int total_count =0; for (Iterator e = get_sorted_keys_iterator(tbl); e.hasNext();) { String key_full = (String)e.next(); System.out.println(key_full); String[] fds = key_full.split("\\s+\\|{3}\\s+");//TODO: key separator if(fds.length!=2){System.out.println("The key does not have two fds, must be error"); System.exit(0);} if(key_part1!=null && fds[0].compareTo(key_part1)!=0){//normalize //for(Iterator it_val = values_tbl.keySet().iterator(); it_val.hasNext();){ for(Iterator it_val = get_sorted_keys_iterator(values_tbl); it_val.hasNext();){ String val = (String)it_val.next(); int t_c =(Integer) values_tbl.get(val); //System.out.println(key_part1 + KEY_SEPARATOR + val + KEY_SEPARATOR + t_c + KEY_SEPARATOR + new Formatter().format("%.3f", t_c*1.0/total_count)); tbl.put(key_part1 + KEY_SEPARATOR + val, t_c*1.0/total_count); } values_tbl.clear(); total_count=0; } key_part1 = fds[0]; int t_c =(Integer) tbl.get(key_full); total_count += t_c; values_tbl.put(fds[1], t_c); } //for the last one //for(Iterator it_val = values_tbl.keySet().iterator(); it_val.hasNext();){ for(Iterator it_val = get_sorted_keys_iterator(values_tbl); it_val.hasNext();){ String val = (String)it_val.next(); int t_c =(Integer) values_tbl.get(val); //System.out.println(key_part1 + KEY_SEPARATOR + val + KEY_SEPARATOR + t_c + KEY_SEPARATOR + new Formatter().format("%.3f", t_c*1.0/total_count)); tbl.put(key_part1 + KEY_SEPARATOR + val , t_c*1.0/total_count); } }*/ /*private void process_deduction_pair(Deduction dt1, Deduction dt2, HashMap tbl_exclude_rules){ if(is_confusable(dt1, dt2, tbl_exclude_rules)){ String key1 = get_rule_pair_key(dt1.get_rule(), dt2.get_rule()); if(tbl_confusion_forward.containsKey(key1)) tbl_confusion_forward.put(key1, (Integer)tbl_confusion_forward.get(key1)+1); else tbl_confusion_forward.put(key1, 1);//either key1 or key2 is fine String key2 = get_rule_pair_key(dt2.get_rule(), dt1.get_rule()); if(tbl_confusion_inverse.containsKey(key2)) tbl_confusion_inverse.put(key2, (Integer)tbl_confusion_inverse.get(key2)+1); else tbl_confusion_inverse.put(key2, 1);//either key1 or key2 is fine //if(dt1.get_rule().arity<=1){System.out.println("key is " +key1); System.exit(0);}//debug } }*/ //update one single table /*private void process_deduction_pair(Deduction dt1, Deduction dt2, HashMap tbl_exclude_rules){ if(is_confusable(dt1, dt2, tbl_exclude_rules)){ String key1 = get_rule_pair_key(dt1.get_rule(), dt2.get_rule()); if(tbl_confusion_forward.containsKey(key1)) tbl_confusion_forward.put(key1, (Integer)tbl_confusion_forward.get(key1)+1); else{ String key2 = get_rule_pair_key(dt2.get_rule(), dt1.get_rule());//reverse if(tbl_confusion_forward.containsKey(key2)) tbl_confusion_forward.put(key2, (Integer)tbl_confusion_forward.get(key2)+1); else tbl_confusion_forward.put(key1, 1);//either key1 or key2 is fine } //if(dt1.get_rule().arity<=1){System.out.println("key is " +key1); System.exit(0);}//debug } }*/ public static Iterator<String> getSortedKeysIterator(HashMap<String,Double> tbl) { ArrayList<String> v = new ArrayList<String>(tbl.keySet()); Collections.sort(v); return v.iterator(); } //======================== main method ================================ public static void main(String[] args) throws IOException{ if(args.length<3){ System.out.println("Wrong command, it should be: java ConfusionExtractor f_hypergraphs_items f_hypergraphs_grammar f_confusion_grammar total_num_sent"); } SymbolTable p_symbol = new BuildinSymbol(null); int baseline_lm_feat_id=0;//TODO boolean saveModelCosts = true; boolean itemSpecific=false; String f_hypergraphs = args[0]; String f_rule_tbl = args[1]; String f_confusion_grammar = args[2]; int total_num_sent = new Integer(args[3]); /* String f_hypergraphs="C:\\data_disk\\java_work_space\\sf_trunk\\example\\example.nbest.javalm.out.hg.items"; String f_rule_tbl="C:\\data_disk\\java_work_space\\sf_trunk\\example\\example.nbest.javalm.out.hg.rules"; String f_confusion_grammar; if(itemSpecific) f_confusion_grammar="C:\\Users\\zli\\Documents\\itemspecific.confusion.grammar"; else f_confusion_grammar="C:\\Users\\zli\\Documents\\cellspecific.confusion.grammar"; */ ConfusionExtractor g_con = new ConfusionExtractor(p_symbol); DiskHyperGraph dhg = new DiskHyperGraph(p_symbol, baseline_lm_feat_id, saveModelCosts, null); dhg.initRead(f_hypergraphs, f_rule_tbl, null); //int total_num_sent = 5; for(int sent_id=0; sent_id < total_num_sent; sent_id ++){ System.out.println("############Process sentence " + sent_id); HyperGraph hg = dhg.readHyperGraph(); if(itemSpecific) g_con.itemSpecificConfusionExtraction(hg); else g_con.cellSpecificConfusionExtraction(hg,hg.sentLen); } g_con.printConfusionTbl(f_confusion_grammar); } // ======================== end ================================ }