/** * */ package edu.berkeley.nlp.PCFGLA; import java.io.Serializable; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Random; import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveUnaryRule.SubRule; import edu.berkeley.nlp.syntax.Tree; import edu.berkeley.nlp.syntax.Trees.PennTreeRenderer; import edu.berkeley.nlp.util.Numberer; import edu.berkeley.nlp.util.Pair; /** * @author petrov * */ public class HierarchicalAdaptiveBinaryRule extends HierarchicalBinaryRule { private static final long serialVersionUID = 1L; public short[][][] mapping; Tree<Double> hierarchy; public int nParam; public SubRule[] subRuleList; // assume for now that the rule being passed in is unsplit public HierarchicalAdaptiveBinaryRule(BinaryRule b) { super(b); hierarchy = new Tree<Double>(0.0); scores = new double[1][1][1]; mapping = new short[1][1][1]; //to parameters nParam = 1; } public Pair<Integer,Integer> countParameters(){ // first one is the max_depth, second one is the number of parameters int maxDepth = hierarchy.getDepth(); nParam = hierarchy.getYield().size(); return new Pair<Integer,Integer>(maxDepth, nParam); } public HierarchicalAdaptiveBinaryRule splitRule(short[] numSubStates, short[] newNumSubStates, Random random, double randomness, boolean doNotNormalize, int mode){ splitRuleHelper(hierarchy, random, 8); // mapping = new short[newNumSubStates[this.leftChildState]][newNumSubStates[this.rightChildState]][newNumSubStates[this.parentState]]; // int finalLevel = (int)(Math.log(mapping.length)/Math.log(2)); // updateMapping((short)0, 0, 0, 0, 0, finalLevel, hierarchy); return this; } // private short updateMapping(short myID, int nextLeftSubstate, int nextRightSubstate, int nextParentSubstate, int myDepth, int finalDepth, Tree<Double> tree) { // if (tree.isLeaf()){ // if (myDepth==finalDepth){ // mapping[nextLeftSubstate][nextRightSubstate][nextParentSubstate] = myID; // } else { // int substatesToCover = (int)Math.pow(2,finalDepth-myDepth); // nextLeftSubstate *= substatesToCover; // nextRightSubstate *= substatesToCover; // nextParentSubstate *= substatesToCover; // for (int i=0; i<substatesToCover; i++){ // for (int j=0; j<substatesToCover; j++){ // for (int k=0; k<substatesToCover; k++){ // mapping[nextLeftSubstate+i][nextRightSubstate+j][nextParentSubstate+k] = myID; // } // } // } // } // myID++; // } else { // int i = 0; // for (Tree<Double> child : tree.getChildren()){ // myID = updateMapping(myID, nextLeftSubstate*2 + (i/4), nextRightSubstate*2 + ((i/2)%2), nextParentSubstate*2 + (i%2), myDepth+1, finalDepth, child); // i++; // } // } // return myID; // } private void splitRuleHelper(Tree<Double> tree, Random random, int splitFactor) { if (tree.isLeaf()){ if (tree.getLabel()!=0||nParam==1){ // split it ArrayList<Tree<Double>> children = new ArrayList<Tree<Double>>(splitFactor); for (int i=0; i<splitFactor; i++){ Tree<Double> child = new Tree<Double>(random.nextDouble()/100.0); children.add(child); } tree.setChildren(children); nParam += splitFactor-1; // } else { //perturb it // tree.setLabel(random.nextDouble()/100.0); } } else { for (Tree<Double> child : tree.getChildren()){ splitRuleHelper(child, random, splitFactor); } } } public void explicitlyComputeScores(int finalLevel, short[] newNumSubStates){ // int nSubstates = (int)Math.pow(2, finalLevel); // scores = new double[nSubstates][nSubstates][nSubstates]; // int nextSubstate = fillScores((short)0, 0, 0, 0, 0, 0, finalLevel, hierarchy); // if (nextSubstate != nParam) // System.out.println("Didn't fill all scores!"); computeSubRuleList(); } // private short fillScores(short myID, double previousScore, int nextLeftSubstate, int nextRightSubstate, int nextParentSubstate, int myDepth, int finalDepth, Tree<Double> tree){ // if (tree.isLeaf()){ // double myScore = Math.exp(previousScore + tree.getLabel()); // if (myDepth==finalDepth){ // scores[nextLeftSubstate][nextRightSubstate][nextParentSubstate] = myScore; // } else { // int substatesToCover = (int)Math.pow(2,finalDepth-myDepth); // nextLeftSubstate *= substatesToCover; // nextRightSubstate *= substatesToCover; // nextParentSubstate *= substatesToCover; // for (int i=0; i<substatesToCover; i++){ // for (int j=0; j<substatesToCover; j++){ // for (int k=0; k<substatesToCover; k++){ // scores[nextLeftSubstate+i][nextRightSubstate+j][nextParentSubstate+k] = myScore; // } // } // } // } // myID++; // } else { // double myScore = previousScore + tree.getLabel(); // int i = 0; // for (Tree<Double> child : tree.getChildren()){ // myID = fillScores(myID, myScore, nextLeftSubstate*2 + (i/4), nextRightSubstate*2 + ((i/2)%2), nextParentSubstate*2 + (i%2), myDepth+1, finalDepth, child); // i++; // } // } // return myID; // } public void updateScores(double[] scores){ int nSubstates = updateHierarchy(hierarchy, 0, scores); if (nSubstates != nParam) System.out.println("Didn't update all parameters"); // if (subRuleList!=null){ // int i = 0; // for (SubRule r : subRuleList){ // r.score = scores[this.identifier + i++]; // } // } } private int updateHierarchy(Tree<Double> tree, int nextSubstate, double[] scores) { if (tree.isLeaf()){ double val = scores[identifier + nextSubstate++]; if (val>200) { val = 0; System.out.println("Ignored proposed binary value since it was danegrous"); } else tree.setLabel(val); } else { for (Tree<Double> child : tree.getChildren()){ nextSubstate = updateHierarchy(child, nextSubstate, scores); } } return nextSubstate; } public int mergeRule() { int paramBefore = nParam; compactifyHierarchy(hierarchy); scores = null; mapping = null; subRuleList = null; scoreHierarchy = null; return paramBefore - nParam; } /** * @return */ public List<Double> getFinalLevel() { return hierarchy.getYield(); } private void compactifyHierarchy(Tree<Double> tree){ if (tree.getDepth()==2){ boolean allZero = true; for (Tree<Double> child : tree.getChildren()){ allZero = allZero && child.getLabel()==0; } if (allZero) { nParam -= tree.getChildren().size()-1; tree.setChildren(Collections.EMPTY_LIST); } } else { for (Tree<Double> child : tree.getChildren()){ compactifyHierarchy(child); } } } public String toStringShort(){ Numberer n = Numberer.getGlobalNumberer("tags"); String lState = (String)n.object(leftChildState); String rState = (String)n.object(rightChildState); String pState = (String)n.object(parentState); return (pState+" -> "+lState+" "+rState); } public String toString(){ StringBuilder sb = new StringBuilder(); Numberer n = Numberer.getGlobalNumberer("tags"); String lState = (String)n.object(leftChildState); String rState = (String)n.object(rightChildState); String pState = (String)n.object(parentState); sb.append(pState+" -> "+lState+" "+rState+"\n"); if (subRuleList==null){ compactifyHierarchy(hierarchy); lastLevel = hierarchy.getDepth(); computeSubRuleList(); } for (SubRule rule : subRuleList){ sb.append(rule.toString(lastLevel-1)); sb.append("\n"); } // sb.append(PennTreeRenderer.render(hierarchy)); sb.append("\n"); // sb.append(Arrays.toString(scores)); return sb.toString(); } public int countNonZeroFeatures() { int total = 0; for (Tree<Double> d : hierarchy.getPreOrderTraversal()) { if (d.getLabel()!=0) total++; } return total; } public int countNonZeroFringeFeatures() { int total = 0; for (Tree<Double> d : hierarchy.getTerminals()) { if (d.getLabel()!=0) total++; } return total; } public void computeSubRuleList(){ subRuleList = new SubRule[nParam]; int nRules = computeSubRules(0, 0, 0, 0, 0, 0, hierarchy); if (nRules != nParam) System.out.println("A rule got lost"); } private int computeSubRules(int myID, double previousScore, int nextLeftSubstate, int nextRightSubstate, int nextParentSubstate, int myDepth, Tree<Double> tree){ if (tree.isLeaf()){ double myScore = Math.exp(previousScore + tree.getLabel()); SubRule rule = new SubRule((short)nextLeftSubstate, (short)nextRightSubstate, (short)nextParentSubstate, (short)myDepth, myScore); subRuleList[myID]=rule; myID++; } else { double myScore = previousScore + tree.getLabel(); int i = 0; for (Tree<Double> child : tree.getChildren()){ myID = computeSubRules(myID, myScore, nextLeftSubstate*2 + (i/4), nextRightSubstate*2 + ((i/2)%2), nextParentSubstate*2 + (i%2), myDepth+1, child); i++; } } return myID; } class SubRule implements Serializable{ private static final long serialVersionUID = 1L; short lChild, rChild, parent, level; double score; SubRule(short lC, short rC, short p, short l, double s){ lChild = lC; rChild = rC; parent = p; level = l; score = s; } public String toString(){ String s = "["+parent+"] \t -> \t ["+lChild+"] \t ["+rChild+"] \t "+score; return s; } public String toString(int finalLevel){ if (finalLevel==level) return toString(); int k = (int)Math.pow(2, finalLevel-level); String s = "["+(k*parent)+"-"+(k*parent+k-1)+"] \t -> \t ["+(k*lChild)+"-"+(k*lChild+k-1)+"] \t ["+(k*rChild)+"-"+(k*rChild+k-1)+"] \t "+score+"\t level: "+level; return s; } } }