package edu.berkeley.nlp.PCFGLA; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import edu.berkeley.nlp.syntax.Tree; import edu.berkeley.nlp.syntax.Trees; import edu.berkeley.nlp.util.Numberer; public class TreeGenerator { static Grammar grammar; static SophisticatedLexicon lexicon; static Numberer tagNumberer; /** * @param args */ public static void main(String[] args) { if (args.length <3) { System.out.println("usage: java TreeGenerator <input file for grammar> <maxLength> <nTrees>\n"); System.exit(2); } String inFileName = args[0]; int maxLength = Integer.parseInt(args[1]); int nTrees = Integer.parseInt(args[2]); System.out.println("Loading grammar from " + inFileName + "."); ParserData pData = ParserData.Load(inFileName); if (pData == null) { System.out.println("Failed to load grammar from file" + inFileName + "."); System.exit(1); } grammar = pData.getGrammar(); lexicon = (SophisticatedLexicon)pData.getLexicon(); Numberer.setNumberers(pData.getNumbs()); tagNumberer = Numberer.getGlobalNumberer("tags"); grammar.splitRules(); int nGen = 0; while (nGen < nTrees){ Tree<String> artTree = generateTree(0, 0); System.out.println(artTree.getYield().toString()); Tree<String> tree = TreeAnnotations.unAnnotateTree(artTree); if (tree.getYield().size() > maxLength) continue; System.out.println("Generated tree of length "+tree.getYield().size()+".\n"+Trees.PennTreeRenderer.render(tree)+"\n"); nGen++; } } private static Tree<String> generateTree(int pState, int pSubState) { String root = (String)tagNumberer.object(pState); //System.out.println("Current parent: "+root+"-"+pSubState); BinaryRule[] bRules = grammar.splitRulesWithP(pState); //System.out.println("Number of binary rules: " +bRules.length); double randval = GrammarTrainer.RANDOM.nextDouble(); double sum=0; ArrayList<Tree<String>> children = new ArrayList<Tree<String>>(); for (int i = 0; i < bRules.length; i++) { double[][][] scores = bRules[i].scores; for (int lC=0; lC<scores.length; lC++){ for (int rC=0; rC<scores[lC].length; rC++){ if (scores[lC][rC]!=null) sum += scores[lC][rC][pSubState]; if (sum>randval){ children.add( generateTree(bRules[i].leftChildState, lC) ); children.add( generateTree(bRules[i].rightChildState, rC) ); return new Tree<String>( root, children ); } } } } List<UnaryRule> uRulesList = grammar.getUnaryRulesByParent(pState); //) getClosedViterbiUnaryRulesByParent( //for (int i = 0; i < uRules.length; i++) { //double[][] scores = uRules[i].scores; for (UnaryRule uRule : uRulesList){ double[][] scores = uRule.scores; for (int uC=0; uC<scores.length; uC++){ if (uRule.parentState==uRule.childState) continue; if (scores[uC]!=null) sum += scores[uC][pSubState]; if (sum>randval){ children.add( generateTree(uRule.childState, uC) ); return new Tree<String>( root, children ); } } } if (sum==0) { //System.out.println("There are no rules with "+root+" as parent."); String word = sampleWord(pState, pSubState); List<Tree<String>> child = Collections.singletonList( new Tree<String>(word) ); return new Tree<String>( root, child ); } else throw new Error("rule probability sum "+sum+" is more than 1!"); } // P(T|W) = P(W|T)*P(W)/P(T) private static String sampleWord(int tag, int substate) { String w = (String)tagNumberer.object(tag); double randval = GrammarTrainer.RANDOM.nextDouble(); double sum=0; HashMap<String,double[]> wordToTagCounter = lexicon.wordToTagCounters[tag]; for (String word : wordToTagCounter.keySet()){ double c_TW = 0; if (lexicon.wordToTagCounters[tag]!=null && lexicon.wordToTagCounters[tag].get(word)!=null) { c_TW = wordToTagCounter.get(word)[substate]; } double c_W = lexicon.wordCounter.getCount(word); double c_T = lexicon.tagCounter[tag][substate]; double total = lexicon.totalTokens; double pb_T_W = c_TW / c_W; double p_T = (c_T / total); double p_W = (c_W / total); double pb_W_T = pb_T_W * p_W / p_T; sum += pb_W_T; if (sum>randval) return word; } return w; } }