/** * */ package edu.berkeley.nlp.PCFGLA; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.io.ObjectInputStream; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.zip.GZIPInputStream; import edu.berkeley.nlp.syntax.Tree; import edu.berkeley.nlp.util.ArrayUtil; import edu.berkeley.nlp.util.Numberer; import edu.berkeley.nlp.util.ScalingTools; public class PosteriorMerger{ public static class Options { @Option(name = "-grammarFiles", required = true, usage = "Input Files for Grammars.") public String grammarFiles; @Option(name = "-inputFile", usage = "Read input from this file instead of reading it from STDIN.") public String inputFile; @Option(name = "-outputFile", usage = "Store output in this file instead of printing it to STDOUT.") public String outputFile; @Option(name = "-nGrammars", usage = "Number of Grammars") public int nGrammars; @Option(name = "-maxLength", usage = "Maximum sentence length (Default = 200).") public int maxLength = 200; } static double[][][] maxcScore; // start, end, state --> logProb static int[][][] maxcSplit; // start, end, state -> split position static int[][][] maxcChild; // start, end, state -> unary child (if any) static int[][][] maxcLeftChild; // start, end, state -> left child static int[][][] maxcRightChild; // start, end, state -> right child public static void main(String[] args) { OptionParser optParser = new OptionParser(Options.class); Options opts = (Options) optParser.parse(args, true); // provide feedback on command-line arguments System.err.println("Calling with " + optParser.getPassedInOptions()); String inFileName = opts.grammarFiles; if (inFileName==null) { throw new Error("Did not provide a grammar."); } short[][] numSubstates = new short[opts.nGrammars][]; Grammar[] grammars = new Grammar[opts.nGrammars]; Lexicon[] lexicons = new Lexicon[opts.nGrammars]; for (int gr=0; gr<opts.nGrammars; gr++) { System.err.println("Loading grammar from "+inFileName+"."+(gr+1)); ParserData pData = ParserData.Load(inFileName+"."+(gr+1)); if (pData==null) { System.out.println("Failed to load grammar from file"+inFileName+"."); System.exit(1); } numSubstates[gr] = pData.getGrammar().numSubStates; Numberer.setNumberers(pData.getNumbs()); grammars[gr]= pData.getGrammar(); lexicons[gr] = pData.getLexicon(); } int nGrammars = numSubstates.length; CoarseToFineMaxRuleParser parser = new CoarseToFineMaxRuleParser(grammars[0], lexicons[0], 1.0, -1, false, false, false, true, false, false, false); try{ BufferedReader inputData = (opts.inputFile==null) ? new BufferedReader(new InputStreamReader(System.in)) : new BufferedReader(new InputStreamReader(new FileInputStream(opts.inputFile), "UTF-8")); PrintWriter outputData = (opts.outputFile==null) ? new PrintWriter(new OutputStreamWriter(System.out)) : new PrintWriter(new OutputStreamWriter(new FileOutputStream(opts.outputFile), "UTF-8"), true); String line = ""; int blockIndex = 0; int lineIndex = 0; List<Posterior>[] posteriors = null; while ((line = inputData.readLine()) != null) { List<String> sentence = Arrays.asList(line.split(" ")); if (posteriors == null || lineIndex == posteriors[0].size()){ posteriors = new ArrayList[nGrammars]; for (int gr=0; gr<nGrammars; gr++) { String fileName = opts.grammarFiles + "." + (gr+1) + ".posteriors." + blockIndex; posteriors[gr] = loadPosteriors(fileName); } lineIndex = 0; blockIndex++; } int length = sentence.size(); if (length > opts.maxLength){ // lineIndex++; outputData.write("(())\n"); continue; } List<double[][][][]> iScores = new ArrayList<double[][][][]>(nGrammars); List<double[][][][]> oScores = new ArrayList<double[][][][]>(nGrammars); List<int[][][]> iScales = new ArrayList<int[][][]>(nGrammars); List<int[][][]> oScales = new ArrayList<int[][][]>(nGrammars); boolean[][][] allowedStates = null; boolean skip = false; for (int gr=0; gr<nGrammars; gr++) { Posterior posterior = posteriors[gr].get(lineIndex); iScores.add(posterior.iScore); oScores.add(posterior.oScore); iScales.add(posterior.iScale); oScales.add(posterior.oScale); allowedStates = mergeAllowedStates(allowedStates, posterior.allowedStates); countAllowedStates(allowedStates); if (posterior.iScale !=null){ skip=true; System.err.println("Scaling will be used."); if (length != posterior.iScale.length){ System.err.println("G: " +gr + " sentence "+ lineIndex +" Length mismatch. Expected: " +length + " Got: " + posterior.iScale.length); } } } lineIndex++; if (skip==true){ outputData.write("(()) \n"); continue; } doCombinedMaxCScores(sentence, iScores, oScores, iScales, oScales, allowedStates, grammars, lexicons, numSubstates, iScales.get(0)!=null); System.err.println("Done with scores"); if (maxcScore[0][sentence.size()][0]==Double.NEGATIVE_INFINITY){ System.err.println("MaxCscore for ROOT is -Inf."); outputData.write("(()) \n"); continue; } parser.maxcScore = maxcScore; parser.maxcChild = maxcChild; parser.maxcLeftChild = maxcLeftChild; parser.maxcRightChild = maxcRightChild; parser.maxcSplit = maxcSplit; parser.allowedStates = allowedStates; Tree<String> parsedTree = parser.extractBestMaxRuleParse(0, sentence.size(), sentence); parsedTree = TreeAnnotations.unAnnotateTree(parsedTree); outputData.write(parsedTree+"\n"); outputData.flush(); } outputData.flush(); outputData.close(); }catch (Exception ex) { ex.printStackTrace(); } System.exit(0); } private static boolean[][][] mergeAllowedStates(boolean[][][] allowedStates, boolean[][][] allowedStates2) { if (allowedStates==null) return allowedStates2; for (int i=0; i<allowedStates.length; i++){ for (int j=i+1; j<allowedStates[i].length; j++){ for (int k=0; k<allowedStates[i][j].length; k++){ if (!allowedStates2[i][j][k] && allowedStates[i][j][k]) allowedStates[i][j][k] = false; } } } return allowedStates; } private static void countAllowedStates(boolean[][][] allowedStates) { int total = 0; int allowed = 0; for (int i=0; i<allowedStates.length; i++){ for (int j=i+1; j<allowedStates[i].length; j++){ for (int k=0; k<allowedStates[i][j].length; k++){ if (allowedStates[i][j][k]) allowed++; total++; } } } System.err.println(allowed+"/"+total+" allowed for sentence of length "+allowedStates.length); } static void doCombinedMaxCScores(List<String> sentence, List<double[][][][]> iScores, List<double[][][][]> oScores, List<int[][][]> iScales, List<int[][][]> oScales, boolean[][][] allowedStates, Grammar[] grammars, Lexicon[] lexicons, short[][] numSubstates, boolean scale) { int length = sentence.size(); int nGrammars = numSubstates.length; int numStates = numSubstates[0].length; boolean[] grammarTags = grammars[0].isGrammarTag; Numberer tagNumberer = Numberer.getGlobalNumberer("tags"); maxcScore = new double[length][length + 1][numStates]; maxcSplit = new int[length][length + 1][numStates]; maxcChild = new int[length][length + 1][numStates]; maxcLeftChild = new int[length][length + 1][numStates]; maxcRightChild = new int[length][length + 1][numStates]; ArrayUtil.fill(maxcScore, Double.NEGATIVE_INFINITY); double[] logNormalizer = new double[nGrammars]; for (int i=0; i<nGrammars; i++){ logNormalizer[i] = iScores.get(i)[0][length][0][0]; } for (int diff = 1; diff <= length; diff++) { for (int start = 0; start < (length - diff + 1); start++) { int end = start + diff; Arrays.fill(maxcSplit[start][end], -1); Arrays.fill(maxcChild[start][end], -1); Arrays.fill(maxcLeftChild[start][end], -1); Arrays.fill(maxcRightChild[start][end], -1); if (diff > 1) { // diff > 1: Try binary rules for (short pState=0; pState<numStates; pState++){ if (!allowedStates[start][end][pState]) continue; BinaryRule[] parentRules = grammars[0].splitRulesWithP(pState); for (int i = 0; i < parentRules.length; i++) { BinaryRule r = parentRules[i]; short lState = r.leftChildState; short rState = r.rightChildState; double scoreToBeat = maxcScore[start][end][pState]; for (int split = start+1; split <= end-1; split++) { if (!allowedStates[start][split][lState]) continue; if (!allowedStates[split][end][rState]) continue; double leftChildScore = maxcScore[start][split][lState]; double rightChildScore = maxcScore[split][end][rState]; if (leftChildScore==Double.NEGATIVE_INFINITY||rightChildScore==Double.NEGATIVE_INFINITY) continue; double scalingFactor = 0.0; if (scale) { for (int gr=0; gr<nGrammars; gr++) { scalingFactor += oScales.get(gr)[start][end][pState]+iScales.get(gr)[start][split][lState]+ iScales.get(gr)[split][end][rState]-iScales.get(gr)[0][length][0]; } // System.err.println(scalingFactor); scalingFactor = Math.log(ScalingTools.calcScaleFactor(scalingFactor)); } double gScore = leftChildScore + scalingFactor + rightChildScore; if (gScore < scoreToBeat) continue; // no chance of finding a better derivation for (int gr=0; gr<nGrammars; gr++) { double ruleScore = 0; BinaryRule rule = grammars[gr].getBinaryRule(pState, lState, rState); if (rule==null){ System.err.println("Dont have rule "+ (String)tagNumberer.object(pState) +" -> "+(String)tagNumberer.object(lState) +" "+(String)tagNumberer.object(rState)+" in grammar "+gr); continue; } double[][][] scores = rule.getScores2(); int nParentStates = numSubstates[gr][pState]; // == scores[0][0].length; int nLeftChildStates = numSubstates[gr][lState]; // == scores.length; int nRightChildStates = numSubstates[gr][rState]; // == scores[0].length; for (int lp = 0; lp < nLeftChildStates; lp++) { double lIS = iScores.get(gr)[start][split][lState][lp]; if (lIS == 0) continue; for (int rp = 0; rp < nRightChildStates; rp++) { if (scores[lp][rp]==null) continue; double rIS = iScores.get(gr)[split][end][rState][rp]; if (rIS == 0) continue; for (int np = 0; np < nParentStates; np++) { double pOS = oScores.get(gr)[start][end][pState][np]; if (pOS == 0) continue; double ruleS = scores[lp][rp][np]; if (ruleS == 0) continue; ruleScore += (pOS * ruleS * lIS * rIS) / logNormalizer[gr]; } } } // if (ruleScore==0) continue; gScore += Math.log(ruleScore); } if (gScore > scoreToBeat) { scoreToBeat = gScore; maxcScore[start][end][pState] = gScore; maxcSplit[start][end][pState] = split; maxcLeftChild[start][end][pState] = lState; maxcRightChild[start][end][pState] = rState; } } } } } else { // diff == 1 // We treat TAG --> word exactly as if it was a unary rule, except the score of the rule is // given by the lexicon rather than the grammar and that we allow another unary on top of it. //for (int tag : lexicon.getAllTags()){ for (int tag=0; tag<numStates; tag++){ if (!allowedStates[start][end][tag]) continue; String word = sentence.get(start); if (grammarTags[tag]) continue; double lexiconScores = 0; for (int gr=0; gr<nGrammars; gr++) { double ruleScore = 0; double[] lexiconScoreArray = lexicons[gr].score(word, (short) tag, start, false,false); for (int tp = 0; tp < numSubstates[gr][tag]; tp++) { double pOS = oScores.get(gr)[start][end][tag][tp]; double ruleS = lexiconScoreArray[tp]; ruleScore += (pOS * ruleS) / logNormalizer[gr]; // The inside score of a word is 0.0f } // if (ruleScore==0) continue; lexiconScores += Math.log(ruleScore); } if (length != iScores.get(0).length){ System.err.println("Length mismatch. Expected: " +length + " Got: " + iScores.get(0).length); System.err.println(sentence); } double scalingFactor = 0.0; if (scale) { for (int gr=0; gr<nGrammars; gr++) { try{ scalingFactor += oScales.get(gr)[start][end][tag]-iScales.get(gr)[0][length][0]; } catch (java.lang.ArrayIndexOutOfBoundsException e){ System.err.println("Start "+start); System.err.println("End "+end); System.err.println("Length "+length); System.err.println("Tag "+tag); System.err.println("Grammar "+gr); int[][][] oS = oScales.get(gr); System.err.println("oS.l "+oS.length); System.err.println("oS[].l "+oS[start].length); System.err.println("oS[][].l "+oS[start][end].length); int[][][] iS = iScales.get(gr); System.err.println("iS.l "+iS.length); System.err.println("iS[].l "+iS[start].length); System.err.println("iS[][].l "+iS[start][end].length); double[][][][] isS = iScores.get(gr); System.err.println("iS.l "+isS.length); System.err.println("iS[].l "+isS[start].length); System.err.println("iS[][].l "+isS[start][end].length); System.err.println("Length mismatch. Expected: " +length + " Got: " + iScales.get(gr).length); System.err.println(sentence); } } // System.err.println(scalingFactor); scalingFactor = Math.log(ScalingTools.calcScaleFactor(scalingFactor)); } maxcScore[start][end][tag] = lexiconScores + scalingFactor; } } // Try unary rules // Replacement for maxcScore[start][end], which is updated in batch double[] maxcScoreStartEnd = new double[numStates]; for (int i = 0; i < numStates; i++) { maxcScoreStartEnd[i] = maxcScore[start][end][i]; } for (short pState=0; pState<numStates; pState++){ if (!allowedStates[start][end][pState]) continue; UnaryRule[] unaries = grammars[0].getClosedSumUnaryRulesByParent(pState); for (int r = 0; r < unaries.length; r++) { UnaryRule ur = unaries[r]; short cState = ur.childState; if ((pState == cState)) continue;// && (np == cp))continue; if (!allowedStates[start][end][cState]) continue; double childScore = maxcScore[start][end][cState]; if (childScore==Double.NEGATIVE_INFINITY) continue; double scalingFactor = 0.0; if (scale) { for (int gr=0; gr<nGrammars; gr++) { scalingFactor += oScales.get(gr)[start][end][pState]+iScales.get(gr)[start][end][cState]-iScales.get(gr)[0][length][0]; } // System.err.println(scalingFactor); scalingFactor = Math.log(ScalingTools.calcScaleFactor(scalingFactor)); } double gScore = scalingFactor + childScore; if (gScore < maxcScoreStartEnd[pState]) continue; for (int gr=0; gr<nGrammars; gr++) { double ruleScore = 0; // TODO: this could be a problem //ClosedSumUnaryRulesByParent(pState); // double[][] scores = grammars[gr].getUnaryRule(pState, cState).getScores2(); UnaryRule rule = grammars[gr].getUnaryRule(pState, cState); if (rule==null){ System.err.println("Dont have rule "+ (String)tagNumberer.object(pState) +" -> "+(String)tagNumberer.object(cState)+" in grammar "+gr); continue; } double[][] scores = rule.getScores2(); int nChildStates = numSubstates[gr][cState]; // == scores.length; int nParentStates = numSubstates[gr][pState]; // == scores[0].length; for (int cp = 0; cp < nChildStates; cp++) { double cIS = iScores.get(gr)[start][end][cState][cp]; if (cIS == 0) continue; if (scores[cp]==null) continue; for (int np = 0; np < nParentStates; np++) { double pOS = oScores.get(gr)[start][end][pState][np]; if (pOS < 0) continue; double ruleS = scores[cp][np]; if (ruleS == 0) continue; ruleScore += (pOS * ruleS * cIS) / logNormalizer[gr]; } } // if (ruleScore==0) continue; gScore += Math.log(ruleScore); } if (gScore > maxcScoreStartEnd[pState]) { maxcScoreStartEnd[pState] = gScore; maxcChild[start][end][pState] = cState; } } } maxcScore[start][end] = maxcScoreStartEnd; } } } public static List<Posterior> loadPosteriors(String fileName) { List<Posterior> posteriors = null; try { FileInputStream fis = new FileInputStream(fileName); // Load from file GZIPInputStream gzis = new GZIPInputStream(fis); // Compressed ObjectInputStream in = new ObjectInputStream(gzis); // Load objects posteriors = (List<Posterior>)in.readObject(); // Read the mix of grammars in.close(); // And close the stream. gzis.close(); fis.close(); } catch (IOException e) { System.out.println("IOException\n"+e); return null; } catch (ClassNotFoundException e) { System.out.println("Class not found!"); return null; } return posteriors; } }