/** * */ package edu.berkeley.nlp.PCFGLA; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.InputStreamReader; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.util.Arrays; import java.util.List; import edu.berkeley.nlp.math.DoubleArrays; import edu.berkeley.nlp.math.SloppyMath; import edu.berkeley.nlp.syntax.StateSet; import edu.berkeley.nlp.syntax.Tree; import edu.berkeley.nlp.syntax.Trees.PennTreeReader; import edu.berkeley.nlp.util.Numberer; import edu.berkeley.nlp.util.PriorityQueue; import edu.berkeley.nlp.util.ScalingTools; /** * @author petrov * Takes an unannotated tree a returns the log-likelihood of all derivations corresponding to the given tree, * under a number of different grammars. */ public class TreeReranker{ public static class Options { @Option(name = "-grammar", required = true, usage = "Input Files for Grammar") public String inFileName; @Option(name = "-inputFile", usage = "Input File for Parse Trees.") public String inputFile; @Option(name = "-outputFile", usage = "Store output in this file instead of printing it to STDOUT.") public String outputFile; @Option(name = "-nGrammars", usage = "Number of grammars") public int nGrammars; @Option(name = "-kBest", usage = "Print the k best trees") public int kbest = 1; } public static void main(String[] args) { OptionParser optParser = new OptionParser(Options.class); Options opts = (Options) optParser.parse(args, true); // provide feedback on command-line arguments System.err.println("Calling with " + optParser.getPassedInOptions()); String inFileName = opts.inFileName; if (inFileName==null) { throw new Error("Did not provide a grammar."); } ArrayParser[] parsers = new ArrayParser[opts.nGrammars]; short[][] numSubstates = new short[opts.nGrammars][]; ParserData[] pData = new ParserData[opts.nGrammars]; Numberer[] tagNumberer = new Numberer[opts.nGrammars]; int v_markov = 1, h_markov = 0; Binarization bin = Binarization.RIGHT; for (int i=0; i<opts.nGrammars; i++) { System.err.println("Loading grammar from "+inFileName+"."+(i+1)); pData[i] = ParserData.Load(inFileName+"."+(i+1)); if (pData==null) { System.out.println("Failed to load grammar from file"+inFileName+"."); System.exit(1); } Grammar grammar = pData[i].getGrammar(); grammar.splitRules(); SophisticatedLexicon lexicon = (SophisticatedLexicon)pData[i].getLexicon(); parsers[i] = new ArrayParser(grammar, lexicon); numSubstates[i] = grammar.numSubStates; v_markov = pData[i].v_markov; h_markov = pData[i].h_markov; bin = pData[i].bin; Numberer.setNumberers(pData[i].getNumbs()); tagNumberer[i] = Numberer.getGlobalNumberer("tags"); } try{ BufferedReader inputData = (opts.inputFile==null) ? new BufferedReader(new InputStreamReader(System.in)) : new BufferedReader(new InputStreamReader(new FileInputStream(opts.inputFile), "UTF-8")); // PennTreeReader treeReader = new PennTreeReader(inputData); PrintWriter outputData = (opts.outputFile==null) ? new PrintWriter(new OutputStreamWriter(System.out)) : new PrintWriter(new OutputStreamWriter(new FileOutputStream(opts.outputFile), "UTF-8"), true); Tree<String> tree = null; String line = ""; double bestScore = Double.NEGATIVE_INFINITY; Tree<String> bestTree = null; PriorityQueue<Tree<String>> pQ = new PriorityQueue<Tree<String>>(); int index=1; while ((line = inputData.readLine()) != null) { tree = PennTreeReader.parseEasy(line); if (line.equals("\n") || tree==null || tree.getYield().get(0).equals("") ) { // done with the block if (bestTree == null) { outputData.write("(())\n"); } else { if (opts.kbest == 1) { outputData.write(bestTree+"\n"); } else { int nTrees = Math.min(opts.kbest, pQ.size()); outputData.write(nTrees+"\t"+opts.inputFile+"-"+(index++)+"\n"); for (int i=0; i<nTrees; i++){ double p = pQ.getPriority(); outputData.write(p+"\n"+pQ.next()+"\n"); } outputData.write("\n"); } } outputData.flush(); bestScore = Double.NEGATIVE_INFINITY; bestTree = null; pQ = new PriorityQueue<Tree<String>>(); System.err.println("Picked best tree."); continue; } Tree<String> processedTree = TreeAnnotations.processTree(tree,v_markov,h_markov,bin,false); double[] logScores = new double[opts.nGrammars]; for (int i=0; i<opts.nGrammars; i++){ Tree<StateSet> stateSetTree = StateSetTreeList.stringTreeToStatesetTree(processedTree, numSubstates[i], false, tagNumberer[i]); allocate(stateSetTree); parsers[i].doInsideScores(stateSetTree, false, false, null); logScores[i] = Math.log(stateSetTree.getLabel().getIScore(0)) + (stateSetTree.getLabel().getIScale()*ScalingTools.LOGSCALE); } // double totalScore = SloppyMath.logAdd(logScores); double totalScore = DoubleArrays.add(logScores);///opts.nGrammars; if (opts.kbest > 1 && totalScore != Double.NEGATIVE_INFINITY) { pQ.add(tree, totalScore); } if (totalScore > bestScore) { // System.err.println(totalScore); bestScore = totalScore; bestTree = tree; } } outputData.flush(); outputData.close(); }catch (Exception ex) { ex.printStackTrace(); } System.exit(0); } /* * Allocate the inside and outside score arrays for the whole tree */ static void allocate(Tree<StateSet> tree) { tree.getLabel().allocate(); for (Tree<StateSet> child : tree.getChildren()) { allocate(child); } } }