package joshua.discriminative.training.risk_annealer.nbest; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.logging.Logger; import joshua.discriminative.FileUtilityOld; import joshua.discriminative.training.NbestMerger; import joshua.discriminative.training.risk_annealer.AbstractMinRiskMERT; import joshua.discriminative.training.risk_annealer.DeterministicAnnealer; import joshua.discriminative.training.risk_annealer.GradientComputer; import joshua.discriminative.training.risk_annealer.hypergraph.MRConfig; import joshua.util.FileUtility; /** * @author Zhifei Li, <zhifei.work@gmail.com> * @version $LastChangedDate: 2008-10-20 00:12:30 -0400 $ */ public abstract class NbestMinRiskDAMert extends AbstractMinRiskMERT { int totalNumHyp = 0; String nbestPrefix; boolean useShortestRef; private static Logger logger = Logger.getLogger(NbestMinRiskDAMert.class.getSimpleName()); public NbestMinRiskDAMert(boolean useShortestRef, String decoderConfigFile, int numSentInTrainSet, String[] refFiles, String nbestPrefix) { super(decoderConfigFile, numSentInTrainSet, refFiles); this.nbestPrefix = nbestPrefix; this.useShortestRef = useShortestRef; initialize(); } public abstract void writeConfigFile(double[] weights, String configFileTemplate, String configOutFile); public void mainLoop(){ for(int iter=1; iter<=MRConfig.maxNumIter; iter++){ //#re-normalize weights normalizeWeightsByFirstFeature(lastWeightVector,0); //############decoding writeConfigFile(lastWeightVector, configFile, configFile+"." + iter); String curNbestFile = nbestPrefix +"." + iter; decodingTestSet(lastWeightVector, curNbestFile); //call decoder to produce an nbest using the new weight vector //##############merge nbest and check convergency String oldNbestMergedFile = nbestPrefix +".merged." + (iter-1); String newNbestMergedFile = nbestPrefix +".merged." + (iter); if(iter ==1){ copyNbest(curNbestFile, newNbestMergedFile); }else{ boolean haveNewHyp = true; if(true){ int newTotalNumHyp = NbestMerger.mergeNbest(oldNbestMergedFile, curNbestFile, newNbestMergedFile); if(newTotalNumHyp!=totalNumHyp) haveNewHyp = true; totalNumHyp = newTotalNumHyp; }else{ haveNewHyp = mergeNbest(oldNbestMergedFile, curNbestFile, newNbestMergedFile); } if(haveNewHyp==false) { System.out.println("No new hypotheses generated at iteration " + iter); break; } } //String f_nbest_merged_new = "C:/Users/zli/Documents/minriskannealer.nbest.merged.17";//???????????? //String f_nbest_merged_new = "C:/Users/zli/Documents/minriskannealer.nbest.merged.1";//???????????? GradientComputer gradientComputer = new NbestRiskGradientComputer(newNbestMergedFile, referenceFiles, useShortestRef, numTrainingSentence, numPara, MRConfig.gainFactor, 1.0, 0.0, true, MRConfig.linearCorpusGainThetas); annealer = new DeterministicAnnealer( numPara, lastWeightVector, MRConfig.isMinimizer, gradientComputer, MRConfig.useL2Regula, MRConfig.varianceForL2, MRConfig.useModelDivergenceRegula, MRConfig.lambda, MRConfig.printFirstN); if(MRConfig.annealingMode==0)//do not anneal lastWeightVector = annealer.runWithoutAnnealing(MRConfig.isScalingFactorTunable, MRConfig.startScaleAtNoAnnealing, MRConfig.temperatureAtNoAnnealing); else if(MRConfig.annealingMode==1) lastWeightVector = annealer.runQuenching(1.0); else if(MRConfig.annealingMode==2) lastWeightVector = annealer.runDAAndQuenching(); else{ logger.severe("unsorported anneal mode, " + MRConfig.annealingMode); System.exit(0); } //last_weight_vector will be used intial weights in the next iteration } //final output normalizeWeightsByFirstFeature(lastWeightVector,0); writeConfigFile(lastWeightVector, configFile, configFile+".final"); System.out.println("#### Final weights are: "); annealer.getLBFGSRunner().printStatistics(-1, -1, null, lastWeightVector); } //return false: if the nbest does not add any new hyp //TODO: decide converged if the number of new hyp generate is very small //TODO: terminate decoding when the weights does not change much; this one makes more sense, as if the weights do not change much; then new hypotheses will be rare public static boolean mergeNbest(String oldMergedNbestFile, String newNbestFile, String newMergedNbestFile){ boolean haveNewHyp =false; BufferedReader oldMergedNbestReader = FileUtilityOld.getReadFileStream(oldMergedNbestFile); BufferedReader newNbestReader = FileUtilityOld.getReadFileStream(newNbestFile); BufferedWriter newMergedNbestReader = FileUtilityOld.getWriteFileStream(newMergedNbestFile); int oldSentID=-1; String line; String previousLineInNewNbest = FileUtilityOld.readLineLzf(newNbestReader);; HashMap<String, String> oldNbests = new HashMap<String, String>();//key: hyp itself, value: remaining fds exlcuding sent_id while((line=FileUtilityOld.readLineLzf(oldMergedNbestReader))!=null){ String[] fds = line.split("\\s+\\|{3}\\s+"); int newSentID = new Integer(fds[0]); if(oldSentID!=-1 && oldSentID!=newSentID){ boolean[] t_have_new_hyp = new boolean[1]; previousLineInNewNbest = processNbest(newNbestReader, newMergedNbestReader, oldSentID, oldNbests, previousLineInNewNbest, t_have_new_hyp); if(t_have_new_hyp[0]==true) haveNewHyp = true; } oldSentID = newSentID; oldNbests.put(fds[1], fds[2]);//last field is not needed } //last nbest boolean[] t_have_new_hyp = new boolean[1]; previousLineInNewNbest= processNbest(newNbestReader, newMergedNbestReader, oldSentID, oldNbests, previousLineInNewNbest, t_have_new_hyp); if(previousLineInNewNbest!=null){ System.out.println("last line is not null, must be wrong"); System.exit(0); } if(t_have_new_hyp[0]==true) haveNewHyp = true; FileUtilityOld.closeReadFile(oldMergedNbestReader); FileUtilityOld.closeReadFile(newNbestReader); FileUtilityOld.closeWriteFile(newMergedNbestReader); return haveNewHyp; } private static String processNbest(BufferedReader newNbestReader, BufferedWriter newMergedNbestReader, int oldSentID, HashMap<String, String> oldNbests, String previousLine, boolean[] have_new_hyp){ have_new_hyp[0] = false; String previousLineInNewNbest = previousLine; // #### read new nbest and merge into nbests while(true){ String[] t_fds = previousLineInNewNbest.split("\\s+\\|{3}\\s+"); int t_new_id = new Integer(t_fds[0]); if( t_new_id == oldSentID){ if(oldNbests.containsKey(t_fds[1])==false){//new hyp have_new_hyp[0] = true; oldNbests.put(t_fds[1], t_fds[2]);//merge into nbests } }else{ break; } previousLineInNewNbest = FileUtilityOld.readLineLzf(newNbestReader); if(previousLineInNewNbest==null) break; } //#### print the nbest: order is not important; and the last field is ignored for (Map.Entry<String, String> entry : oldNbests.entrySet()){ FileUtilityOld.writeLzf(newMergedNbestReader, oldSentID + " ||| " + entry.getKey() + " ||| " + entry.getValue() + "\n"); } oldNbests.clear(); return previousLineInNewNbest; } //return false: if the nbest does not add any new hyp public static void copyNbest(String newNbestFile, String newMergedNbestFile){ /* BufferedReader newNbestReader = FileUtilityOld.getReadFileStream(newNbestFile); BufferedWriter newMergedNbestReader = FileUtilityOld.getWriteFileStream(newMergedNbestFile); String line; while((line=FileUtilityOld.readLineLzf(newNbestReader))!=null){ String[] fds = line.split("\\s+\\|{3}\\s+"); FileUtilityOld.writeLzf(newMergedNbestReader, fds[0] + " ||| " + fds[1] + " ||| " + fds[2] + "\n"); } FileUtilityOld.closeReadFile(newNbestReader); FileUtilityOld.closeWriteFile(newMergedNbestReader);*/ try { FileUtility.copyFile(newNbestFile, newMergedNbestFile); } catch (IOException e) { // TODO Auto-generated catch block e.printStackTrace(); } } // set lastWeightVector and google linear corpus private void initialize() { //===== read configurations MRConfig.readConfigFile(this.configFile); logger.info("intilize features and weights"); //== get the weights List<Double> weights = readBaselineFeatureWeights(this.configFile); /**initialize the weights*/ int numPara=weights.size(); this.lastWeightVector = new double[numPara]; for(int i=0; i<lastWeightVector.length; i++){ lastWeightVector[i] = weights.get(i); logger.info("weight: " + lastWeightVector[i]); } this.numPara = this.lastWeightVector.length; } }