package joshua.discriminative.training; import java.io.BufferedReader; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import joshua.discriminative.DiscriminativeSupport; import joshua.discriminative.FileUtilityOld; import joshua.discriminative.training.learning_algorithm.DefaultCRF; import joshua.discriminative.training.learning_algorithm.DefaultPerceptron; import joshua.discriminative.training.learning_algorithm.GradientBasedOptimizer; public class NBESTDiscriminativeLearner { GradientBasedOptimizer optimizer = null; static boolean usingCRF = true;//default is crf //##feature realted HashMap<String,Double> empiricalFeatsTbl = new HashMap<String,Double>();//experical feature counts HashMap<String,Double> modelFeatsTbl = new HashMap<String,Double>(); //feature counts assigned by model HashSet<String> restrictedFeatureSet = null;// only consider feature in this set, if null, then ignore this //## batch update related int numProcessedExamples=0; //## reranking: baseline static String baselineFeatName ="baseline_lzf"; static boolean fixBaseline= true;//TODO: the current version assume we always fix the baseline during training static double baselineScale = 1.0; //the scale for baseline relative to the corrective score static int startNgramOrder =1; static int endNgramOrder =2; static int hypStrID = 1;//hyp format: sent_id ||| hyp_str ||| features ||| final score public NBESTDiscriminativeLearner(GradientBasedOptimizer optimizer, HashSet<String> restrictedFeatureSet){ this.optimizer = optimizer; this.restrictedFeatureSet = restrictedFeatureSet; } // all hyp are represented as a hyper-graph public void processOneSent( ArrayList nbest, String hyp_oracle, String ref_sent){ //####feature extraction using the current model, get modelFeatsTbl if(usingCRF){ getFeatureExpection(modelFeatsTbl, optimizer.getSumModel(), restrictedFeatureSet, nbest); }else{//perceptron String rerankedOnebest = rerankNbest(optimizer.getSumModel(), restrictedFeatureSet, nbest); featureExtraction(rerankedOnebest, modelFeatsTbl, restrictedFeatureSet, hypStrID, false); } featureExtraction(hyp_oracle, empiricalFeatsTbl, restrictedFeatureSet, 0, false); //####common: update the sum and avg model /*System.out.println("g_tbl_feats_model size: " + g_tbl_feats_model.size()); System.out.println("g_tbl_feats_emperical size: " + g_tbl_feats_empirical.size()); System.out.println("g_restricted_feature_set size: " + g_restricted_feature_set.size()); */ //####common: update the models numProcessedExamples++; updateModel(false); } public void updateModel(boolean force_update){ if(force_update || numProcessedExamples>=optimizer.getBatchSize()){ /*//debug System.out.println("baseline feature emprical " + g_tbl_feats_empirical.get(g_baseline_feat_name)); System.out.println("baseline feature model " + g_tbl_feats_model.get(g_baseline_feat_name)); //edn*/ optimizer.updateModel(empiricalFeatsTbl, modelFeatsTbl); //System.out.println("baseline feature weight " + p_optimizer.get_sum_model().get(g_baseline_feat_name)); resetBaselineFeat(); //System.out.println("baseline feature weight " + p_optimizer.get_sum_model().get(g_baseline_feat_name)); empiricalFeatsTbl.clear(); modelFeatsTbl.clear(); numProcessedExamples=0; } } public void resetBaselineFeat(){ if(fixBaseline) optimizer.setFeatureWeight(baselineFeatName, baselineScale); else{ System.out.println("not implemented"); System.exit(0); } } public String rerankNbest(HashMap<String,Double> corrective_model, HashSet<String> restrictedFeatSet, ArrayList nbest){ double best_score = -1000000000; String best_hyp=null; HashMap<String,Double> tbl_feat_set= new HashMap<String,Double>(); for(int i=0; i<nbest.size(); i++){ tbl_feat_set.clear(); String cur_hyp = (String)nbest.get(i); featureExtraction(cur_hyp, tbl_feat_set, restrictedFeatSet, hypStrID, true); double cur_score = DiscriminativeSupport.computeLinearCombinationLogP(tbl_feat_set, corrective_model); if(i==0 || cur_score > best_score){//maximize best_score = cur_score; best_hyp = cur_hyp; } } return best_hyp; } //accumulate feature expectations in res_feat_tbl public void getFeatureExpection(HashMap<String,Double> res_feat_tbl, HashMap<String,Double> corrective_model, HashSet<String> restricted_feat_set, ArrayList nbest){ //### get noralization constant, remember features, remember the combined linear score double normalization_constant = Double.NEGATIVE_INFINITY;//log-semiring ArrayList<HashMap> l_feats = new ArrayList<HashMap>(); ArrayList<Double> l_score = new ArrayList<Double>(); for(int i=0; i<nbest.size(); i++){ HashMap<String,Double> tbl_feat_set= new HashMap<String,Double>(); String cur_hyp = (String)nbest.get(i); featureExtraction(cur_hyp, tbl_feat_set, restricted_feat_set, hypStrID, true); double curScore = DiscriminativeSupport.computeLinearCombinationLogP(tbl_feat_set, corrective_model); normalization_constant = addInLogSemiring(normalization_constant, curScore,0); l_feats.add(tbl_feat_set); l_score.add(curScore); } ////### get expected feature count double sum=0; for(int i=0; i<nbest.size(); i++){ double cur_score = l_score.get(i); HashMap<String, Double> feats = l_feats.get(i); double post_prob = Math.exp(cur_score-normalization_constant); sum += post_prob; //accumulate feature counts for (Map.Entry<String, Double> entry : feats.entrySet() ){ DiscriminativeSupport.increaseCount(res_feat_tbl, entry.getKey(), entry.getValue()*post_prob); } } System.out.println("Sum is " + sum); } // OR: return Math.log(Math.exp(x) + Math.exp(y)); private double addInLogSemiring(double x, double y, int add_mode){//prevent over-flow if(add_mode==0){//sum if(x==Double.NEGATIVE_INFINITY)//if y is also n-infinity, then return n-infinity return y; if(y==Double.NEGATIVE_INFINITY) return x; if(y<=x) return x + Math.log(1+Math.exp(y-x)); else//x<y return y + Math.log(1+Math.exp(x-y)); }else if (add_mode==1){//viter-min return (x<=y)?x:y; }else if (add_mode==2){//viter-max return (x>=y)?x:y; }else{ System.out.println("invalid add mode"); System.exit(0); return 0; } } public static void featureExtraction(String hyp, HashMap<String,Double> feat_tbl, HashSet<String> restricted_feat_set, int hyp_str_id, boolean extract_baseline_feat){ String[] fds = hyp.split("\\s+\\|{3}\\s+"); //### baseline feature if(extract_baseline_feat){ String score = replaceBadSymbol(fds[fds.length-1]); double baseline_score = new Double(score); feat_tbl.put(baselineFeatName, baseline_score); } //### ngram feature String[] wrds = fds[hyp_str_id].split("\\s+"); for(int i=0; i<wrds.length; i++) for(int j=startNgramOrder-1; j<endNgramOrder && j+i<wrds.length; j++){//ngram: [i,i+j] StringBuffer ngram = new StringBuffer(); for(int k=i; k<=i+j; k++){ String t_wrd = wrds[k]; ngram.append(t_wrd); if(k<i+j) ngram.append(" "); } String ngram_str = ngram.toString(); if(restricted_feat_set==null || restricted_feat_set.contains(ngram_str)){//filter DiscriminativeSupport.increaseCount(feat_tbl, ngram_str, 1.0); } } } private static String replaceBadSymbol(String in){ if(in.startsWith("--")) return in.substring(1); else return in; } public static void main(String[] args) { if(args.length<5){ System.out.println("wrong command, correct command should be: java is_using_crf f_l_train_nbest f_l_orc f_data_sel f_model_out_prefix [f_feature_set]"); System.out.println("num of args is "+ args.length); for(int i=0; i <args.length; i++) System.out.println("arg is: " + args[i]); System.exit(0); } long start_time = System.currentTimeMillis(); boolean is_using_crf = new Boolean(args[0].trim()); NBESTDiscriminativeLearner.usingCRF=is_using_crf; String f_l_train_nbest=args[1].trim(); String f_l_orc=args[2].trim(); String f_data_sel=args[3].trim(); String f_model_out_prefix =args[4].trim(); String initModelFile=null; if(args.length>5) initModelFile = args[5].trim(); int max_loop = 3;//TODO List<String> l_file_train_nbest = DiscriminativeSupport.readFileList(f_l_train_nbest); List<String> l_file_orc = DiscriminativeSupport.readFileList(f_l_orc); HashMap tbl_sent_selected = DiscriminativeSupport.setupDataSelTbl(f_data_sel);//for data selection //####### training // ###### INIT Model ################### NBESTDiscriminativeLearner ndl = null; //TODO optimal parameters int trainSize = 610000; int batchUpdateSize = 1; int convergePass = 1; double initGain = 0.1; double sigma = 0.5; boolean isMinimizeScore = false; //setup optimizer GradientBasedOptimizer optimizer = null; if(usingCRF){ HashMap<String,Double> crfModel = new HashMap<String,Double>(); if(initModelFile!=null) DiscriminativeSupport.loadModel(initModelFile, crfModel, null); else{ System.out.println("In crf, must specify feature set"); System.exit(0); } optimizer = new DefaultCRF(crfModel, trainSize, batchUpdateSize, convergePass, initGain, sigma, isMinimizeScore); optimizer.initModel(0, 0);//TODO optimal initial parameters ndl = new NBESTDiscriminativeLearner(optimizer, new HashSet<String>(crfModel.keySet())); ndl.resetBaselineFeat();//add and init baseline feature }else{//perceptron HashMap<String,Double> perceptronSumModel = new HashMap<String,Double>(); HashMap<String,Double> perceptronAvgModel = new HashMap<String,Double>(); HashMap<String,Double> perceptronModel = new HashMap<String,Double>(); if(initModelFile!=null){ DiscriminativeSupport.loadModel(initModelFile, perceptronModel, null); perceptronModel.put(baselineFeatName, 1.0); }else{ System.out.println("In perceptron, should specify feature set"); } optimizer = new DefaultPerceptron(perceptronSumModel, perceptronAvgModel,trainSize, batchUpdateSize, convergePass, initGain, sigma, isMinimizeScore); ndl = new NBESTDiscriminativeLearner(optimizer, new HashSet<String>(perceptronModel.keySet())); ndl.resetBaselineFeat(); } //TODO ndl.optimizer.set_no_cooling(); //ndl.p_optimizer.set_no_regularization(); for(int loop_id=0; loop_id<max_loop; loop_id++){ System.out.println("###################################Loop " + loop_id); for(int fid=0; fid < l_file_train_nbest.size(); fid++){ System.out.println("#######Process file id " + fid); BufferedReader t_reader_nbest = FileUtilityOld.getReadFileStream((String)l_file_train_nbest.get(fid),"UTF-8"); BufferedReader t_reader_orc = FileUtilityOld.getReadFileStream((String)l_file_orc.get(fid),"UTF-8"); String line=null; int old_sent_id=-1; ArrayList<String> nbest = new ArrayList<String>(); while((line=FileUtilityOld.readLineLzf(t_reader_nbest))!=null){ String[] fds = line.split("\\s+\\|{3}\\s+"); int new_sent_id = new Integer(fds[0]); if(old_sent_id!=-1 && old_sent_id!=new_sent_id){ String hyp_oracle = FileUtilityOld.readLineLzf(t_reader_orc); if(tbl_sent_selected.containsKey(old_sent_id)){ System.out.println("#Process sentence " + old_sent_id); ndl.processOneSent( nbest, hyp_oracle, null); }else System.out.println("#Skip sentence " + old_sent_id); nbest.clear(); //Support.print_hash_tbl(perceptron.g_tbl_sum_model);//debug } old_sent_id = new_sent_id; nbest.add(line); } //last nbest String hyp_oracle = FileUtilityOld.readLineLzf(t_reader_orc); if(tbl_sent_selected.containsKey(old_sent_id)){ System.out.println("#Process sentence " + old_sent_id); ndl.processOneSent( nbest, hyp_oracle, null); }else System.out.println("#Skip sentence " + old_sent_id); nbest.clear(); FileUtilityOld.closeReadFile(t_reader_nbest); FileUtilityOld.closeReadFile(t_reader_orc); } if(usingCRF){ ndl.updateModel(true); FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".crf." + loop_id, false, false); }else{//perceptron ndl.updateModel(true); ((DefaultPerceptron)optimizer).force_update_avg_model(); FileUtilityOld.printHashTbl(optimizer.getSumModel(), f_model_out_prefix+".sum." + loop_id, false, false); FileUtilityOld.printHashTbl(optimizer.getAvgModel(), f_model_out_prefix+".avg." + loop_id, false, true); } System.out.println("Time cost: " + ((System.currentTimeMillis()-start_time)/1000)); //clean up } } }