/**
*
*/
package edu.berkeley.nlp.PCFGLA;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.List;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees.PennTreeReader;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.ScalingTools;
/**
* @author petrov
* Takes an unannotated tree a returns the log-likelihood of all derivations corresponding to the given tree.
*
*/
public class TreeScorer{
public static class Options {
@Option(name = "-gr", required = true, usage = "Input File for Grammar (Required)\n")
public String inFileName;
@Option(name = "-inputFile", usage = "Read input from this file instead of reading it from STDIN.")
public String inputFile;
@Option(name = "-outputFile", usage = "Store output in this file instead of printing it to STDOUT.")
public String outputFile;
@Option(name = "-printStats", usage = "Compute and print subcategory usage statistics")
public boolean printStats;
}
public static void main(String[] args) {
OptionParser optParser = new OptionParser(Options.class);
Options opts = (Options) optParser.parse(args, true);
// provide feedback on command-line arguments
System.err.println("Calling with " + optParser.getPassedInOptions());
String inFileName = opts.inFileName;
if (inFileName==null) {
throw new Error("Did not provide a grammar.");
}
System.err.println("Loading grammar from "+inFileName+".");
ParserData pData = ParserData.Load(inFileName);
if (pData==null) {
System.out.println("Failed to load grammar from file"+inFileName+".");
System.exit(1);
}
Grammar grammar = pData.getGrammar();
grammar.splitRules();
SophisticatedLexicon lexicon = (SophisticatedLexicon)pData.getLexicon();
ArrayParser parser = new ArrayParser(grammar, lexicon);
Numberer.setNumberers(pData.getNumbs());
Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
short[] numSubstates = grammar.numSubStates;
double[][] cumulativePosteriors = null;
if (opts.printStats){
cumulativePosteriors = new double[numSubstates.length][];
for (int state=0; state<numSubstates.length; state++){
cumulativePosteriors[state] = new double[numSubstates[state]];
}
}
try{
BufferedReader inputData = (opts.inputFile==null) ? new BufferedReader(new InputStreamReader(System.in)) : new BufferedReader(new InputStreamReader(new FileInputStream(opts.inputFile), "UTF-8"));
PrintWriter outputData = (opts.outputFile==null) ? new PrintWriter(new OutputStreamWriter(System.out)) : new PrintWriter(new OutputStreamWriter(new FileOutputStream(opts.outputFile), "UTF-8"), true);
Tree<String> tree = null;
String line = "";
while((line=inputData.readLine()) != null){
if (line.equals("")) {
outputData.write("\n");
continue;
}
tree = PennTreeReader.parseEasy(line);
if (tree.getYield().get(0).equals("")){ // empty tree -> parse failure
outputData.write("()\n");
continue;
}
tree = TreeAnnotations.processTree(tree,pData.v_markov, pData.h_markov,pData.bin,false);
Tree<StateSet> stateSetTree = StateSetTreeList.stringTreeToStatesetTree(tree, numSubstates, false, tagNumberer);
allocate(stateSetTree);
if (opts.printStats){
parser.doInsideOutsideScores(stateSetTree, false, false);
parser.countPosteriors(cumulativePosteriors, stateSetTree, stateSetTree.getLabel().getIScore(0), stateSetTree.getLabel().getIScale());
} else {
parser.doInsideScores(stateSetTree, false, false, null);
}
double logScore = Math.log(stateSetTree.getLabel().getIScore(0)) + (stateSetTree.getLabel().getIScale()*ScalingTools.LOGSCALE);
outputData.write(logScore + "\n");
outputData.flush();
}
}catch (Exception ex) {
ex.printStackTrace();
}
if (opts.printStats){
for (int state=0; state<numSubstates.length; state++){
String tagname = (String)tagNumberer.object(state);
if (tagname.endsWith("^g")) tagname = tagname.substring(0,tagname.length()-2);
Arrays.sort(cumulativePosteriors[state]);
System.out.print(tagname);
for (int substate=cumulativePosteriors[state].length-1; substate>=0; substate--){
System.out.print("\t"+cumulativePosteriors[state][substate]);
}
System.out.print("\n");
}
}
System.exit(0);
}
/*
* Allocate the inside and outside score arrays for the whole tree
*/
static void allocate(Tree<StateSet> tree) {
tree.getLabel().allocate();
for (Tree<StateSet> child : tree.getChildren()) {
allocate(child);
}
}
}