/* This file is part of the Joshua Machine Translation System. * * Joshua is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 * of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free * Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, * MA 02111-1307 USA */ package joshua.oracle; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import joshua.corpus.vocab.BuildinSymbol; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.Support; import joshua.decoder.ff.state_maintenance.NgramDPState; import joshua.decoder.hypergraph.DiskHyperGraph; import joshua.decoder.hypergraph.HGNode; import joshua.decoder.hypergraph.HyperEdge; import joshua.decoder.hypergraph.HyperGraph; import joshua.decoder.hypergraph.KBestExtractor; import joshua.decoder.hypergraph.ViterbiExtractor; import joshua.util.FileUtility; /** * approximated BLEU * (1) do not consider clipping effect * (2) in the dynamic programming, do not maintain different states * for different hyp length * (3) brief penalty is calculated based on the avg ref length * (4) using sentence-level BLEU, instead of doc-level BLEU * * @author Zhifei Li, <zhifei.work@gmail.com> (Johns Hopkins University) * @version $LastChangedDate: 2010-02-08 13:03:13 -0600 (Mon, 08 Feb 2010) $ */ public class OracleExtractionHG extends SplitHg { static String BACKOFF_LEFT_LM_STATE_SYM="<lzfbo>"; public int BACKOFF_LEFT_LM_STATE_SYM_ID;//used for equivelant state static String NULL_LEFT_LM_STATE_SYM="<lzflnull>"; public int NULL_LEFT_LM_STATE_SYM_ID;//used for equivelant state static String NULL_RIGHT_LM_STATE_SYM="<lzfrnull>"; public int NULL_RIGHT_LM_STATE_SYM_ID;//used for equivelant state // int[] ref_sentence;//reference string (not tree) protected int src_sent_len =0; protected int ref_sent_len =0; protected int g_lm_order=4; //only used for decide whether to get the LM state by this class or not in compute_state static protected boolean do_local_ngram_clip =false; static protected boolean maitain_length_state = false; static protected int g_bleu_order=4; static boolean using_left_equiv_state = true; static boolean using_right_equiv_state = true; //TODO Add generics to hash tables in this class HashMap<String, Boolean> tbl_suffix = new HashMap<String, Boolean>(); HashMap<String, Boolean> tbl_prefix = new HashMap<String, Boolean>(); static PrefixGrammar grammar_prefix = new PrefixGrammar();//TODO static PrefixGrammar grammar_suffix = new PrefixGrammar();//TODO // key: item; value: best_deduction, best_bleu, best_len, # of n-gram match where n is in [1,4] protected HashMap<String, Integer> tbl_ref_ngrams = new HashMap<String, Integer>(); static boolean always_maintain_seperate_lm_state = true; //if true: the virtual item maintain its own lm state regardless whether lm_order>=g_bleu_order SymbolTable p_symbolTable; int lm_feat_id=0; //the baseline LM feature id /** * Constructs a new object capable of extracting a tree * from a hypergraph that most closely matches a provided * oracle sentence. * <p> * It seems that the symbol table here should only need to * represent monolingual terminals, plus nonterminals. * * @param symbolTable * @param lm_feat_id_ */ public OracleExtractionHG(SymbolTable symbolTable, int lm_feat_id_){ this.p_symbolTable = symbolTable; this.lm_feat_id = lm_feat_id_; this.BACKOFF_LEFT_LM_STATE_SYM_ID = p_symbolTable.addTerminal(BACKOFF_LEFT_LM_STATE_SYM); this.NULL_LEFT_LM_STATE_SYM_ID = p_symbolTable.addTerminal(NULL_RIGHT_LM_STATE_SYM); this.NULL_RIGHT_LM_STATE_SYM_ID = p_symbolTable.addTerminal(NULL_RIGHT_LM_STATE_SYM); } /*for 919 sent, time_on_reading: 148797 time_on_orc_extract: 580286*/ public static void main(String[] args) throws IOException { /*String f_hypergraphs="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.items"; String f_rule_tbl="C:\\Users\\zli\\Documents\\mt03.src.txt.ss.nbest.hg.rules"; String f_ref_files="C:\\Users\\zli\\Documents\\mt03.ref.txt.1"; String f_orc_out ="C:\\Users\\zli\\Documents\\mt03.orc.txt";*/ if (6 != args.length) { System.out.println("Usage: java Decoder f_hypergraphs f_rule_tbl f_ref_files f_orc_out lm_order orc_extract_nbest"); System.out.println("num of args is "+ args.length); for (int i = 0; i < args.length; i++) { System.out.println("arg is: " + args[i]); } System.exit(1); } String f_hypergraphs = args[0].trim(); String f_rule_tbl = args[1].trim(); String f_ref_files = args[2].trim(); String f_orc_out = args[3].trim(); int lm_order = Integer.parseInt(args[4].trim()); boolean orc_extract_nbest = Boolean.valueOf(args[5].trim()); // oracle extraction from nbest or hg //?????????????????????????????????????? int baseline_lm_feat_id = 0; //?????????????????????????????????????? SymbolTable p_symbolTable = new BuildinSymbol(null); KBestExtractor kbest_extractor = null; int topN = 300;//TODO boolean extract_unique_nbest = true;//TODO boolean do_ngram_clip_nbest = true; //TODO if (orc_extract_nbest) { System.out.println("oracle extraction from nbest list"); kbest_extractor = new KBestExtractor(p_symbolTable, extract_unique_nbest, false, false, false, false, true); } BufferedWriter orc_out = FileUtility.getWriteFileStream(f_orc_out); long start_time0 = System.currentTimeMillis(); long time_on_reading = 0; long time_on_orc_extract = 0; BufferedReader t_reader_ref = FileUtility.getReadFileStream(f_ref_files); DiskHyperGraph dhg_read = new DiskHyperGraph(p_symbolTable, baseline_lm_feat_id, true, null); dhg_read.initRead(f_hypergraphs, f_rule_tbl, null); OracleExtractionHG orc_extractor = new OracleExtractionHG(p_symbolTable, baseline_lm_feat_id); String ref_sent= null; long start_time = System.currentTimeMillis(); int sent_id=0; while( (ref_sent=FileUtility.read_line_lzf(t_reader_ref))!= null ){ System.out.println("############Process sentence " + sent_id); start_time = System.currentTimeMillis(); sent_id++; //if(sent_id>10)break; HyperGraph hg = dhg_read.readHyperGraph(); if(hg==null)continue; String orc_sent=null; double orc_bleu=0; //System.out.println("read disk hyp: " + (System.currentTimeMillis()-start_time)); time_on_reading += System.currentTimeMillis()-start_time; start_time = System.currentTimeMillis(); if(orc_extract_nbest){ Object[] res = orc_extractor.oracle_extract_nbest(kbest_extractor, hg, topN, do_ngram_clip_nbest, ref_sent); orc_sent = (String) res[0]; orc_bleu = (Double) res[1]; }else{ HyperGraph hg_oracle = orc_extractor.oracle_extract_hg(hg, hg.sentLen, lm_order, ref_sent); orc_sent = ViterbiExtractor.extractViterbiString(p_symbolTable, hg_oracle.goalNode); orc_bleu = orc_extractor.get_best_goal_cost(hg, orc_extractor.g_tbl_split_virtual_items); time_on_orc_extract += System.currentTimeMillis()-start_time; System.out.println("num_virtual_items: " + orc_extractor.g_num_virtual_items + " num_virtual_dts: " + orc_extractor.g_num_virtual_deductions); //System.out.println("oracle extract: " + (System.currentTimeMillis()-start_time)); } orc_out.write(orc_sent+"\n"); System.out.println("orc bleu is " + orc_bleu); } t_reader_ref.close(); orc_out.close(); System.out.println("time_on_reading: " + time_on_reading); System.out.println("time_on_orc_extract: " + time_on_orc_extract); System.out.println("total running time: " + (System.currentTimeMillis() - start_time0)); } //find the oracle hypothesis in the nbest list public Object[] oracle_extract_nbest(KBestExtractor kbest_extractor, HyperGraph hg, int n, boolean do_ngram_clip, String ref_sent){ if(hg.goalNode==null) return null; kbest_extractor.resetState(); int next_n=0; double orc_bleu=-1; String orc_sent=null; while(true){ String hyp_sent = kbest_extractor.getKthHyp(hg.goalNode, ++next_n, -1, null, null);//????????? if(hyp_sent==null || next_n > n) break; double t_bleu = compute_sentence_bleu(this.p_symbolTable, ref_sent, hyp_sent, do_ngram_clip, 4); if(t_bleu>orc_bleu){ orc_bleu = t_bleu; orc_sent = hyp_sent; } } System.out.println("Oracle sent: " + orc_sent); System.out.println("Oracle bleu: " + orc_bleu); Object[] res = new Object[2]; res[0]=orc_sent; res[1]=orc_bleu; return res; } public HyperGraph oracle_extract_hg(HyperGraph hg, int src_sent_len_in, int lm_order, String ref_sent_str) { int[] ref_sent = this.p_symbolTable.addTerminals(ref_sent_str.split("\\s+")); g_lm_order=lm_order; src_sent_len = src_sent_len_in; ref_sent_len = ref_sent.length; tbl_ref_ngrams.clear(); get_ngrams(tbl_ref_ngrams, g_bleu_order, ref_sent, false); if(using_left_equiv_state || using_right_equiv_state){ tbl_prefix.clear(); tbl_suffix.clear(); setup_prefix_suffix_tbl(ref_sent, g_bleu_order, tbl_prefix, tbl_suffix); setup_prefix_suffix_grammar(ref_sent, g_bleu_order, grammar_prefix, grammar_suffix);//TODO } split_hg(hg); //System.out.println("best bleu is " + get_best_goal_cost( hg, g_tbl_split_virtual_items)); return get_1best_tree_hg(hg, g_tbl_split_virtual_items); } /*This procedure does * (1) identify all possible match * (2) add a new deduction for each matches*/ protected void process_one_combination_axiom(HGNode parent_item, HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt){ if (null == cur_dt.getRule()) { throw new RuntimeException("error null rule in axiom"); } double avg_ref_len = (parent_item.j-parent_item.i>=src_sent_len) ? ref_sent_len : (parent_item.j-parent_item.i)*ref_sent_len*1.0/src_sent_len;//avg len? double bleu_score[] = new double[1]; DPStateOracle dps = compute_state(parent_item, cur_dt, null, tbl_ref_ngrams, do_local_ngram_clip, g_lm_order, avg_ref_len, bleu_score, tbl_suffix, tbl_prefix); VirtualDeduction t_dt = new VirtualDeduction(cur_dt, null, -bleu_score[0]);//cost: -best_bleu g_num_virtual_deductions++; add_deduction(parent_item, virtual_item_sigs, t_dt, dps, true); } /*This procedure does * (1) create a new deduction (based on cur_dt and ant_virtual_item) * (2) find whether an Item can contain this deduction (based on virtual_item_sigs which is a hashmap specific to a parent_item) * (2.1) if yes, add the deduction, * (2.2) otherwise * (2.2.1) create a new item * (2.2.2) and add the item into virtual_item_sigs **/ protected void process_one_combination_nonaxiom(HGNode parent_item, HashMap<String, VirtualItem> virtual_item_sigs, HyperEdge cur_dt, ArrayList<VirtualItem> l_ant_virtual_item){ if (null == l_ant_virtual_item) { throw new RuntimeException("wrong call in process_one_combination_nonaxiom"); } double avg_ref_len = (parent_item.j-parent_item.i>=src_sent_len) ? ref_sent_len : (parent_item.j-parent_item.i)*ref_sent_len*1.0/src_sent_len;//avg len? double bleu_score[] = new double[1]; DPStateOracle dps = compute_state(parent_item, cur_dt, l_ant_virtual_item, tbl_ref_ngrams, do_local_ngram_clip, g_lm_order, avg_ref_len, bleu_score, tbl_suffix, tbl_prefix); VirtualDeduction t_dt = new VirtualDeduction(cur_dt, l_ant_virtual_item, -bleu_score[0]);//cost: -best_bleu g_num_virtual_deductions++; add_deduction(parent_item, virtual_item_sigs, t_dt, dps, true); } //DPState maintain all the state information at an item that is required during dynamic programming protected static class DPStateOracle extends DPState { int best_len; //this may not be used in the signature int[] ngram_matches; int[] left_lm_state; int[] right_lm_state; public DPStateOracle(int blen, int[] matches, int[] left, int[] right){ best_len = blen; ngram_matches = matches; left_lm_state = left; right_lm_state = right; } protected String get_signature() { StringBuffer res = new StringBuffer(); if (maitain_length_state) { res.append(best_len); res.append(' '); } if (null != left_lm_state) { // goal-item have null state for (int i = 0; i < left_lm_state.length; i++) { res.append(left_lm_state[i]); res.append(' '); } } res.append("lzf "); if (null != right_lm_state) { // goal-item have null state for (int i = 0; i < right_lm_state.length; i++) { res.append(right_lm_state[i]); res.append(' '); } } //if(left_lm_state==null || right_lm_state==null)System.out.println("sig is: " + res.toString()); return res.toString(); } protected void print(){ StringBuffer res = new StringBuffer(); res.append("DPstate: best_len: "); res.append(best_len); for(int i=0; i<ngram_matches.length; i++){ res.append("; ngram: "); res.append(ngram_matches[i]); } System.out.println(res.toString()); } } // ########################## commmon funcions ##################### //based on tbl_oracle_states, tbl_ref_ngrams, and dt, get the state //get the new state: STATE_BEST_DEDUCT STATE_BEST_BLEU STATE_BEST_LEN NGRAM_MATCH_COUNTS protected DPStateOracle compute_state( HGNode parent_item, HyperEdge dt, ArrayList<VirtualItem> l_ant_virtual_item, HashMap<String,Integer> tbl_ref_ngrams, boolean do_local_ngram_clip, int lm_order, double ref_len, double[] bleu_score, HashMap<String, Boolean> tbl_suffix, HashMap<String, Boolean> tbl_prefix ) { //##### deductions under "goal item" does not have rule if (null == dt.getRule()) { if (l_ant_virtual_item.size() != 1) { throw new RuntimeException("error deduction under goal item have more than one item"); } bleu_score[0] = -l_ant_virtual_item.get(0).best_virtual_deduction.best_cost; return new DPStateOracle(0, null, null,null); // no DPState at all } //################## deductions *not* under "goal item" HashMap<String, Integer> new_ngram_counts = new HashMap<String, Integer>();//new ngrams created due to the combination HashMap<String, Integer> old_ngram_counts = new HashMap<String, Integer>();//the ngram that has already been computed int total_hyp_len =0; int[] num_ngram_match = new int[g_bleu_order]; int[] en_words = dt.getRule().getEnglish(); //####calulate new and old ngram counts, and len ArrayList<Integer> words= new ArrayList<Integer>(); // used for compute left- and right- lm state ArrayList<Integer> left_state_sequence = null; // used for compute left- and right- lm state ArrayList<Integer> right_state_sequence = null; int correct_lm_order = lm_order; if (always_maintain_seperate_lm_state || lm_order < g_bleu_order) { left_state_sequence = new ArrayList<Integer>(); right_state_sequence = new ArrayList<Integer>(); correct_lm_order = g_bleu_order; // if lm_order is smaller than g_bleu_order, we will get the lm state by ourself } //#### get left_state_sequence, right_state_sequence, total_hyp_len, num_ngram_match for (int c = 0; c < en_words.length; c++) { int c_id = en_words[c]; if (this.p_symbolTable.isNonterminal(c_id)) { int index = this.p_symbolTable.getTargetNonterminalIndex(c_id); DPStateOracle ant_state = (DPStateOracle) l_ant_virtual_item.get(index).dp_state; total_hyp_len += ant_state.best_len; for (int t = 0; t < g_bleu_order; t++) { num_ngram_match[t] += ant_state.ngram_matches[t]; } int[] l_context = ant_state.left_lm_state; int[] r_context = ant_state.right_lm_state; for (int t : l_context) { // always have l_context words.add(t); if (null != left_state_sequence && left_state_sequence.size() < g_bleu_order-1) { left_state_sequence.add(t); } } get_ngrams(old_ngram_counts, g_bleu_order, l_context, true); if (r_context.length >= correct_lm_order-1) { // the right and left are NOT overlapping get_ngrams(new_ngram_counts, g_bleu_order, words, true); get_ngrams(old_ngram_counts, g_bleu_order, r_context, true); words.clear();//start a new chunk if (null != right_state_sequence) { right_state_sequence.clear(); } for (int t : r_context) { words.add(t); } } if (null != right_state_sequence) { for(int t : r_context) { right_state_sequence.add(t); } } } else { words.add(c_id); total_hyp_len += 1; if (null != left_state_sequence && left_state_sequence.size() < g_bleu_order-1) { left_state_sequence.add(c_id); } if (null != right_state_sequence) { right_state_sequence.add(c_id); } } } get_ngrams(new_ngram_counts, g_bleu_order, words, true); //####now deduct ngram counts for (String ngram : new_ngram_counts.keySet()) { if (tbl_ref_ngrams.containsKey(ngram)) { int final_count = (Integer)new_ngram_counts.get(ngram); if (old_ngram_counts.containsKey(ngram)) { final_count -= (Integer)old_ngram_counts.get(ngram); // BUG: Whoa, is that an actual hard-coded ID in there? :) if (final_count < 0) { throw new RuntimeException("negative count for ngram: " + this.p_symbolTable.getWord(11844) + "; new: " + new_ngram_counts.get(ngram) + "; old: " + old_ngram_counts.get(ngram) ); } } if (final_count > 0) { // TODO: not correct/global ngram clip if (do_local_ngram_clip) { // BUG: use joshua.util.Regex.spaces.split(...) num_ngram_match[ngram.split("\\s+").length-1] += Support.findMin(final_count, (Integer)tbl_ref_ngrams.get(ngram)); } else { // BUG: use joshua.util.Regex.spaces.split(...) num_ngram_match[ngram.split("\\s+").length-1] += final_count; //do not do any cliping } } } } //####now calculate the BLEU score and state int[] left_lm_state = null; int[] right_lm_state = null; if (!always_maintain_seperate_lm_state && lm_order >= g_bleu_order) { //do not need to change lm state, just use orignal lm state NgramDPState state = (NgramDPState) parent_item.getDPState(this.lm_feat_id); left_lm_state = intListToArray( state.getLeftLMStateWords() ); right_lm_state = intListToArray( state.getRightLMStateWords() ); } else { left_lm_state = get_left_equiv_state(left_state_sequence, tbl_suffix); right_lm_state = get_right_equiv_state(right_state_sequence, tbl_prefix); //debug //System.out.println("lm_order is " + lm_order); //compare_two_int_arrays(left_lm_state, (int[])parent_item.tbl_states.get(Symbol.LM_L_STATE_SYM_ID)); //compare_two_int_arrays(right_lm_state, (int[])parent_item.tbl_states.get(Symbol.LM_R_STATE_SYM_ID)); //end } bleu_score[0] = compute_bleu(total_hyp_len, ref_len, num_ngram_match, g_bleu_order); //System.out.println("blue score is " + bleu_score[0]); return new DPStateOracle(total_hyp_len, num_ngram_match, left_lm_state, right_lm_state); } private int[] intListToArray(List<Integer> words){ int[] res = new int[words.size()]; int i=0; for(int wrd : words) res[i++] = wrd; return res; } private int[] get_left_equiv_state(ArrayList<Integer> left_state_sequence, HashMap<String, Boolean> tbl_suffix) { int l_size = (left_state_sequence.size()<g_bleu_order-1)? left_state_sequence.size() : (g_bleu_order-1); int[] left_lm_state = new int[l_size]; if (!using_left_equiv_state || l_size < g_bleu_order-1) { // regular for (int i = 0; i < l_size; i++) { left_lm_state[i] = left_state_sequence.get(i); } } else { for (int i = l_size-1; i >= 0; i--) { // right to left if (is_a_suffix_in_tbl(left_state_sequence, 0, i, tbl_suffix)) { //if(is_a_suffix_in_grammar(left_state_sequence, 0, i, grammar_suffix)){ for (int j = i; j >= 0; j--) { left_lm_state[j] = left_state_sequence.get(j); } break; } else { left_lm_state[i] = this.NULL_LEFT_LM_STATE_SYM_ID; } } //System.out.println("origi left:" + Symbol.get_string(left_state_sequence) + "; equiv left:" + Symbol.get_string(left_lm_state)); } return left_lm_state; } private boolean is_a_suffix_in_tbl(ArrayList<Integer> left_state_sequence, int start_pos, int end_pos, HashMap<String, Boolean> tbl_suffix) { if ((Integer)left_state_sequence.get(end_pos) == this.NULL_LEFT_LM_STATE_SYM_ID) { return false; } StringBuffer suffix = new StringBuffer(); for (int i = end_pos; i >= start_pos; i--) { // right-most first suffix.append(left_state_sequence.get(i)); if (i > start_pos) suffix.append(' '); } return (Boolean) tbl_suffix.containsKey(suffix.toString()); } // TODO: never called. remove? private boolean is_a_suffix_in_grammar( ArrayList<Integer> left_state_sequence, int start_pos, int end_pos, PrefixGrammar grammar_suffix) { if ((Integer)left_state_sequence.get(end_pos) == this.NULL_LEFT_LM_STATE_SYM_ID) { return false; } ArrayList<Integer> suffix = new ArrayList<Integer>(); for (int i = end_pos; i >= start_pos; i--) { // right-most first suffix.add(left_state_sequence.get(i)); } return grammar_suffix.contain_ngram(suffix, 0, suffix.size()-1); } private int[] get_right_equiv_state( ArrayList<Integer> right_state_sequence, HashMap<String, Boolean> tbl_prefix) { int r_size = (right_state_sequence.size() < g_bleu_order-1) ? right_state_sequence.size() : (g_bleu_order-1); int[] right_lm_state = new int[r_size]; if (!using_right_equiv_state || r_size < g_bleu_order-1) { // regular for (int i = 0; i < r_size; i++) { right_lm_state[i] = (Integer)right_state_sequence.get(right_state_sequence.size()-r_size+i); } } else { for (int i = 0; i < r_size; i++) { // left to right if (is_a_prefix_in_tbl(right_state_sequence, right_state_sequence.size()-r_size+i, right_state_sequence.size()-1, tbl_prefix)) { //if(is_a_prefix_in_grammar(right_state_sequence, right_state_sequence.size()-r_size+i, right_state_sequence.size()-1, grammar_prefix)){ for (int j = i; j < r_size; j++) { right_lm_state[j] = (Integer)right_state_sequence.get(right_state_sequence.size()-r_size+j); } break; } else { right_lm_state[i] = this.NULL_RIGHT_LM_STATE_SYM_ID; } } //System.out.println("origi right:" + Symbol.get_string(right_state_sequence)+ "; equiv right:" + Symbol.get_string(right_lm_state)); } return right_lm_state; } private boolean is_a_prefix_in_tbl(ArrayList<Integer> right_state_sequence, int start_pos, int end_pos, HashMap<String, Boolean> tbl_prefix) { if (right_state_sequence.get(start_pos) == this.NULL_RIGHT_LM_STATE_SYM_ID) { return false; } StringBuffer prefix = new StringBuffer(); for (int i = start_pos; i <= end_pos; i++) { prefix.append(right_state_sequence.get(i)); if (i < end_pos) prefix.append(' '); } return (Boolean) tbl_prefix.containsKey(prefix.toString()); } // TODO: never called. remove? private boolean isAPrefixInGrammar( ArrayList<Integer> right_state_sequence, int start_pos, int end_pos, PrefixGrammar gr_prefix) { if (right_state_sequence.get(start_pos) == this.NULL_RIGHT_LM_STATE_SYM_ID) { return false; } return gr_prefix.contain_ngram(right_state_sequence, start_pos, end_pos); } public static void compare_two_int_arrays(int[] a, int[] b) { if (a.length != b.length) { throw new RuntimeException("two arrays do not have same size"); } for (int i = 0; i<a.length; i++) { if (a[i] != b[i]) { throw new RuntimeException("elements in two arrays are not same"); } } } //sentence-bleu: BLEU= bp * prec; where prec = exp (sum 1/4 * log(prec[order])) public static double compute_bleu(int hyp_len, double ref_len, int[] num_ngram_match, int bleu_order){ if (hyp_len <= 0 || ref_len <= 0){ throw new RuntimeException("ref or hyp is zero len"); } double res = 0; double wt = 1.0/bleu_order; double prec = 0; double smooth_factor=1.0; for (int t = 0; t < bleu_order && t < hyp_len; t++) { if (num_ngram_match[t] > 0) { prec += wt*Math.log(num_ngram_match[t]*1.0/(hyp_len-t)); } else { smooth_factor *= 0.5;//TODO prec += wt*Math.log(smooth_factor/(hyp_len-t)); } } double bp = (hyp_len>=ref_len) ? 1.0 : Math.exp(1-ref_len/hyp_len); res = bp*Math.exp(prec); //System.out.println("hyp_len: " + hyp_len + "; ref_len:" + ref_len + "prec: " + Math.exp(prec) + "; bp: " + bp + "; bleu: " + res); return res; } //accumulate ngram counts into tbl public void get_ngrams(HashMap<String,Integer> tbl, int order, int[] wrds, boolean ignore_null_equiv_symbol) { for (int i = 0; i < wrds.length; i++) { for (int j = 0; j < order && j+i < wrds.length; j++) { // ngram: [i,i+j] boolean contain_null = false; StringBuffer ngram = new StringBuffer(); for (int k = i; k <= i+j; k++) { if (wrds[k] == this.NULL_LEFT_LM_STATE_SYM_ID || wrds[k] == this.NULL_RIGHT_LM_STATE_SYM_ID) { contain_null = true; if (ignore_null_equiv_symbol) break; } ngram.append(wrds[k]); if (k < i+j) ngram.append(' '); } if (ignore_null_equiv_symbol && contain_null) continue; // skip this ngram String ngram_str = ngram.toString(); if (tbl.containsKey(ngram_str)) { tbl.put(ngram_str, (Integer)tbl.get(ngram_str)+1); } else { tbl.put(ngram_str, 1); } } } } /** accumulate ngram counts into tbl. */ public void get_ngrams(HashMap<String, Integer> tbl, int order, ArrayList<Integer> wrds, boolean ignore_null_equiv_symbol) { for (int i = 0; i < wrds.size(); i++) { // ngram: [i,i+j] for (int j = 0; j < order && j+i < wrds.size(); j++) { boolean contain_null = false; StringBuffer ngram = new StringBuffer(); for (int k = i; k <= i+j; k++) { int t_wrd = (Integer) wrds.get(k); if (t_wrd == this.NULL_LEFT_LM_STATE_SYM_ID || t_wrd == this.NULL_RIGHT_LM_STATE_SYM_ID) { contain_null = true; if (ignore_null_equiv_symbol) break; } ngram.append(t_wrd); if (k < i+j) ngram.append(' '); } // skip this ngram if (ignore_null_equiv_symbol && contain_null) continue; String ngram_str = ngram.toString(); if (tbl.containsKey(ngram_str)) { tbl.put(ngram_str, (Integer)tbl.get(ngram_str)+1); } else { tbl.put(ngram_str, 1); } } } } //do_ngram_clip: consider global n-gram clip public double compute_sentence_bleu(SymbolTable p_symbol, String ref_sent, String hyp_sent, boolean do_ngram_clip, int bleu_order) { // BUG: use joshua.util.Regex.spaces.split(...) int[] numeric_ref_sent = p_symbol.addTerminals(ref_sent.split("\\s+")); int[] numeric_hyp_sent = p_symbol.addTerminals(hyp_sent.split("\\s+")); return compute_sentence_bleu(numeric_ref_sent, numeric_hyp_sent, do_ngram_clip, bleu_order); } public double compute_sentence_bleu(int[] ref_sent, int[] hyp_sent, boolean do_ngram_clip, int bleu_order) { double res_bleu = 0; int order = 4; HashMap<String, Integer> ref_ngram_tbl = new HashMap<String, Integer>(); get_ngrams(ref_ngram_tbl, order, ref_sent, false); HashMap<String, Integer> hyp_ngram_tbl = new HashMap<String, Integer>(); get_ngrams(hyp_ngram_tbl, order, hyp_sent, false); int[] num_ngram_match = new int[order]; for (String ngram : hyp_ngram_tbl.keySet()) { if (ref_ngram_tbl.containsKey(ngram)) { if (do_ngram_clip) { // BUG: use joshua.util.Regex.spaces.split(...) num_ngram_match[ngram.split("\\s+").length-1] += Support.findMin((Integer)ref_ngram_tbl.get(ngram),(Integer)hyp_ngram_tbl.get(ngram)); //ngram clip } else { // BUG: use joshua.util.Regex.spaces.split(...) num_ngram_match[ngram.split("\\s+").length-1] += (Integer)hyp_ngram_tbl.get(ngram);//without ngram count clipping } } } res_bleu = compute_bleu(hyp_sent.length, ref_sent.length, num_ngram_match, bleu_order); //System.out.println("hyp_len: " + hyp_sent.length + "; ref_len:" + ref_sent.length + "; bleu: " + res_bleu +" num_ngram_matches: " + num_ngram_match[0] + " " +num_ngram_match[1]+ // " " + num_ngram_match[2] + " " +num_ngram_match[3]); return res_bleu; } // TODO: never called, remove? private static void printState(Object[] state) { System.out.println("State is"); for (int i = 0; i < state.length; i++) { System.out.print(state[i] + " ---- "); } System.out.println(); } //#### equivalent lm stuff ############ public static void setup_prefix_suffix_tbl(int[] wrds, int order, HashMap<String, Boolean> prefix_tbl, HashMap<String, Boolean> suffix_tbl) { for (int i = 0; i < wrds.length; i++) { for (int j = 0; j < order && j+i < wrds.length; j++) { // ngram: [i,i+j] StringBuffer ngram = new StringBuffer(); //### prefix for (int k = i; k < i+j; k++) { // all ngrams [i,i+j-1] ngram.append(wrds[k]); prefix_tbl.put(ngram.toString(), true); ngram.append(' '); } //### suffix: right-most wrd first ngram = new StringBuffer(); for (int k = i+j; k > i; k--) { // all ngrams [i+1,i+j]: reverse order ngram.append(wrds[k]); suffix_tbl.put(ngram.toString(), true);//stored in reverse order ngram.append(' '); } } } } // #### equivalent lm stuff ############ public static void setup_prefix_suffix_grammar(int[] wrds, int order, PrefixGrammar prefix_gr, PrefixGrammar suffix_gr) { for (int i = 0; i < wrds.length; i++) { for (int j = 0; j < order && j+i < wrds.length; j++) { // ngram: [i,i+j] //### prefix prefix_gr.add_ngram(wrds, i, i+j-1);//ngram: [i,i+j-1] //### suffix: right-most wrd first int[] reverse_wrds = new int[j]; for (int k = i+j, t = 0; k > i; k--) { // all ngrams [i+1,i+j]: reverse order reverse_wrds[t++] = wrds[k]; } suffix_gr.add_ngram(reverse_wrds, 0, j-1); } } } /* a backoff node is a hashtable, it may include: * (1) probabilititis for next words * (2) pointers to a next-layer backoff node (hashtable) * (3) backoff weight for this node * (4) suffix/prefix flag to indicate that there is ngrams start from this suffix */ private static class PrefixGrammar { private static class PrefixGrammarNode extends HashMap<Integer, PrefixGrammarNode> { private static final long serialVersionUID = 1L; }; PrefixGrammarNode root = new PrefixGrammarNode(); //add prefix information public void add_ngram(int[] wrds, int start_pos, int end_pos) { //######### identify the position, and insert the trinodes if necessary PrefixGrammarNode pos = root; for (int k = start_pos; k <= end_pos; k++) { int cur_sym_id = wrds[k]; PrefixGrammarNode next_layer = pos.get(cur_sym_id); if (null != next_layer) { pos = next_layer; } else { // next layer node PrefixGrammarNode tmp = new PrefixGrammarNode(); pos.put(cur_sym_id, tmp); pos = tmp; } } } public boolean contain_ngram(ArrayList<Integer> wrds, int start_pos, int end_pos) { if (end_pos < start_pos) return false; PrefixGrammarNode pos = root; for (int k = start_pos; k <= end_pos; k++) { int cur_sym_id = wrds.get(k); PrefixGrammarNode next_layer = pos.get(cur_sym_id); if (next_layer != null) { pos = next_layer; } else { return false; } } return true; } } }