package joshua.discriminative.training.risk_annealer.hypergraph; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.logging.Logger; import joshua.corpus.vocab.BuildinSymbol; import joshua.corpus.vocab.SymbolTable; import joshua.decoder.BLEU; import joshua.decoder.JoshuaDecoder; import joshua.decoder.ff.FeatureFunction; import joshua.decoder.hypergraph.DiskHyperGraph; import joshua.decoder.hypergraph.KBestExtractor; import joshua.discriminative.FileUtilityOld; import joshua.discriminative.feature_related.feature_function.FeatureTemplateBasedFF; import joshua.discriminative.feature_related.feature_template.BaselineFT; import joshua.discriminative.feature_related.feature_template.FeatureTemplate; import joshua.discriminative.feature_related.feature_template.IndividualBaselineFT; import joshua.discriminative.feature_related.feature_template.MicroRuleFT; import joshua.discriminative.feature_related.feature_template.NgramFT; import joshua.discriminative.feature_related.feature_template.TMFT; import joshua.discriminative.feature_related.feature_template.TargetTMFT; import joshua.discriminative.ranker.HGRanker; import joshua.discriminative.training.NbestMerger; import joshua.discriminative.training.risk_annealer.AbstractMinRiskMERT; import joshua.discriminative.training.risk_annealer.DeterministicAnnealer; import joshua.discriminative.training.risk_annealer.GradientComputer; import joshua.discriminative.training.risk_annealer.nbest.NbestMinRiskDAMert; import joshua.util.FileUtility; public class HGMinRiskDAMert extends AbstractMinRiskMERT { JoshuaDecoder joshuaDecoder; String sourceTrainingFile; SymbolTable symbolTbl; List<FeatureTemplate> featTemplates; HashMap<String, Integer> featureStringToIntegerMap; MicroRuleFT microRuleFeatureTemplate = null; String hypFilePrefix;//training hypothesis file prefix String curConfigFile; String curFeatureFile; String curHypFilePrefix; boolean useIntegerString = false;//TODO boolean haveRefereces = true; int oldTotalNumHyp = 0; //== for loss-augmented pruning double curLossScale = 0; int oralceFeatureID = 0; static private Logger logger = Logger.getLogger(HGMinRiskDAMert.class.getSimpleName()); public HGMinRiskDAMert(String configFile, int numSentInDevSet, String[] devRefs, String hypFilePrefix, SymbolTable symbolTbl, String sourceTrainingFile) { super(configFile, numSentInDevSet, devRefs); this.symbolTbl = symbolTbl; if(devRefs!=null){ haveRefereces = true; for(String refFile : devRefs){ logger.info("add symbols for file " + refFile); addAllWordsIntoSymbolTbl(refFile, symbolTbl); } }else{ haveRefereces = false; } this.initialize(); this.hypFilePrefix = hypFilePrefix; this.sourceTrainingFile = sourceTrainingFile; if(MRConfig.oneTimeHGRerank==false){ joshuaDecoder = JoshuaDecoder.getUninitalizedDecoder(); joshuaDecoder.initialize(configFile); } //oralce id-realted Integer id = inferOracleFeatureID(this.configFile); if(id != null && MRConfig.lossAugmentedPrune==false ){ logger.severe("lossAugmentedPrune=false, but has a oracle model"); System.exit(1); } if(MRConfig.lossAugmentedPrune == true){ if(id==null){ logger.severe("no oralce model while doing loss-augmented pruning, must be wrong"); System.exit(1); }else{ this.oralceFeatureID = id; } this.curLossScale = MRConfig.startLossScale; logger.info("startLossScale="+MRConfig.startLossScale+"; oralceFeatureID="+this.oralceFeatureID); } if(haveRefereces==false){//minimize conditional entropy MRConfig.temperatureAtNoAnnealing = 1;//TODO }else{ if(MRConfig.useModelDivergenceRegula){ System.out.println("supervised training, we should not do model divergence regular"); System.exit(0); } } } public void mainLoop(){ /**Here, we need multiple iterations as we do pruning when generate the hypergraph * Note that DeterministicAnnealer itself many need to solve an optimization problem at each temperature, * and each optimization is solved by LBFGS which itself involves many iterations (of computing gradients) * */ for(int iter=1; iter<=MRConfig.maxNumIter; iter++){ //==== re-normalize weights, and save config files this.curConfigFile = configFile+"." + iter; this.curFeatureFile = MRConfig.featureFile +"." + iter; if(MRConfig.normalizeByFirstFeature) normalizeWeightsByFirstFeature(lastWeightVector, 0); saveLastModel(configFile, curConfigFile, MRConfig.featureFile, curFeatureFile); //writeConfigFile(lastWeightVector, configFile, configFile+"." + iter); //==== re-decode based on the new weights if(MRConfig.oneTimeHGRerank){ this.curHypFilePrefix = hypFilePrefix; }else{ this.curHypFilePrefix = hypFilePrefix +"." + iter; decodingTestSet(null, curHypFilePrefix); } //==== merge hypergrphs and check convergency if(MRConfig.hyp_merge_mode>0){ try { String oldMergedFile = hypFilePrefix +".merged." + (iter-1); String newMergedFile = hypFilePrefix +".merged." + (iter); int newTotalNumHyp =0; if(MRConfig.use_kbest_hg==false && MRConfig.hyp_merge_mode==2){ System.out.println("use_kbest_hg==false && MRConfig.hyp_merge_mode; we will look at the nbest"); if(iter ==1){ FileUtility.copyFile(curHypFilePrefix, newMergedFile); newTotalNumHyp = FileUtilityOld.numberLinesInFile(newMergedFile); }else{ newTotalNumHyp = NbestMerger.mergeNbest(oldMergedFile, curHypFilePrefix, newMergedFile); } }else{ if(iter ==1){ FileUtility.copyFile(curHypFilePrefix+".hg.items", newMergedFile+".hg.items"); FileUtility.copyFile(curHypFilePrefix+".hg.rules", newMergedFile+".hg.rules"); }else{ boolean saveModelCosts = true; /**TODO: this assumes that the feature values for the same hypothesis does not change, * though the weights for these features can change. In particular, this means * we cannot tune the weight for the aggregate discriminative model while we are tunining the individual * discriminative feature. This is also true for the bestHyperEdge pointer.*/ newTotalNumHyp = DiskHyperGraph.mergeDiskHyperGraphs(MRConfig.ngramStateID, saveModelCosts, this.numTrainingSentence, MRConfig.use_unique_nbest, MRConfig.use_tree_nbest, oldMergedFile, curHypFilePrefix, newMergedFile, (MRConfig.hyp_merge_mode==2)); } this.curHypFilePrefix = newMergedFile; } //check convergence double newRatio = (newTotalNumHyp-oldTotalNumHyp)*1.0/oldTotalNumHyp; if(iter <=2 || newRatio > MRConfig.stop_hyp_ratio) { System.out.println("oldTotalNumHyp=" + oldTotalNumHyp + "; newTotalNumHyp=" + newTotalNumHyp + "; newRatio="+ newRatio +"; at iteration " + iter); oldTotalNumHyp = newTotalNumHyp; }else{ System.out.println("No new hypotheses generated at iteration " + iter + " for stop_hyp_ratio=" + MRConfig.stop_hyp_ratio); break; } } catch (IOException e) { e.printStackTrace(); } } Map<String, Integer> ruleStringToIDTable = DiskHyperGraph.obtainRuleStringToIDTable(curHypFilePrefix+".hg.rules"); //try to abbrevate the featuers if possible addAbbreviatedNames(ruleStringToIDTable); //micro rule features if(MRConfig.useSparseFeature && MRConfig.useMicroTMFeat){ this.microRuleFeatureTemplate.setupTbl(ruleStringToIDTable, featureStringToIntegerMap.keySet()); } //=====compute onebest BLEU computeOneBestBLEU(curHypFilePrefix); //==== run DA annealer to obtain optimal weight vector using the hypergraphs as training data HyperGraphFactory hgFactory = new HyperGraphFactory(curHypFilePrefix, referenceFiles, MRConfig.ngramStateID, symbolTbl, this.haveRefereces); GradientComputer gradientComputer = new HGRiskGradientComputer(MRConfig.useSemiringV2, numTrainingSentence, numPara, MRConfig.gainFactor, 1.0, 0.0, true, MRConfig.fixFirstFeature, hgFactory, MRConfig.maxNumHGInQueue, MRConfig.numThreads, MRConfig.ngramStateID, MRConfig.baselineLMOrder, symbolTbl, featureStringToIntegerMap, featTemplates, MRConfig.linearCorpusGainThetas, this.haveRefereces ); annealer = new DeterministicAnnealer(numPara, lastWeightVector, MRConfig.isMinimizer, gradientComputer, MRConfig.useL2Regula, MRConfig.varianceForL2, MRConfig.useModelDivergenceRegula, MRConfig.lambda, MRConfig.printFirstN); if(MRConfig.annealingMode==0)//do not anneal lastWeightVector = annealer.runWithoutAnnealing(MRConfig.isScalingFactorTunable, MRConfig.startScaleAtNoAnnealing, MRConfig.temperatureAtNoAnnealing); else if(MRConfig.annealingMode==1) lastWeightVector = annealer.runQuenching(1.0); else if(MRConfig.annealingMode==2) lastWeightVector = annealer.runDAAndQuenching(); else{ logger.severe("unsorported anneal mode, " + MRConfig.annealingMode); System.exit(0); } //=====re-compute onebest BLEU if(MRConfig.normalizeByFirstFeature) normalizeWeightsByFirstFeature(lastWeightVector, 0); computeOneBestBLEU(curHypFilePrefix); //@todo: check convergency //@todo: delete files if(false){ FileUtility.deleteFile(this.curHypFilePrefix+".hg.items"); FileUtility.deleteFile(this.curHypFilePrefix+".hg.rules"); } if(MRConfig.lossAugmentedPrune){ this.curLossScale -= MRConfig.lossDecreaseConstant; if(this.curLossScale<=0) this.curLossScale = 0; } } //final output if(MRConfig.normalizeByFirstFeature) normalizeWeightsByFirstFeature(lastWeightVector, 0); saveLastModel(configFile, configFile + ".final", MRConfig.featureFile, MRConfig.featureFile + ".final"); //writeConfigFile(lastWeightVector, configFile, configFile+".final"); //System.out.println("#### Final weights are: "); //annealer.getLBFGSRunner().printStatistics(-1, -1, null, lastWeightVector); } public void decodingTestSet(double[] weights, String hypFilePrefix){ /**three scenarios: * (1) individual baseline features * (2) baselineCombo + sparse feature * (3) individual baseline features + sparse features */ if(MRConfig.useSparseFeature) joshuaDecoder.changeFeatureWeightVector( getIndividualBaselineWeights(), this.curFeatureFile ); else joshuaDecoder.changeFeatureWeightVector( getIndividualBaselineWeights(), null); //call Joshua decoder to produce an hypergraph using the new weight vector joshuaDecoder.decodeTestSet(sourceTrainingFile, hypFilePrefix); } private void computeOneBestBLEU(String curHypFilePrefix){ if(this.haveRefereces==false) return; double bleuSum = 0; double googleGainSum = 0; double modelSum = 0; //==== feature-based feature int featID = 999; double weight = 1.0; HashSet<String> restrictedFeatureSet = null; HashMap<String, Double> modelTbl = obtainModelTable(this.featureStringToIntegerMap, this.lastWeightVector); //System.out.println("modelTable: " + modelTbl); FeatureFunction ff = new FeatureTemplateBasedFF(featID, weight, modelTbl, this.featTemplates, restrictedFeatureSet); //==== reranker List<FeatureFunction> features = new ArrayList<FeatureFunction>(); features.add(ff); HGRanker reranker = new HGRanker(features); //==== kbest boolean addCombinedCost = false; KBestExtractor kbestExtractor = new KBestExtractor(symbolTbl, MRConfig.use_unique_nbest, MRConfig.use_tree_nbest, false, addCombinedCost, false, true); //==== loop HyperGraphFactory hgFactory = new HyperGraphFactory(curHypFilePrefix, referenceFiles, MRConfig.ngramStateID, symbolTbl, true); hgFactory.startLoop(); for(int sentID=0; sentID< this.numTrainingSentence; sentID ++){ HGAndReferences res = hgFactory.nextHG(); reranker.rankHG(res.hg);//reset best pointer and transition prob String hypSent = kbestExtractor.getKthHyp(res.hg.goalNode, 1, -1, null, null); double bleu = BLEU.computeSentenceBleu(res.referenceSentences, hypSent); bleuSum += bleu; double googleGain = BLEU.computeLinearCorpusGain(MRConfig.linearCorpusGainThetas, res.referenceSentences, hypSent); googleGainSum += googleGain; modelSum += res.hg.bestLogP(); //System.out.println("logP=" + res.hg.bestLogP() + "; Bleu=" + bleu +"; googleGain="+googleGain); } hgFactory.endLoop(); System.out.println("AvgLogP=" + modelSum/this.numTrainingSentence + "; AvgBleu=" + bleuSum/this.numTrainingSentence + "; AvgGoogleGain=" + googleGainSum/this.numTrainingSentence + "; SumGoogleGain=" + googleGainSum); } public void saveLastModel(String configTemplate, String configOutput, String sparseFeaturesTemplate, String sparseFeaturesOutput){ if(MRConfig.useSparseFeature){ JoshuaDecoder.writeConfigFile( getIndividualBaselineWeights(), configTemplate, configOutput, sparseFeaturesOutput); saveSparseFeatureFile(sparseFeaturesTemplate, sparseFeaturesOutput); }else{ JoshuaDecoder.writeConfigFile( getIndividualBaselineWeights(), configTemplate, configOutput, null); } } private void initialize(){ //===== read configurations MRConfig.readConfigFile(this.configFile); //===== initialize googleCorpusBLEU if(MRConfig.useGoogleLinearCorpusGain==true){ //do nothing }else{ logger.severe("On hypergraph, we must use the linear corpus gain."); System.exit(1); } //===== initialize the featureTemplates setupFeatureTemplates(); //====== initialize featureStringToIntegerMap and weights initFeatureMapAndWeights(MRConfig.featureFile); } //TODO: should merge with setupFeatureTemplates in HGMinRiskDAMert private void setupFeatureTemplates(){ this.featTemplates = new ArrayList<FeatureTemplate>(); if(MRConfig.useBaseline){ FeatureTemplate ft = new BaselineFT(MRConfig.baselineFeatureName, true); featTemplates.add(ft); } if(MRConfig.useIndividualBaselines){ for(int id : MRConfig.baselineFeatIDsToTune){ String featName = MRConfig.individualBSFeatNamePrefix +id; FeatureTemplate ft = new IndividualBaselineFT(featName, id, true); featTemplates.add(ft); } } if(MRConfig.useSparseFeature){ if(MRConfig.useMicroTMFeat){ //FeatureTemplate ft = new TMFT(symbolTbl, useIntegerString, MRConfig.useRuleIDName); this.microRuleFeatureTemplate = new MicroRuleFT(MRConfig.useRuleIDName, MRConfig.startTargetNgramOrder, MRConfig.endTargetNgramOrder, MRConfig.wordMapFile); featTemplates.add(microRuleFeatureTemplate); } if(MRConfig.useTMFeat){ FeatureTemplate ft = new TMFT(symbolTbl, useIntegerString, MRConfig.useRuleIDName); featTemplates.add(ft); } if(MRConfig.useTMTargetFeat){ FeatureTemplate ft = new TargetTMFT(symbolTbl, useIntegerString); featTemplates.add(ft); } if(MRConfig.useLMFeat){ FeatureTemplate ft = new NgramFT(symbolTbl, useIntegerString, MRConfig.ngramStateID, MRConfig.baselineLMOrder, MRConfig.startNgramOrder, MRConfig.endNgramOrder); featTemplates.add(ft); } } System.out.println("feature template are " + featTemplates.toString()); } //read feature map into featureStringToIntegerMap //TODO we assume the featureId is the line ID (starting from zero) private void initFeatureMapAndWeights(String featureFile){ featureStringToIntegerMap = new HashMap<String, Integer>(); List<Double> temInitWeights = new ArrayList<Double>(); int featID = 0; //==== baseline feature if(MRConfig.useBaseline){ featureStringToIntegerMap.put(MRConfig.baselineFeatureName, featID++); temInitWeights.add(MRConfig.baselineFeatureWeight); } //==== individual bs feature if(MRConfig.useIndividualBaselines){ List<Double> weights = readBaselineFeatureWeights(this.configFile); for(int id : MRConfig.baselineFeatIDsToTune){ String featName = MRConfig.individualBSFeatNamePrefix + id; featureStringToIntegerMap.put(featName, featID++); double weight = weights.get(id); temInitWeights.add(weight); } } //==== features in file if(MRConfig.useSparseFeature){ BufferedReader reader = FileUtilityOld.getReadFileStream(featureFile ,"UTF-8"); String line; while((line=FileUtilityOld.readLineLzf(reader))!=null){ String[] fds = line.split("\\s+\\|{3}\\s+");// feature_key ||| feature vale; the feature_key itself may contain "|||" StringBuffer featKey = new StringBuffer(); for(int i=0; i<fds.length-1; i++){ featKey.append(fds[i]); if(i<fds.length-2) featKey.append(" ||| "); } double initWeight = new Double(fds[fds.length-1]);//initial weight temInitWeights.add(initWeight); featureStringToIntegerMap.put(featKey.toString(), featID++); } FileUtilityOld.closeReadFile(reader); } //==== initialize lastWeightVector numPara = temInitWeights.size(); lastWeightVector = new double[numPara]; for(int i=0; i<numPara; i++) lastWeightVector[i] = temInitWeights.get(i); } private double[] getIndividualBaselineWeights(){ double baselineWeight = 1.0; if(MRConfig.useBaseline) baselineWeight = getBaselineWeight(); List<Double> weights = readBaselineFeatureWeights(this.configFile); //change the weights we are tunning if(MRConfig.useIndividualBaselines){ for(int id : MRConfig.baselineFeatIDsToTune){ String featName = MRConfig.individualBSFeatNamePrefix +id; int featID = featureStringToIntegerMap.get(featName); weights.set(id, baselineWeight*lastWeightVector[featID]); } } if(MRConfig.lossAugmentedPrune){ String featName = MRConfig.individualBSFeatNamePrefix +this.oralceFeatureID; if(featureStringToIntegerMap.containsKey(featName)){ logger.severe("we are tuning the oracle model, must be wrong in specifying baselineFeatIDsToTune"); System.exit(1); } weights.set(this.oralceFeatureID, this.curLossScale); System.out.println("curLossScale=" + this.curLossScale + "; oralceFeatureID="+this.oralceFeatureID); } double[] res = new double[weights.size()]; for(int i=0; i<res.length; i++) res[i] = weights.get(i); return res; } private double getBaselineWeight(){ String featName = MRConfig.baselineFeatureName; int featID = featureStringToIntegerMap.get(featName); double weight = lastWeightVector[featID]; System.out.println("baseline weight is " + weight); return weight; } private void saveSparseFeatureFile(String fileTemplate, String outputFile){ BufferedReader template = FileUtilityOld.getReadFileStream(fileTemplate,"UTF-8"); BufferedWriter writer = FileUtilityOld.getWriteFileStream(outputFile); String line; while((line=FileUtilityOld.readLineLzf(template))!=null){ //== construct feature name String[] fds = line.split("\\s+\\|{3}\\s+");// feature_key ||| feature vale; the feature_key itself may contain "|||" StringBuffer featKey = new StringBuffer(); for(int i=0; i<fds.length-1; i++){ featKey.append(fds[i]); if(i<fds.length-2) featKey.append(" ||| "); } //== write the learnt weight //double oldWeight = new Double(fds[fds.length-1]);//initial weight int featID = featureStringToIntegerMap.get(featKey.toString()); double newWeight = lastWeightVector[featID];//last model //System.out.println(featKey +"; old=" + oldWeight + "; new=" + newWeight); FileUtilityOld.writeLzf(writer, featKey.toString() + " ||| " + newWeight +"\n"); featID++; } FileUtilityOld.closeReadFile(template); FileUtilityOld.closeWriteFile(writer); } private HashMap<String,Double> obtainModelTable(HashMap<String, Integer> featureStringToIntegerMap, double[] weightVector){ HashMap<String,Double> modelTbl = new HashMap<String,Double>(); for(Map.Entry<String,Integer> entry : featureStringToIntegerMap.entrySet()){ int featID = entry.getValue(); double weight = lastWeightVector[featID];//last model modelTbl.put(entry.getKey(), weight); } return modelTbl; } private void addAbbreviatedNames(Map<String, Integer> rulesIDTable){ // try to abbrevate the featuers if possible if(MRConfig.useRuleIDName){ //add the abbreviated feature name into featureStringToIntegerMap //System.out.println("size1=" + featureStringToIntegerMap.size()); for(Entry<String, Integer> entry : rulesIDTable.entrySet()){ Integer featureID = featureStringToIntegerMap.get(entry.getKey()); if(featureID!=null){ String abbrFeatName = "r" + entry.getValue();//TODO?????? featureStringToIntegerMap.put(abbrFeatName, featureID); //System.out.println("full="+entry.getKey() + "; abbrFeatName="+abbrFeatName + "; id="+featureID); } } //System.out.println("size2=" + featureStringToIntegerMap); //System.exit(0); } } static public void addAllWordsIntoSymbolTbl(String file, SymbolTable symbolTbl){ BufferedReader reader = FileUtilityOld.getReadFileStream(file,"UTF-8"); String line; while((line=FileUtilityOld.readLineLzf(reader))!=null){ symbolTbl.addTerminals(line); } FileUtilityOld.closeReadFile(reader); } public static void main(String[] args) { /*String f_joshua_config="C:/data_disk/java_work_space/discriminative_at_clsp/edu/jhu/joshua/discriminative_training/lbfgs/example.config.javalm"; String f_dev_src="C:/data_disk/java_work_space/sf_trunk/example/example.test.in"; String f_nbest_prefix="C:/data_disk/java_work_space/discriminative_at_clsp/edu/jhu/joshua/discriminative_training/lbfgs/example.nbest.javalm.out"; String f_dev_ref="C:/data_disk/java_work_space/sf_trunk/example/example.test.ref.0"; */ if(args.length<3){ System.out.println("Wrong number of parameters!"); System.exit(1); } //long start_time = System.currentTimeMillis(); String joshuaConfigFile=args[0].trim(); String sourceTrainingFile=args[1].trim(); String hypFilePrefix=args[2].trim(); String[] devRefs = null; if(args.length>3){ devRefs = new String[args.length-3]; for(int i=3; i< args.length; i++){ devRefs[i-3]= args[i].trim(); System.out.println("Use ref file " + devRefs[i-3]); } } SymbolTable symbolTbl = new BuildinSymbol(null); int numSentInDevSet = FileUtilityOld.numberLinesInFile(sourceTrainingFile); HGMinRiskDAMert trainer = new HGMinRiskDAMert(joshuaConfigFile,numSentInDevSet, devRefs, hypFilePrefix, symbolTbl, sourceTrainingFile); trainer.mainLoop(); } }