package joshua.discriminative.variational_decoder; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import joshua.corpus.vocab.BuildinSymbol; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.ff.FeatureFunction; import joshua.decoder.hypergraph.DiskHyperGraph; import joshua.decoder.hypergraph.HyperGraph; import joshua.decoder.hypergraph.KBestExtractor; import joshua.decoder.hypergraph.TrivialInsideOutside; import joshua.discriminative.FileUtilityOld; import joshua.discriminative.feature_related.feature_function.FeatureTemplateBasedFF; import joshua.discriminative.feature_related.feature_function.EdgeTblBasedBaselineFF; import joshua.discriminative.ranker.HGRanker; public class VariationalDecoder { private int topN=300; private boolean useUniqueNbest =true; private boolean useTreeNbest = false; private boolean addCombinedCost = true; private KBestExtractor kbestExtractor; SymbolTable symbolTbl; List<FeatureFunction> featFunctions ; //for HG reranking and kbest extraction HashMap<VariationalNgramApproximator, FeatureTemplateBasedFF> approximatorMap; double insideOutsideScalingFactor= 0.5; HGRanker ranker; public VariationalDecoder(){ //do nothing; } // return changed hg public HyperGraph decoding(HyperGraph hg, int sentenceID, BufferedWriter out) { //=== step-1: run inside-outside //note, inside and outside will use the transition_cost of each hyperedge, this cost is already linearly interpolated TrivialInsideOutside pInsideOutside = new TrivialInsideOutside(); pInsideOutside.runInsideOutside(hg, 0, 1, insideOutsideScalingFactor);//ADD_MODE=0=sum; LOG_SEMIRING=1; //=== initialize baseline table //TODO:???????????????????? ((EdgeTblBasedBaselineFF)featFunctions.get(0)).collectTransitionLogPs(hg); //=== step-2: model extraction based on the definition of Q for(Map.Entry<VariationalNgramApproximator, FeatureTemplateBasedFF> entry : approximatorMap.entrySet()){ VariationalNgramApproximator approximator = entry.getKey(); FeatureTemplateBasedFF featureFunction = entry.getValue(); HashMap<String, Double> model = approximator.estimateModel(hg, pInsideOutside); featureFunction.setModel(model); } //clean up pInsideOutside.clearState(); //=== step-3: rank the HG using the baseline and variational feature this.ranker.rankHG(hg); //=== step-4: kbest extraction from the reranked HG: remember to add the new feature function into the model list try{ kbestExtractor.lazyKBestExtractOnHG(hg, this.featFunctions, this.topN, sentenceID, out); } catch (IOException e) { e.printStackTrace(); } return hg; } public void initializeDecoder(String configFile){ VariationalDecoderConfiguration.readConfigFile(configFile); this.symbolTbl = new BuildinSymbol(null); this.featFunctions = new ArrayList<FeatureFunction>(); this.ranker = new HGRanker(featFunctions); this.approximatorMap = new HashMap<VariationalNgramApproximator, FeatureTemplateBasedFF>(); VariationalDecoderConfiguration.initializeModels(configFile, this.symbolTbl, this.featFunctions, this.approximatorMap); this.insideOutsideScalingFactor = VariationalDecoderConfiguration.insideoutsideScalingFactor; //this.kbestExtractor = new KBestExtractor(p_symbol); this.kbestExtractor = new KBestExtractor(this.symbolTbl, this.useUniqueNbest, this.useTreeNbest, false, this.addCombinedCost, false, true); } public void decodingTestSet(String testItemsFile, String testRulesFile, int numSents, String nbestFile) { BufferedWriter nbestWriter = FileUtilityOld.getWriteFileStream(nbestFile); System.out.println("############Process file " + testItemsFile); DiskHyperGraph diskHG = new DiskHyperGraph(symbolTbl, VariationalDecoderConfiguration.ngramStateID, true, null); //have model costs stored diskHG.initRead(testItemsFile, testRulesFile,null); for(int sentID=0; sentID < numSents; sentID ++){ System.out.println("#Process sentence " + sentID); HyperGraph testhg = diskHG.readHyperGraph(); /*if(use_constituent_decoding) vdecoder.constitudent_decoding(hg_test, sent_id, t_writer_nbest); else*/ decoding(testhg, sentID, nbestWriter); //Thread.sleep(3000); } FileUtilityOld.closeWriteFile(nbestWriter); } public void writeConfigFile(double[] newWeights, String template, String fileToWrite){ BufferedReader configReader = FileUtilityOld.getReadFileStream(template); BufferedWriter configWriter = FileUtilityOld.getWriteFileStream(fileToWrite); String line; int featID = 0; while ((line = FileUtilityOld.readLineLzf(configReader)) != null) { line = line.trim(); if (line.matches("^\\s*\\#.*$") || line.matches("^\\s*$") || line.indexOf("=") != -1) {//comment, empty line, or parameter lines: just copy FileUtilityOld.writeLzf(configWriter, line + "\n"); }else{//models: replace the weight String[] fds = line.split("\\s+"); StringBuffer new_line = new StringBuffer(); if(fds[fds.length-1].matches("^[\\d\\.\\-\\+]+")==false){System.out.println("last field is not a number, must be wrong; the field is: " + fds[fds.length-1]); System.exit(1);}; for(int i=0; i<fds.length-1; i++){ new_line.append(fds[i]); new_line.append(" "); } new_line.append(newWeights[featID++]); FileUtilityOld.writeLzf(configWriter, new_line.toString() + "\n"); } } if(featID!=newWeights.length){System.out.println("number of models does not match number of weights, must be wrong"); System.exit(1);}; FileUtilityOld.closeReadFile(configReader); FileUtilityOld.closeWriteFile(configWriter); } /*this assumes that the weight_vector is ordered according to the decoder config file * */ public void changeFeatureWeightVector(double[] weight_vector){ if(featFunctions.size()!=weight_vector.length){ System.out.println("In updateFeatureWeightVector: number of weights does not match number of feature functions"); System.exit(0); } for(int i=0; i<featFunctions.size(); i++){ FeatureFunction ff = featFunctions.get(i); double old_weight = ff.getWeight(); ff.setWeight(weight_vector[i]); System.out.println("Feature function : " + ff.getClass().getSimpleName() + "; weight changed from " + old_weight + " to " + ff.getWeight()); } } //============================ main function ============================== public static void main(String[] args) throws InterruptedException, IOException { /*//##read configuration information if(args.length<8){ System.out.println("wrong command, correct command should be: java Perceptron_HG is_crf lf_train_items lf_train_rules lf_orc_items lf_orc_rules f_l_num_sents f_data_sel f_model_out_prefix use_tm_feat use_lm_feat use_edge_bigram_feat_only f_feature_set use_joint_tm_lm_feature"); System.out.println("num of args is "+ args.length); for(int i=0; i <args.length; i++)System.out.println("arg is: " + args[i]); System.exit(0); }*/ if(args.length!=5){ System.out.println("Wrong number of parameters, it must be 5"); System.exit(1); } //long start_time = System.currentTimeMillis(); String testItemsFile=args[0].trim(); String testRulesFile=args[1].trim(); int numSents=new Integer(args[2].trim()); String nbestFile=args[3].trim(); String configFile=args[4].trim(); VariationalDecoder vdecoder = new VariationalDecoder(); vdecoder.initializeDecoder(configFile); vdecoder.decodingTestSet(testItemsFile, testRulesFile, numSents, nbestFile); } /* //return changed hg public HyperGraph decoding_old(HyperGraph hg, int sentenceID, BufferedWriter out){ //### step-1: run inside-outside TrivialInsideOutside p_inside_outside = new TrivialInsideOutside(); p_inside_outside.run_inside_outside(hg, 0, 1);//ADD_MODE=0=sum; LOG_SEMIRING=1; //### step-2: model extraction based on the definition of Q //step-2.1: baseline feature: do not use insideoutside this.baseline_feat_tbl.clear(); FeatureExtractionHG.feature_extraction_hg(hg, this.baseline_feat_tbl, null, this.baseline_ft_template); //step-2.2: ngram feature this.ngram_feat_tbl.clear(); FeatureExtractionHG.feature_extraction_hg(hg, p_inside_outside, this.ngram_feat_tbl, null, this.ngram_ft_template); VariationalLMFeature.getNormalizedLM(this.ngram_feat_tbl, this.baseline_lm_order, false, false, false); //step-2.2: constituent feature this.constituent_feat_tbl.clear(); FeatureExtractionHG.feature_extraction_hg(hg, p_inside_outside, this.constituent_feat_tbl, null, this.constituent_template); //clean p_inside_outside.clear_state(); //### step-3: rank the HG using the baseline and variational feature l_feat_functions.clear(); FeatureFunction base_ff = new BaselineFF(this.baseline_feat_id, this.baseline_weight, this.baseline_feat_tbl); l_feat_functions.add(base_ff); FeatureFunction variational_ff = new VariationalLMFeature(this.baseline_lm_feat_id, this.baseline_lm_order, this.p_symbol, this.ngram_feat_tbl, this.ngam_weight, false, false, false); l_feat_functions.add(variational_ff); FeatureFunction constituent_ff = new BaselineFF(this.constituent_feat_id, this.constituent_weight, this.constituent_feat_tbl); l_feat_functions.add(constituent_ff); RankHG.rankHG(hg, l_feat_functions); //### step-4: kbest extraction from the reranked HG: remember to add the new feature function into the model list //two features: baseline feature, and reranking feature KbestExtraction kbestExtractor = new KbestExtraction(p_symbol); kbestExtractor.lazy_k_best_extract_hg(hg, this.l_feat_functions, this.topN, this.use_unique_nbest, sentenceID, out, this.use_tree_nbest, this.add_combined_cost); return hg; } // return changed hg public HyperGraph constitudent_decoding(HyperGraph hg, int sentenceID, BufferedWriter out){ ConstituentVariationalDecoder decoder = new ConstituentVariationalDecoder(); decoder.decoding(hg); //### step-2: model extraction based on the definition of Q //step-2.1: baseline feature: do not use insideoutside this.baseline_feat_tbl.clear(); FeatureExtractionHG.feature_extraction_hg(hg, this.baseline_feat_tbl, null, this.baseline_ft_template); //### step-3: rank the HG using the baseline and variational feature l_feat_functions.clear(); FeatureFunction base_ff = new BaselineFF(this.baseline_feat_id, this.baseline_weight, this.baseline_feat_tbl); l_feat_functions.add(base_ff); //RankHG.rankHG(hg, l_feat_functions); //### step-4: kbest extraction from the reranked HG: remember to add the new feature function into the model list //two features: baseline feature, and reranking feature this.kbestExtractor.lazy_k_best_extract_hg(hg, this.l_feat_functions, this.topN, this.use_unique_nbest, sentenceID, out, this.use_tree_nbest, this.add_combined_cost); return hg; } */ }