/** * */ package edu.berkeley.nlp.scripts; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.InputStreamReader; import java.io.UnsupportedEncodingException; import java.util.ArrayList; import java.util.List; import edu.berkeley.nlp.PCFGLA.Binarization; import edu.berkeley.nlp.PCFGLA.Grammar; import edu.berkeley.nlp.PCFGLA.Lexicon; import edu.berkeley.nlp.PCFGLA.Option; import edu.berkeley.nlp.PCFGLA.OptionParser; import edu.berkeley.nlp.PCFGLA.ParserData; import edu.berkeley.nlp.PCFGLA.SimpleLexicon; import edu.berkeley.nlp.PCFGLA.SophisticatedLexicon; import edu.berkeley.nlp.PCFGLA.StateSetTreeList; import edu.berkeley.nlp.PCFGLA.GrammarTrainer.Options; import edu.berkeley.nlp.PCFGLA.smoothing.NoSmoothing; import edu.berkeley.nlp.PCFGLA.smoothing.SmoothAcrossParentSubstate; import edu.berkeley.nlp.PCFGLA.smoothing.Smoother; import edu.berkeley.nlp.syntax.StateSet; import edu.berkeley.nlp.syntax.Tree; import edu.berkeley.nlp.syntax.Trees.PennTreeReader; import edu.berkeley.nlp.util.Numberer; /** * @author petrov * */ public class ObservedGrammarExtractor { public static class Options { @Option(name = "-out", required = true, usage = "Output File for Grammar (Required)") public String outFileName; @Option(name = "-path", usage = "Path to Corpus File (Default: null)") public String path = null; @Option(name = "-smooth", usage = "Smooth the grammar if possible") public boolean smooth = false; } public static void main(String[] args) { OptionParser optParser = new OptionParser(Options.class); Options opts = (Options) optParser.parse(args, true); List<Tree<String>> trainTrees = loadTrees(opts.path); ParserData pData = createGrammar(trainTrees, opts.smooth); if (pData.Save(opts.outFileName)) System.out.println("Saved grammar."); else System.out.println("Saving failed!"); System.exit(0); } static Numberer tagNumberer; static List<Numberer> substateNumberers; private static ParserData createGrammar(List<Tree<String>> trainTrees, boolean smooth) { tagNumberer = Numberer.getGlobalNumberer("tags"); substateNumberers = new ArrayList<Numberer>(); short[] numSubStates = countSymbols(trainTrees); List<Tree<String>> trainTreesNoAnnotation = stripOffAnnotation(trainTrees); StateSetTreeList stateSetTrees = new StateSetTreeList(trainTreesNoAnnotation, numSubStates, false, tagNumberer); Grammar grammar = new Grammar(numSubStates, false, new NoSmoothing(), null, -1); Lexicon lexicon = new SophisticatedLexicon(numSubStates,SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF,new double[]{0.5,0.1}, new NoSmoothing(),0); if (smooth){ System.out.println("Will smooth the grammar."); Smoother grSmoother = new SmoothAcrossParentSubstate(0.01); Smoother lexSmoother = new SmoothAcrossParentSubstate(0.1); grammar.setSmoother(grSmoother); lexicon.setSmoother(lexSmoother); } System.out.print("Creating grammar..."); int index = 0; boolean secondHalf = false; int nTrees = trainTrees.size(); for (Tree<StateSet> stateSetTree : stateSetTrees){ Tree<String> tree = trainTrees.get(index++); secondHalf = (index>nTrees/2.0); setScores(stateSetTree, tree); lexicon.trainTree(stateSetTree, 0, null, secondHalf,false,4); grammar.tallyStateSetTree(stateSetTree, grammar); } lexicon.optimize(); grammar.optimize(0); System.out.println("done."); ParserData pData = new ParserData(lexicon, grammar, null, Numberer.getNumberers(), numSubStates, 1, 0, Binarization.RIGHT); return pData; } private static void setScores(Tree<StateSet> stateSetTree, Tree<String> tree) { if (tree.isLeaf()) return; String[] labels = splitLabel(tree.getLabel()); StateSet stateSet = stateSetTree.getLabel(); int substate = substateNumberers.get(stateSet.getState()).number(labels[1]); stateSet.setIScore(substate, 1.0); stateSet.setIScale(0); stateSet.setOScore(substate, 1.0); stateSet.setOScale(0); int nChildren = tree.getChildren().size(); if (nChildren != stateSetTree.getChildren().size()) System.err.println("Mismatch!"); for (int i=0; i<nChildren; i++){ setScores(stateSetTree.getChildren().get(i), tree.getChildren().get(i)); } } private static List<Tree<String>> stripOffAnnotation(List<Tree<String>> trainTrees) { List<Tree<String>> trainTreesNoGF = new ArrayList<Tree<String>>(trainTrees.size()); for (Tree<String> tree : trainTrees){ trainTreesNoGF.add(tree.shallowClone()); } for (Tree<String> tree : trainTreesNoGF){ for (Tree<String> node : tree.getPostOrderTraversal()){ if (tree.isLeaf()) continue; String label = node.getLabel(); int cutIndex = label.indexOf('-'); if (cutIndex!=-1) label = label.substring(0,cutIndex); node.setLabel(label); } } return trainTreesNoGF; } private static short[] countSymbols(List<Tree<String>> trainTrees) { System.out.print("Counting symbols..."); for (Tree<String> tree : trainTrees){ processTree(tree); } short[] numSubStates = new short[tagNumberer.total()]; for (int substate=0; substate<numSubStates.length; substate++){ numSubStates[substate] = (short)substateNumberers.get(substate).total(); } System.out.println("done."); for (int tag=0; tag<tagNumberer.size(); tag++){ System.out.println((String)tagNumberer.object(tag)+"\t"+numSubStates[tag]); } return numSubStates; } private static void processTree(Tree<String> tree) { String[] labelParts = splitLabel(tree.getLabel()); int state = tagNumberer.number(labelParts[0]); if (state >= substateNumberers.size()) { substateNumberers.add(new Numberer()); } substateNumberers.get(state).number(labelParts[1]); for (Tree<String> child : tree.getChildren()){ if (!child.isLeaf()) processTree(child); } } private static String[] splitLabel(String label) { int breakPoint = label.indexOf("-"); String substateString = (breakPoint<0) ? "" : label.substring(breakPoint); String stateString = (breakPoint<0) ? label : label.substring(0, breakPoint); return new String[]{stateString,substateString}; } private static List<Tree<String>> loadTrees(String inputFile) { System.out.print("Loading trees..."); InputStreamReader inputData = null; try { inputData = new InputStreamReader(new FileInputStream(inputFile), "UTF-8"); } catch (UnsupportedEncodingException e) { // TODO Auto-generated catch block e.printStackTrace(); } catch (FileNotFoundException e) { // TODO Auto-generated catch block e.printStackTrace(); } PennTreeReader treeReader = new PennTreeReader(inputData); List<Tree<String>> trainTrees = new ArrayList<Tree<String>>(); Tree<String> tree = null; while(treeReader.hasNext()){ tree = treeReader.next(); // trainTrees.add(TreeAnnotations.processTree(tree, 1, 0, Binarization.LEFT, false, false, false)); trainTrees.add(tree); } System.out.println("done."); return trainTrees; } }