package joshua.discriminative.training; import java.io.BufferedReader; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import joshua.corpus.vocab.BuildinSymbol; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.hypergraph.DefaultInsideOutside; import joshua.decoder.hypergraph.DiskHyperGraph; import joshua.decoder.hypergraph.HyperGraph; import joshua.discriminative.DiscriminativeSupport; import joshua.discriminative.FileUtilityOld; import joshua.discriminative.feature_related.FeatureBasedInsideOutside; import joshua.discriminative.feature_related.FeatureExtractionHG; import joshua.discriminative.feature_related.feature_template.BaselineFT; import joshua.discriminative.feature_related.feature_template.EdgeBigramFT; import joshua.discriminative.feature_related.feature_template.FeatureTemplate; import joshua.discriminative.feature_related.feature_template.NgramFT; import joshua.discriminative.feature_related.feature_template.TMFT; import joshua.discriminative.ranker.RescorerHGSimple; import joshua.discriminative.training.learning_algorithm.DefaultCRF; import joshua.discriminative.training.learning_algorithm.DefaultPerceptron; import joshua.discriminative.training.learning_algorithm.GradientBasedOptimizer; /* This class inplement hypergraph-based discriminative reranking * (1) hypergraph-based FeatureTemplate * (2) reranking related: baseline feature * (3) batch updates * */ public class HGDiscriminativeLearner { GradientBasedOptimizer optimizer = null; static boolean usingCRF = true;//default is crf static boolean usingStringOracle = false;//TODO //##feature realted HashMap empiricalFeatsTbl = new HashMap();//experical feature counts HashMap modelFeatsTbl = new HashMap(); //feature counts assigned by model HashSet<String> restrictedFeatureSet = null;// only consider feature in this set, if null, then ignore this //## batch update related int numProcessedExamples=0; //## reranking: baseline public static String baselineFeatName ="baseline_lzf"; static boolean fixBaseline= true;//TODO: the current version assume we always fix the baseline during training static double baselineScale = 1.0; //the scale for baseline relative to the corrective score RescorerHGSimple reranker = new RescorerHGSimple(); public HGDiscriminativeLearner(GradientBasedOptimizer optimizer, HashSet<String> restrictedFeatureSet){ this.optimizer = optimizer; this.restrictedFeatureSet = restrictedFeatureSet; } //all hyp are represented as a hyper-graph public void processOneSent(HyperGraph fullHG, Object oracle, String refSent, List<FeatureTemplate> featTemplates, List<FeatureTemplate> featTemplatesNobaseline){ //String sent_original_1best_debug = HyperGraph.extract_best_string(hg_full.goal_item); //String sent_rerank_1best_debug = null; //String orc_string_debug= null; //HyperGraph hg_original_1best = hg_full.get_1best_tree_hg();//debug #### find 1-best based on original model (no corrective) //####feature extraction using the current model, get g_tbl_feats_model if(usingCRF){ DefaultInsideOutside insideOutsider = new FeatureBasedInsideOutside(optimizer.getSumModel(), featTemplates, restrictedFeatureSet);//do inference using current model insideOutsider.runInsideOutside(fullHG, 0, 1,1); //inside_outside.sanity_check_hg(hg_full); FeatureExtractionHG.featureExtractionOnHG(fullHG, insideOutsider, modelFeatsTbl, restrictedFeatureSet, featTemplates); insideOutsider.clearState(); //sent_rerank_1best_debug = RescorerHG.rerank_hg_and_get_1best_string(hg_full, p_optimizer.get_sum_model(), g_baseline_scale, g_restricted_feature_set, l_feat_templates_nobaseline, false); }else{//perceptron HyperGraph rerankedOnebest = reranker.rerankHGAndGet1best(fullHG, optimizer.getSumModel(), restrictedFeatureSet, featTemplatesNobaseline, false); FeatureExtractionHG.featureExtractionOnHG(rerankedOnebest, modelFeatsTbl, restrictedFeatureSet, featTemplates); //sent_rerank_1best_debug = HyperGraph.extract_best_string(hg_reranked_1best.goal_item); } //####feature extraction on emperical, get g_tbl_feats_emperical //TODO in the case of hidden variable, we should run inside-outsude to get the expectation under current model if(usingStringOracle==false){ FeatureExtractionHG.featureExtractionOnHG((HyperGraph) oracle, null, empiricalFeatsTbl, restrictedFeatureSet, featTemplates); //orc_string_debug = HyperGraph.extract_best_string(((HyperGraph)oracle).goal_item); }else{ NBESTDiscriminativeLearner.featureExtraction((String)oracle, empiricalFeatsTbl, restrictedFeatureSet, 0, false); //orc_string_debug = (String)oracle; } //double original_bleu = OracleExtractionHG.compute_sentence_bleu(orc_string_debug, sent_original_1best_debug, true, 4); //double reranked_bleu = OracleExtractionHG.compute_sentence_bleu(orc_string_debug, sent_rerank_1best_debug, true, 4); //System.out.println("reranked bleu: " + reranked_bleu + "; original bleu: " + original_bleu + "; difference: " + (reranked_bleu-original_bleu)); /*System.out.println("g_tbl_feats_model size: " + g_tbl_feats_model.size()); System.out.println("g_tbl_feats_emperical size: " + g_tbl_feats_empirical.size()); System.out.println("g_restricted_feature_set size: " + g_restricted_feature_set.size()); */ //####common: update the models numProcessedExamples++; update_model(false); // debug begin /* //###########DEBUG //System.out.println("sum table size is " + g_tbl_sum_model.size()); //System.out.println("avg table size is " + g_tbl_avg_model.size()); //Support.print_hash_tbl(g_tbl_sum_model); //Support.print_hash_tbl(g_tbl_avg_model); String sent_orc = HyperGraph.extract_best_string(hg_oracle2.goal_item); String sent_original_1best = HyperGraph.extract_best_string(hg_original_1best.goal_item); String sent_reranked_1best = HyperGraph.extract_best_string(hg_reranked_1best.goal_item); //System.out.println("ref: " + ref_sent); //System.out.println("orc: " + sent_orc); //System.out.println("1or: " + sent_original_1best); //System.out.println("1re: " + sent_reranked_1best); //OracleExtractor.compute_sentence_bleu(ref_sent, sent_orc, true); //OracleExtractor.compute_sentence_bleu(ref_sent,sent_original_1best, true ); //OracleExtractor.compute_sentence_bleu(ref_sent,sent_reranked_1best, true ); //OracleExtractor.compute_sentence_bleu(sent_orc,sent_1best ); //##################END*/ //end } public void update_model(boolean force_update){ if(force_update || numProcessedExamples>=optimizer.getBatchSize()){ /*//debug System.out.println("baseline feature emprical " + g_tbl_feats_empirical.get(g_baseline_feat_name)); System.out.println("baseline feature model " + g_tbl_feats_model.get(g_baseline_feat_name)); //edn*/ optimizer.updateModel(empiricalFeatsTbl, modelFeatsTbl); //System.out.println("baseline feature weight " + p_optimizer.get_sum_model().get(g_baseline_feat_name)); reset_baseline_feat(); //System.out.println("baseline feature weight " + p_optimizer.get_sum_model().get(g_baseline_feat_name)); empiricalFeatsTbl.clear(); modelFeatsTbl.clear(); numProcessedExamples=0; } } public void reset_baseline_feat(){ if(fixBaseline) optimizer.setFeatureWeight(baselineFeatName, baselineScale); else{ System.out.println("not implemented"); System.exit(0); } } public static void main(String[] args) { //##read configuration information if(args.length<11){ System.out.println("wrong command, correct command should be: java Perceptron_HG is_crf lf_train_items lf_train_rules lf_orc_items lf_orc_rules f_l_num_sents f_data_sel f_model_out_prefix use_tm_feat use_lm_feat use_edge_bigram_feat_only f_feature_set use_joint_tm_lm_feature"); System.out.println("num of args is "+ args.length); for(int i=0; i <args.length; i++)System.out.println("arg is: " + args[i]); System.exit(0); } long start_time = System.currentTimeMillis(); SymbolTable symbolTbl = new BuildinSymbol(null); boolean is_using_crf = new Boolean(args[0].trim()); HGDiscriminativeLearner.usingCRF=is_using_crf; String f_l_train_items=args[1].trim(); String f_l_train_rules=args[2].trim(); String f_l_orc_items=args[3].trim(); String f_l_orc_rules=args[4].trim(); String f_l_num_sents=args[5].trim(); String f_data_sel=args[6].trim(); String f_model_out_prefix=args[7].trim(); boolean use_tm_feat = new Boolean(args[8].trim()); boolean use_lm_feat = new Boolean(args[9].trim()); boolean use_edge_ngram_only = new Boolean(args[10].trim()); String f_feature_set = null; if(args.length>11) f_feature_set = args[11].trim(); boolean use_joint_tm_lm_feature = false; if(args.length>12) use_joint_tm_lm_feature = new Boolean(args[12].trim()); boolean saveModelCosts = true; // ???????????????????????????????????????????????????? int ngramStateID = 0; //?????????????????????????????????????? //##setup feature templates list ArrayList<FeatureTemplate> l_feat_templates = new ArrayList<FeatureTemplate>(); ArrayList<FeatureTemplate> l_feat_templates_nobaseline = new ArrayList<FeatureTemplate>(); FeatureTemplate ft_bs = new BaselineFT(baselineFeatName, true);//baseline feature l_feat_templates.add(ft_bs); boolean useIntegerString = false; boolean useRuleIDName = false; if(use_tm_feat==true){ FeatureTemplate ft = new TMFT(symbolTbl, useIntegerString, useRuleIDName); l_feat_templates.add(ft); l_feat_templates_nobaseline.add(ft); } int baseline_lm_order = 3;//TODO if(use_lm_feat==true){ FeatureTemplate ft = new NgramFT(symbolTbl, false, ngramStateID, baseline_lm_order,1,2);//TODO: unigram and bi gram l_feat_templates.add(ft); l_feat_templates_nobaseline.add(ft); }else if(use_edge_ngram_only){//exclusive with use_lm_feat FeatureTemplate ft = new EdgeBigramFT(symbolTbl, ngramStateID, baseline_lm_order, useIntegerString); l_feat_templates.add(ft); l_feat_templates_nobaseline.add(ft); } if(use_joint_tm_lm_feature){ //TODO: not implement } System.out.println("feature template are " + l_feat_templates.toString()); System.out.println("feature template(no baseline) are " + l_feat_templates_nobaseline.toString()); int max_loop = 3;//TODO List<String> l_file_train_items = DiscriminativeSupport.readFileList(f_l_train_items); List<String> l_file_train_rules = DiscriminativeSupport.readFileList(f_l_train_rules); List<String> l_file_orc_items = DiscriminativeSupport.readFileList(f_l_orc_items); List<String> l_file_orc_rules =null; if(f_l_orc_rules.compareTo("flat")!=0){//TODO: oracle is a hg, not a flat string System.out.println("oracles are in a hypergraph by " + f_l_orc_rules); HGDiscriminativeLearner.usingStringOracle=false; l_file_orc_rules = DiscriminativeSupport.readFileList(f_l_orc_rules); }else{ System.out.println("flat oracles"); HGDiscriminativeLearner.usingStringOracle=true; } List<String> l_num_sents = DiscriminativeSupport.readFileList(f_l_num_sents); HashMap tbl_sent_selected = DiscriminativeSupport.setupDataSelTbl(f_data_sel);//for data selection //###### INIT Model ################### HGDiscriminativeLearner hgdl = null; //TODO optimal parameters int train_size = 610000; int batch_update_size = 30; int converge_pass = 1; double init_gain = 0.1; double sigma = 0.5; boolean is_minimize_score = true; //setup optimizer GradientBasedOptimizer optimizer = null; if(usingCRF){ HashMap crfModel = new HashMap(); if(f_feature_set!=null){ DiscriminativeSupport.loadModel(f_feature_set, crfModel, null); }else{ System.out.println("In crf, must specify feature set"); System.exit(0); } optimizer = new DefaultCRF(crfModel, train_size, batch_update_size, converge_pass, init_gain, sigma, is_minimize_score); optimizer.initModel(-1, 1);//TODO optimal initial parameters hgdl = new HGDiscriminativeLearner(optimizer, new HashSet<String>(crfModel.keySet())); hgdl.reset_baseline_feat();//add and init baseline feature System.out.println("size3: " + optimizer.getSumModel().size()); }else{//perceptron HashMap perceptron_sum_model = new HashMap(); HashMap perceptron_avg_model = new HashMap(); HashMap perceptronModel = new HashMap(); if(f_feature_set!=null){ DiscriminativeSupport.loadModel(f_feature_set, perceptronModel, null); perceptronModel.put(baselineFeatName, 1.0); System.out.println("feature set size is " + perceptronModel.size()); }else{ System.out.println("In perceptron, should specify feature set"); } optimizer = new DefaultPerceptron(perceptron_sum_model, perceptron_avg_model,train_size, batch_update_size, converge_pass, init_gain, sigma, is_minimize_score); hgdl = new HGDiscriminativeLearner(optimizer, new HashSet<String>(perceptronModel.keySet())); hgdl.reset_baseline_feat(); } //#####begin to do training int g_sent_id=0; for(int loop_id=0; loop_id<max_loop; loop_id++){ System.out.println("###################################Loop " + loop_id); for(int fid=0; fid < l_file_train_items.size(); fid++){ System.out.println("############Process file id " + fid); DiskHyperGraph dhg_train = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null); dhg_train.initRead((String)l_file_train_items.get(fid), (String)l_file_train_rules.get(fid),tbl_sent_selected); DiskHyperGraph dhg_orc =null; BufferedReader t_reader_orc =null; if(l_file_orc_rules!=null){ dhg_orc = new DiskHyperGraph(symbolTbl, ngramStateID, saveModelCosts, null); dhg_orc.initRead((String)l_file_orc_items.get(fid), (String)l_file_orc_rules.get(fid), tbl_sent_selected); }else{ t_reader_orc = FileUtilityOld.getReadFileStream((String)l_file_orc_items.get(fid),"UTF-8"); } int total_num_sent = new Integer((String)l_num_sents.get(fid)); for(int sent_id=0; sent_id < total_num_sent; sent_id ++){ System.out.println("#Process sentence " + g_sent_id); HyperGraph hg_train = dhg_train.readHyperGraph(); HyperGraph hg_orc =null; String hyp_oracle =null; if(l_file_orc_rules!=null) hg_orc = dhg_orc.readHyperGraph(); else hyp_oracle = FileUtilityOld.readLineLzf(t_reader_orc); if(hg_train!=null){//sent is not skipped if(l_file_orc_rules!=null) hgdl.processOneSent( hg_train, hg_orc, null, l_feat_templates, l_feat_templates_nobaseline); else hgdl.processOneSent( hg_train, hyp_oracle, null, l_feat_templates, l_feat_templates_nobaseline); } g_sent_id++; } } if(usingCRF){ hgdl.update_model(true); FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".crf." + loop_id, false, false); }else{//perceptron hgdl.update_model(true); ((DefaultPerceptron)optimizer).force_update_avg_model(); FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".sum." + loop_id, false, false); FileUtilityOld.printHashTbl(optimizer.getAvgModel(), f_model_out_prefix+".avg." + loop_id, false, true); } System.out.println("Time cost: " + ((System.currentTimeMillis()-start_time)/1000)); } } }