/* This file is part of the Joshua Machine Translation System. * * Joshua is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 * of the License, or (at your option) any later version. * * This library is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this library; if not, write to the Free * Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, * MA 02111-1307 USA */ package joshua.decoder; import joshua.util.io.LineReader; import joshua.util.FileUtility; import joshua.util.Ngram; import joshua.util.Regex; import java.io.BufferedWriter; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map.Entry; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.PriorityBlockingQueue; import java.util.concurrent.TimeUnit; /** * this class implements: * (1) nbest min risk (MBR) reranking using BLEU as a gain funtion. * <p> * This assume that the string is unique in the nbest list In Hiero, * due to spurious ambiguity, a string may correspond to many * possible derivations, and ideally the probability of a string * should be the sum of all the derivataions leading to that string. * But, in practice, one normally uses a Viterbi approximation: the * probability of a string is its best derivation probability So, * if one want to deal with spurious ambiguity, he/she should do * that before calling this class * * @author Zhifei Li, <zhifei.work@gmail.com> * @version $LastChangedDate: 2010-01-07 22:36:11 -0600 (Thu, 07 Jan 2010) $ */ public class NbestMinRiskReranker { //TODO: this functionality is not implemented yet; default is to produce 1best without any feature scores; boolean produceRerankedNbest = false; double scalingFactor = 1.0; static int bleuOrder = 4; static boolean doNgramClip = true; static boolean useGoogleLinearCorpusGain = false; final PriorityBlockingQueue<RankerResult> resultsQueue = new PriorityBlockingQueue<RankerResult>(); public NbestMinRiskReranker(boolean produceRerankedNbest, double scalingFactor) { this.produceRerankedNbest = produceRerankedNbest; this.scalingFactor = scalingFactor; } public String processOneSent( List<String> nbest, int sentID) { System.out.println("Now process sentence " + sentID); //step-0: preprocess //assumption: each hyp has a formate: "sent_id ||| hyp_itself ||| feature scores ||| linear-combination-of-feature-scores(this should be logP)" List<String> hypsItself = new ArrayList<String>(); //ArrayList<String> l_feat_scores = new ArrayList<String>(); List<Double> baselineScores = new ArrayList<Double>(); // linear combination of all baseline features List<HashMap<String,Integer>> ngramTbls = new ArrayList<HashMap<String,Integer>>(); List<Integer> sentLens = new ArrayList<Integer>(); for (String hyp : nbest) { String[] fds = Regex.threeBarsWithSpace.split(hyp); int tSentID = Integer.parseInt(fds[0]); if (sentID != tSentID) { throw new RuntimeException("sentence_id does not match"); } String hypothesis = (fds.length==4) ? fds[1] : ""; hypsItself.add(hypothesis); String[] words = Regex.spaces.split(hypothesis); sentLens.add(words.length); HashMap<String,Integer> ngramTbl = new HashMap<String,Integer>(); Ngram.getNgrams(ngramTbl, 1, bleuOrder, words); ngramTbls.add(ngramTbl); //l_feat_scores.add(fds[2]); // The value of finalIndex is expected to be 3, // unless the hyp_itself is empty, // in which case finalIndex will be 2. int finalIndex = fds.length - 1; baselineScores.add(Double.parseDouble(fds[finalIndex])); } //step-1: get normalized distribution /**value in baselineScores will be changed to normalized probability * */ computeNormalizedProbs(baselineScores, scalingFactor); List<Double> normalizedProbs = baselineScores; //=== required by google linear corpus gain HashMap<String, Double> posteriorCountsTbl = null; if (useGoogleLinearCorpusGain) { posteriorCountsTbl = new HashMap<String,Double>(); getGooglePosteriorCounts(ngramTbls, normalizedProbs, posteriorCountsTbl); } //step-2: rerank the nbest /**TODO: zhifei: now the re-ranking takes O(n^2) where n is the size of the nbest. * But, we can significantly speed up this (leadding to O(n)) by * first estimating a model on nbest, and then rerank the nbest * using the estimated model. * */ double bestGain = -1000000000;//set as worst gain String bestHyp = null; List<Double> gains = new ArrayList<Double>(); for (int i = 0; i < hypsItself.size(); i++) { String curHyp = hypsItself.get(i); int curHypLen = sentLens.get(i); HashMap<String, Integer> curHypNgramTbl = ngramTbls.get(i); //double cur_gain = computeGain(cur_hyp, l_hyp_itself, l_normalized_probs); double curGain = 0; if (useGoogleLinearCorpusGain) { curGain = computeExpectedLinearCorpusGain(curHypLen, curHypNgramTbl, posteriorCountsTbl); } else { curGain = computeExpectedGain(curHypLen, curHypNgramTbl, ngramTbls, sentLens,normalizedProbs); } gains.add( curGain); if (i == 0 || curGain > bestGain) { // maximize bestGain = curGain; bestHyp = curHyp; } } //step-3: output the 1best or nbest if (this.produceRerankedNbest) { //TOTO: sort the list and write the reranked nbest; Use Collections.sort(List list, Comparator c) } else { /* this.out.write(best_hyp); this.out.write("\n"); out.flush(); */ } System.out.println("best gain: " + bestGain); if (null == bestHyp) { throw new RuntimeException("mbr reranked one best is null, must be wrong"); } return bestHyp; } /**based on a list of log-probabilities in nbestLogProbs, obtain a * normalized distribution, and put the normalized probability (real value in [0,1]) into nbestLogProbs * */ //get a normalized distributeion and put it back to nbestLogProbs static public void computeNormalizedProbs(List<Double> nbestLogProbs, double scalingFactor){ //=== get noralization constant, remember features, remember the combined linear score double normalizationConstant = Double.NEGATIVE_INFINITY;//log-semiring for (double logp : nbestLogProbs) { normalizationConstant = addInLogSemiring(normalizationConstant, logp * scalingFactor, 0); } //System.out.println("normalization_constant (logP) is " + normalization_constant); //=== get normalized prob for each hyp double tSum = 0; for (int i = 0; i < nbestLogProbs.size(); i++) { double normalizedProb = Math.exp(nbestLogProbs.get(i) * scalingFactor-normalizationConstant); tSum += normalizedProb; nbestLogProbs.set(i, normalizedProb); if (Double.isNaN(normalizedProb)) { throw new RuntimeException( "prob is NaN, must be wrong\nnbest_logps.get(i): " + nbestLogProbs.get(i) + "; scaling_factor: " + scalingFactor + "; normalization_constant:" + normalizationConstant ); } //logger.info("probability: " + normalized_prob); } //sanity check if (Math.abs(tSum - 1.0) > 1e-4) { throw new RuntimeException("probabilities not sum to one, must be wrong"); } } //Gain(e) = negative risk = \sum_{e'} G(e, e')P(e') //curHyp: e //trueHyp: e' public double computeExpectedGain(int curHypLen, HashMap<String, Integer> curHypNgramTbl, List<HashMap<String,Integer>> ngramTbls, List<Integer> sentLens, List<Double> nbestProbs) { //### get noralization constant, remember features, remember the combined linear score double gain = 0; for (int i = 0; i < nbestProbs.size(); i++) { HashMap<String,Integer> trueHypNgramTbl = ngramTbls.get(i); double trueProb = nbestProbs.get(i); int trueLen = sentLens.get(i); gain += trueProb * BLEU.computeSentenceBleu(trueLen, trueHypNgramTbl, curHypLen, curHypNgramTbl, doNgramClip, bleuOrder); } //System.out.println("Gain is " + gain); return gain; } //Gain(e) = negative risk = \sum_{e'} G(e, e')P(e') //curHyp: e //trueHyp: e' static public double computeExpectedGain(String curHyp, List<String> nbestHyps, List<Double> nbestProbs) { //### get noralization constant, remember features, remember the combined linear score double gain = 0; for (int i = 0; i < nbestHyps.size(); i++) { String trueHyp = nbestHyps.get(i); double trueProb = nbestProbs.get(i); gain += trueProb * BLEU.computeSentenceBleu(trueHyp, curHyp, doNgramClip, bleuOrder); } //System.out.println("Gain is " + gain); return gain; } void getGooglePosteriorCounts( List<HashMap<String,Integer>> ngramTbls, List<Double> normalizedProbs, HashMap<String,Double> posteriorCountsTbl) { //TODO } double computeExpectedLinearCorpusGain(int curHypLen, HashMap<String,Integer> curHypNgramTbl, HashMap<String,Double> posteriorCountsTbl) { //TODO double[] thetas = { -1, 1, 1, 1, 1 }; double res = 0; res += thetas[0] * curHypLen; for (Entry<String,Integer> entry : curHypNgramTbl.entrySet()) { String key = entry.getKey(); String[] tem = Regex.spaces.split(key); double post_prob = posteriorCountsTbl.get(key); res += entry.getValue() * post_prob * thetas[tem.length]; } return res; } // OR: return Math.log(Math.exp(x) + Math.exp(y)); static private double addInLogSemiring(double x, double y, int addMode){//prevent over-flow if (addMode == 0) { // sum if (x == Double.NEGATIVE_INFINITY) {//if y is also n-infinity, then return n-infinity return y; } if (y == Double.NEGATIVE_INFINITY) { return x; } if (y <= x) { return x + Math.log(1+Math.exp(y-x)); } else { return y + Math.log(1+Math.exp(x-y)); } } else if (addMode == 1) { // viter-min return (x <= y) ? x : y; } else if (addMode == 2) { // viter-max return (x >= y) ? x : y; } else { throw new RuntimeException("invalid add mode"); } } public static void main(String[] args) throws IOException { // If you don't know what to use for scaling factor, try using 1 if (args.length<4 || args.length>5) { System.out.println("wrong command, correct command should be: java NbestMinRiskReranker f_nbest_in f_out produce_reranked_nbest scaling_factor [numThreads]"); System.out.println("num of args is "+ args.length); for(int i = 0; i < args.length; i++) { System.out.println("arg is: " + args[i]); } System.exit(-1); } long startTime = System.currentTimeMillis(); String inputNbest = args[0].trim(); String output = args[1].trim(); boolean produceRerankedNbest = Boolean.valueOf(args[2].trim()); double scalingFactor = Double.parseDouble(args[3].trim()); int numThreads = (args.length==5) ? Integer.parseInt(args[4].trim()) : 1; BufferedWriter outWriter = FileUtility.getWriteFileStream(output); NbestMinRiskReranker mbrReranker = new NbestMinRiskReranker(produceRerankedNbest, scalingFactor); System.out.println("##############running mbr reranking"); int oldSentID = -1; LineReader nbestReader = new LineReader(inputNbest); List<String> nbest = new ArrayList<String>(); if (numThreads==1) { try { for (String line : nbestReader) { String[] fds = Regex.threeBarsWithSpace.split(line); int newSentID = Integer.parseInt(fds[0]); if (oldSentID != -1 && oldSentID != newSentID) { String best_hyp = mbrReranker.processOneSent(nbest, oldSentID);//nbest: list of unique strings outWriter.write(best_hyp); outWriter.newLine(); outWriter.flush(); nbest.clear(); } oldSentID = newSentID; nbest.add(line); } } finally { nbestReader.close(); } //last nbest String bestHyp = mbrReranker.processOneSent(nbest, oldSentID); outWriter.write(bestHyp); outWriter.newLine(); outWriter.flush(); nbest.clear(); outWriter.close(); } else { ExecutorService threadPool = Executors.newFixedThreadPool(numThreads); for (String line : nbestReader) { String[] fds = Regex.threeBarsWithSpace.split(line); int newSentID = Integer.parseInt(fds[0]); if (oldSentID != -1 && oldSentID != newSentID) { threadPool.execute(mbrReranker.new RankerTask(nbest, oldSentID)); nbest.clear(); } oldSentID = newSentID; nbest.add(line); } //last nbest threadPool.execute(mbrReranker.new RankerTask(nbest, oldSentID)); nbest.clear(); threadPool.shutdown(); try { threadPool.awaitTermination(Integer.MAX_VALUE, TimeUnit.SECONDS); while (! mbrReranker.resultsQueue.isEmpty()) { RankerResult result = mbrReranker.resultsQueue.remove(); String best_hyp = result.toString(); outWriter.write(best_hyp); outWriter.newLine(); } outWriter.flush(); } catch (InterruptedException e) { e.printStackTrace(); } finally { outWriter.close(); } } System.out.println("Total running time (seconds) is " + (System.currentTimeMillis() - startTime) / 1000.0); } private class RankerTask implements Runnable { final List<String> nbest; final int sentID; RankerTask(final List<String> nbest, final int sentID) { this.nbest = new ArrayList<String>(nbest); this.sentID = sentID; } public void run() { String result = processOneSent(nbest, sentID); resultsQueue.add(new RankerResult(result,sentID)); } } private static class RankerResult implements Comparable<RankerResult> { final String result; final Integer sentenceNumber; RankerResult(String result, int sentenceNumber) { this.result = result; this.sentenceNumber = sentenceNumber; } public int compareTo(RankerResult o) { return sentenceNumber.compareTo(o.sentenceNumber); } public String toString() { return result; } } }