/** * */ package edu.berkeley.nlp.PCFGLA; import java.io.IOException; import java.io.PrintWriter; import java.io.Writer; import java.util.ArrayList; import java.util.List; import edu.berkeley.nlp.PCFGLA.smoothing.Smoother; /** * @author petrov * */ public class HierarchicalGrammar extends Grammar { /** * @param nSubStates * @param findClosedPaths * @param smoother * @param oldGrammar * @param thresh */ public HierarchicalGrammar(short[] nSubStates, boolean findClosedPaths, Smoother smoother, Grammar oldGrammar, double thresh) { super(nSubStates, findClosedPaths, smoother, oldGrammar, thresh); } private static final long serialVersionUID = 1L; public HierarchicalGrammar(Grammar gr){ super(gr.numSubStates,gr.findClosedPaths,gr.smoother,gr,gr.threshold); for (BinaryRule oldRule : gr.binaryRuleMap.keySet()) { HierarchicalBinaryRule newRule = new HierarchicalBinaryRule(oldRule); addBinary(newRule); } for (UnaryRule oldRule : gr.unaryRuleMap.keySet()) { HierarchicalUnaryRule newRule = new HierarchicalUnaryRule(oldRule); addUnary(newRule); } if (true) { closedSumRulesWithParent = closedViterbiRulesWithParent = unaryRulesWithParent; closedSumRulesWithChild = closedViterbiRulesWithChild = unaryRulesWithC; } else computePairsOfUnaries(); makeCRArrays(); isGrammarTag = gr.isGrammarTag; } public void splitRules(){ explicitlyComputeScores(finalLevel); super.splitRules(); } public void explicitlyComputeScores(int finalLevel){ for (BinaryRule oldRule : binaryRuleMap.keySet()) { HierarchicalBinaryRule newRule = (HierarchicalBinaryRule)oldRule; newRule.explicitlyComputeScores(finalLevel, numSubStates); } for (UnaryRule oldRule : unaryRuleMap.keySet()) { HierarchicalUnaryRule newRule = (HierarchicalUnaryRule)oldRule; newRule.explicitlyComputeScores(finalLevel, numSubStates); } } public HierarchicalGrammar splitAllStates(double randomness, int[] counts, boolean moreSubstatesThanCounts, int mode){ short[] newNumSubStates = new short[numSubStates.length]; newNumSubStates[0] = 1; for (short i = 1; i < numSubStates.length; i++) { // don't split a state into more substates than times it was actaully seen if (!moreSubstatesThanCounts && numSubStates[i]>=counts[i]) { newNumSubStates[i]=numSubStates[i]; } else{ newNumSubStates[i] = (short)(numSubStates[i] * 2); } } HierarchicalGrammar newGrammar = newInstance(newNumSubStates);// HierarchicalGrammar(newNumSubStates,this.findClosedPaths,this.smoother,this,this.threshold); for (BinaryRule oldRule : binaryRuleMap.keySet()) { HierarchicalBinaryRule newRule = (HierarchicalBinaryRule)oldRule; newGrammar.addBinary(newRule.splitRule(numSubStates, newGrammar.numSubStates, GrammarTrainer.RANDOM, randomness, true, mode)); } for (UnaryRule oldRule : unaryRuleMap.keySet()) { HierarchicalUnaryRule newRule = (HierarchicalUnaryRule)oldRule; newGrammar.addUnary(newRule.splitRule(numSubStates, newGrammar.numSubStates, GrammarTrainer.RANDOM, randomness, true, mode)); } if (true) { newGrammar.closedSumRulesWithParent = newGrammar.closedViterbiRulesWithParent = newGrammar.unaryRulesWithParent; newGrammar.closedSumRulesWithChild = newGrammar.closedViterbiRulesWithChild = newGrammar.unaryRulesWithC; } else newGrammar.computePairsOfUnaries(); newGrammar.makeCRArrays(); newGrammar.isGrammarTag = this.isGrammarTag; return newGrammar; } public HierarchicalGrammar newInstance(short[] newNumSubStates) { return new HierarchicalGrammar(newNumSubStates,this.findClosedPaths,this.smoother,this,this.threshold); } public void mergeGrammar(){ int nBinaryMerged = 0, nUnaryMerged = 0; for (BinaryRule oldRule : binaryRuleMap.keySet()) { HierarchicalBinaryRule newRule = (HierarchicalBinaryRule)oldRule; nBinaryMerged += newRule.mergeRule(); } for (UnaryRule oldRule : unaryRuleMap.keySet()) { HierarchicalUnaryRule newRule = (HierarchicalUnaryRule)oldRule; nUnaryMerged += newRule.mergeRule(); } System.out.println("Removed "+nBinaryMerged+" binary and "+nUnaryMerged+" unary parameters."); } public HierarchicalGrammar copyGrammar(boolean noUnaryChains) { short[] newNumSubStates = numSubStates.clone(); HierarchicalGrammar grammar = newInstance(newNumSubStates); for (BinaryRule oldRule : binaryRuleMap.keySet()) { // HierarchicalBinaryRule newRule = new BinaryRule(oldRule); // grammar.addBinary(newRule); grammar.addBinary(oldRule); } for (UnaryRule oldRule : unaryRuleMap.keySet()) { // UnaryRule newRule = new UnaryRule(oldRule); // grammar.addUnary(newRule); grammar.addUnary(oldRule); } if (noUnaryChains) { closedSumRulesWithParent = closedViterbiRulesWithParent = unaryRulesWithParent; closedSumRulesWithChild = closedViterbiRulesWithChild = unaryRulesWithC; } else grammar.computePairsOfUnaries(); grammar.makeCRArrays(); grammar.isGrammarTag = this.isGrammarTag; /* grammar.ruleIndexer = ruleIndexer; grammar.startIndex = startIndex; grammar.nEntries = nEntries; grammar.toBeIgnored = toBeIgnored;*/ return grammar; } public String toString() { printLevelCounts(); return super.toString(); } void printLevelCounts(){ int nBinaryParams=0, nUnaryParams=0, nBinaryFringeParams=0, nUnaryFringeParams=0; for (int state = 0; state < numStates; state++) { int[] counts = new int[6]; BinaryRule[] parentRules = this.splitRulesWithP(state); if (parentRules.length==0) continue; for (int i = 0; i < parentRules.length; i++) { HierarchicalBinaryRule r =(HierarchicalBinaryRule)parentRules[i]; counts[r.lastLevel]++; nBinaryParams += r.countNonZeroFeatures(); // nBinaryFringeParams += r.nParam; } System.out.print(tagNumberer.object(state)+", binary rules per level: "); for (int i=1; i<6; i++){ System.out.print(counts[i]+" "); } System.out.print("\n"); } // for (int i=0; i<6; i++){ // System.out.println(counts[i]+" binary rules are split upto level "+i); // counts[i] = 0; // } for (int state = 0; state < numStates; state++) { int[] counts = new int[6]; UnaryRule[] unaries = this.getClosedSumUnaryRulesByParent(state); //this.getClosedSumUnaryRulesByParent(state);// if (unaries.length==0) continue; for (int r = 0; r < unaries.length; r++) { HierarchicalUnaryRule ur =(HierarchicalUnaryRule)unaries[r]; counts[ur.lastLevel]++; nUnaryParams += ur.countNonZeroFeatures(); // nUnaryFringeParams += ur.nParam; } System.out.print(tagNumberer.object(state)+", unary rules per level: "); for (int i=1; i<6; i++){ System.out.print(counts[i]+" "); } System.out.print("\n"); } System.out.println("There are "+nBinaryParams+" binary features");//, of which "+ nBinaryFringeParams+" are on the fringe."); System.out.println("There are "+nUnaryParams+" unary features");//, of which "+ nUnaryFringeParams+" are on the fringe."); } public void writeData(Writer w) throws IOException { printLevelCounts(); super.writeData(w); } }