/** * */ package edu.berkeley.nlp.PCFGLA; import java.util.ArrayList; import java.util.List; import java.util.Random; import edu.berkeley.nlp.util.ArrayUtil; import edu.berkeley.nlp.util.Numberer; /** * @author petrov * */ public class HierarchicalBinaryRule extends BinaryRule { private static final long serialVersionUID = 1L; public HierarchicalBinaryRule(HierarchicalBinaryRule b) { super(b); this.scoreHierarchy = new ArrayList<double[][][]>(); for (double[][][] scores : b.scoreHierarchy){ this.scoreHierarchy.add(ArrayUtil.clone(scores)); } this.lastLevel = b.lastLevel; this.scores = null; } // assume for now that the rule being passed in is unsplit public HierarchicalBinaryRule(BinaryRule b) { super(b); this.scoreHierarchy = new ArrayList<double[][][]>(); double[][][] scoreThisLevel = new double[1][1][1]; scoreThisLevel[0][0][0] = Math.log(b.scores[0][0][0]); scoreHierarchy.add(scoreThisLevel); this.lastLevel = 0; this.scores = null; } /* * new stuff below */ /** * before: scores[leftSubState][rightSubState][parentSubState] gives score for this rule * now: have a hierarchy of refinements */ List<double[][][]> scoreHierarchy; public int lastLevel = -1; public void explicitlyComputeScores(int finalLevel, short[] newNumSubStates){ int newMaxStates = (int)Math.pow(2,finalLevel+1); int newPStates = Math.min(newMaxStates, newNumSubStates[this.parentState]); int newLStates = Math.min(newMaxStates, newNumSubStates[this.leftChildState]); int newRStates = Math.min(newMaxStates, newNumSubStates[this.rightChildState]); this.scores = new double[newLStates][newRStates][newPStates]; for (int level=0; level<=lastLevel; level++){ double[][][] scoresThisLevel = scoreHierarchy.get(level); if (scoresThisLevel == null) continue; int divisorL = newLStates / scoresThisLevel.length; int divisorR = newRStates / scoresThisLevel[0].length; int divisorP = newPStates / scoresThisLevel[0][0].length; for (int lChild=0; lChild<newLStates; lChild++){ for (int rChild=0; rChild<newRStates; rChild++){ for (int parent=0; parent<newPStates; parent++){ this.scores[lChild][rChild][parent] += scoresThisLevel[lChild/divisorL][rChild/divisorR][parent/divisorP]; } } } } for (int lChild=0; lChild<newLStates; lChild++){ for (int rChild=0; rChild<newRStates; rChild++){ for (int parent=0; parent<newPStates; parent++){ this.scores[lChild][rChild][parent] = Math.exp(scores[lChild][rChild][parent]); } } } } public double[][][] getLastLevel(){ return this.scoreHierarchy.get(lastLevel); } public HierarchicalBinaryRule splitRule(short[] numSubStates, short[] newNumSubStates, Random random, double randomness, boolean doNotNormalize, int mode) { // when splitting on parent, never split on ROOT, but otherwise split everything if (mode!=2) throw new Error("Can't split hiereachical rule in this mode!"); int newMaxStates = (int)Math.pow(2,lastLevel+1); int newPStates = Math.min(newMaxStates, newNumSubStates[this.parentState]); int newLStates = Math.min(newMaxStates, newNumSubStates[this.leftChildState]); int newRStates = Math.min(newMaxStates, newNumSubStates[this.rightChildState]); double[][][] newScores = new double[newLStates][newRStates][newPStates]; for (int lChild=0; lChild<newLStates; lChild++){ for (int rChild=0; rChild<newRStates; rChild++){ for (int parent=0; parent<newPStates; parent++){ newScores[lChild][rChild][parent] = random.nextDouble()/100.0; } } } HierarchicalBinaryRule newRule = new HierarchicalBinaryRule(this); newRule.scoreHierarchy.add(newScores); newRule.lastLevel++; return newRule; } public int mergeRule() { double[][][] scoresFinalLevel = scoreHierarchy.get(lastLevel); boolean allZero = true; for (int lChild=0; lChild<scoresFinalLevel.length; lChild++){ for (int rChild=0; rChild<scoresFinalLevel[0].length; rChild++){ for (int parent=0; parent<scoresFinalLevel[0][0].length; parent++){ allZero = allZero && (scoresFinalLevel[lChild][rChild][parent] == 0.0); } } } if (allZero) { scoresFinalLevel = null; scoreHierarchy.remove(lastLevel); lastLevel--; return 1; } return 0; } public String toString() { Numberer n = Numberer.getGlobalNumberer("tags"); String lState = (String)n.object(leftChildState); String rState = (String)n.object(rightChildState); String pState = (String)n.object(parentState); StringBuilder sb = new StringBuilder(); if (scores==null) return pState+" -> "+lState+" "+rState+"\n"; //sb.append(pState+ " -> "+lState+ " "+rState+ "\n"); sb.append(pState+" -> "+lState+" "+rState+"\n"); sb.append(ArrayUtil.toString(scores)+"\n"); for (double[][][] s : scoreHierarchy){ sb.append(ArrayUtil.toString(s)+"\n"); } sb.append("\n"); return sb.toString(); } public int countNonZeroFeatures(){ int total = 0; for (int level=0; level<=lastLevel; level++){ double[][][] scoresThisLevel = scoreHierarchy.get(level); if (scoresThisLevel == null) continue; for (int lChild=0; lChild<scoresThisLevel.length; lChild++){ for (int rChild=0; rChild<scoresThisLevel.length; rChild++){ for (int parent=0; parent<scoresThisLevel.length; parent++){ if (scoresThisLevel[lChild][rChild][parent]!=0) total++; } } } } return total; } }