/**
*
*/
package edu.berkeley.nlp.scripts;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.InputStreamReader;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.List;
import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.Option;
import edu.berkeley.nlp.PCFGLA.OptionParser;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.SophisticatedLexicon;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.PCFGLA.GrammarTrainer.Options;
import edu.berkeley.nlp.PCFGLA.smoothing.NoSmoothing;
import edu.berkeley.nlp.PCFGLA.smoothing.SmoothAcrossParentSubstate;
import edu.berkeley.nlp.PCFGLA.smoothing.Smoother;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees.PennTreeReader;
import edu.berkeley.nlp.util.Numberer;
/**
* @author petrov
*
*/
public class ObservedGrammarExtractor {
public static class Options {
@Option(name = "-out", required = true, usage = "Output File for Grammar (Required)")
public String outFileName;
@Option(name = "-path", usage = "Path to Corpus File (Default: null)")
public String path = null;
@Option(name = "-smooth", usage = "Smooth the grammar if possible")
public boolean smooth = false;
}
public static void main(String[] args) {
OptionParser optParser = new OptionParser(Options.class);
Options opts = (Options) optParser.parse(args, true);
List<Tree<String>> trainTrees = loadTrees(opts.path);
ParserData pData = createGrammar(trainTrees, opts.smooth);
if (pData.Save(opts.outFileName)) System.out.println("Saved grammar.");
else System.out.println("Saving failed!");
System.exit(0);
}
static Numberer tagNumberer;
static List<Numberer> substateNumberers;
private static ParserData createGrammar(List<Tree<String>> trainTrees, boolean smooth) {
tagNumberer = Numberer.getGlobalNumberer("tags");
substateNumberers = new ArrayList<Numberer>();
short[] numSubStates = countSymbols(trainTrees);
List<Tree<String>> trainTreesNoAnnotation = stripOffAnnotation(trainTrees);
StateSetTreeList stateSetTrees = new StateSetTreeList(trainTreesNoAnnotation, numSubStates, false, tagNumberer);
Grammar grammar = new Grammar(numSubStates, false, new NoSmoothing(), null, -1);
Lexicon lexicon = new SophisticatedLexicon(numSubStates,SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF,new double[]{0.5,0.1}, new NoSmoothing(),0);
if (smooth){
System.out.println("Will smooth the grammar.");
Smoother grSmoother = new SmoothAcrossParentSubstate(0.01);
Smoother lexSmoother = new SmoothAcrossParentSubstate(0.1);
grammar.setSmoother(grSmoother);
lexicon.setSmoother(lexSmoother);
}
System.out.print("Creating grammar...");
int index = 0;
boolean secondHalf = false;
int nTrees = trainTrees.size();
for (Tree<StateSet> stateSetTree : stateSetTrees){
Tree<String> tree = trainTrees.get(index++);
secondHalf = (index>nTrees/2.0);
setScores(stateSetTree, tree);
lexicon.trainTree(stateSetTree, 0, null, secondHalf,false,4);
grammar.tallyStateSetTree(stateSetTree, grammar);
}
lexicon.optimize();
grammar.optimize(0);
System.out.println("done.");
ParserData pData = new ParserData(lexicon, grammar, null, Numberer.getNumberers(), numSubStates, 1, 0, Binarization.RIGHT);
return pData;
}
private static void setScores(Tree<StateSet> stateSetTree, Tree<String> tree) {
if (tree.isLeaf()) return;
String[] labels = splitLabel(tree.getLabel());
StateSet stateSet = stateSetTree.getLabel();
int substate = substateNumberers.get(stateSet.getState()).number(labels[1]);
stateSet.setIScore(substate, 1.0);
stateSet.setIScale(0);
stateSet.setOScore(substate, 1.0);
stateSet.setOScale(0);
int nChildren = tree.getChildren().size();
if (nChildren != stateSetTree.getChildren().size()) System.err.println("Mismatch!");
for (int i=0; i<nChildren; i++){
setScores(stateSetTree.getChildren().get(i), tree.getChildren().get(i));
}
}
private static List<Tree<String>> stripOffAnnotation(List<Tree<String>> trainTrees) {
List<Tree<String>> trainTreesNoGF = new ArrayList<Tree<String>>(trainTrees.size());
for (Tree<String> tree : trainTrees){
trainTreesNoGF.add(tree.shallowClone());
}
for (Tree<String> tree : trainTreesNoGF){
for (Tree<String> node : tree.getPostOrderTraversal()){
if (tree.isLeaf()) continue;
String label = node.getLabel();
int cutIndex = label.indexOf('-');
if (cutIndex!=-1) label = label.substring(0,cutIndex);
node.setLabel(label);
}
}
return trainTreesNoGF;
}
private static short[] countSymbols(List<Tree<String>> trainTrees) {
System.out.print("Counting symbols...");
for (Tree<String> tree : trainTrees){
processTree(tree);
}
short[] numSubStates = new short[tagNumberer.total()];
for (int substate=0; substate<numSubStates.length; substate++){
numSubStates[substate] = (short)substateNumberers.get(substate).total();
}
System.out.println("done.");
for (int tag=0; tag<tagNumberer.size(); tag++){
System.out.println((String)tagNumberer.object(tag)+"\t"+numSubStates[tag]);
}
return numSubStates;
}
private static void processTree(Tree<String> tree) {
String[] labelParts = splitLabel(tree.getLabel());
int state = tagNumberer.number(labelParts[0]);
if (state >= substateNumberers.size()) {
substateNumberers.add(new Numberer());
}
substateNumberers.get(state).number(labelParts[1]);
for (Tree<String> child : tree.getChildren()){
if (!child.isLeaf()) processTree(child);
}
}
private static String[] splitLabel(String label) {
int breakPoint = label.indexOf("-");
String substateString = (breakPoint<0) ? "" : label.substring(breakPoint);
String stateString = (breakPoint<0) ? label : label.substring(0, breakPoint);
return new String[]{stateString,substateString};
}
private static List<Tree<String>> loadTrees(String inputFile) {
System.out.print("Loading trees...");
InputStreamReader inputData = null;
try {
inputData = new InputStreamReader(new FileInputStream(inputFile), "UTF-8");
} catch (UnsupportedEncodingException e) {
// TODO Auto-generated catch block
e.printStackTrace();
} catch (FileNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
PennTreeReader treeReader = new PennTreeReader(inputData);
List<Tree<String>> trainTrees = new ArrayList<Tree<String>>();
Tree<String> tree = null;
while(treeReader.hasNext()){
tree = treeReader.next();
// trainTrees.add(TreeAnnotations.processTree(tree, 1, 0, Binarization.LEFT, false, false, false));
trainTrees.add(tree);
}
System.out.println("done.");
return trainTrees;
}
}