/** * */ package edu.berkeley.nlp.PCFGLA; import java.util.ArrayList; import java.util.List; import java.util.Random; import edu.berkeley.nlp.util.ArrayUtil; import edu.berkeley.nlp.util.Numberer; /** * @author petrov * */ public class HierarchicalUnaryRule extends UnaryRule { private static final long serialVersionUID = 1L; public HierarchicalUnaryRule(HierarchicalUnaryRule b) { super(b); this.scoreHierarchy = new ArrayList<double[][]>(); for (double[][] scores : b.scoreHierarchy){ this.scoreHierarchy.add(ArrayUtil.clone(scores)); } this.lastLevel = b.lastLevel; this.scores = null; } // assume for now that the rule being passed in is unsplit public HierarchicalUnaryRule(UnaryRule b) { super(b); this.scoreHierarchy = new ArrayList<double[][]>(); double[][] scoreThisLevel = new double[1][1]; scoreThisLevel[0][0] = Math.log(b.scores[0][0]); scoreHierarchy.add(scoreThisLevel); this.lastLevel = 0; this.scores = null; } /* * new stuff below */ /** * before: scores[childSubState][parentSubState] gives score for this rule * now: have a hierarchy of refinements */ List<double[][]> scoreHierarchy; public int lastLevel = -1; public void explicitlyComputeScores(int finalLevel, short[] newNumSubStates){ int newMaxStates = (int)Math.pow(2,finalLevel+1); int newPStates = Math.min(newMaxStates, newNumSubStates[this.parentState]); int newCStates = Math.min(newMaxStates, newNumSubStates[this.childState]); newPStates = (this.parentState==0) ? 1 : newPStates; this.scores = new double[newCStates][newPStates]; for (int level=0; level<=lastLevel; level++){ double[][] scoresThisLevel = scoreHierarchy.get(level); if (scoresThisLevel == null) continue; int divisorC = newCStates / scoresThisLevel.length; int divisorP = newPStates / scoresThisLevel[0].length; for (int child=0; child<newCStates; child++){ for (int parent=0; parent<newPStates; parent++){ this.scores[child][parent] += scoresThisLevel[child/divisorC][parent/divisorP]; } } } for (int child=0; child<newCStates; child++){ for (int parent=0; parent<newPStates; parent++){ this.scores[child][parent] = Math.exp(scores[child][parent]); } } } public double[][] getLastLevel(){ return this.scoreHierarchy.get(lastLevel); } public HierarchicalUnaryRule splitRule(short[] numSubStates, short[] newNumSubStates, Random random, double randomness, boolean doNotNormalize, int mode) { // when splitting on parent, never split on ROOT, but otherwise split everything if (mode!=2) throw new Error("Can't split hiereachical rule in this mode!"); int newMaxStates = (int)Math.pow(2,lastLevel+1); int newPStates = Math.min(newMaxStates, newNumSubStates[this.parentState]); int newCStates = Math.min(newMaxStates, newNumSubStates[this.childState]); if (parentState==0) newPStates = 1; double[][] newScores = new double[newCStates][newPStates]; for (int child=0; child<newCStates; child++){ for (int parent=0; parent<newPStates; parent++){ newScores[child][parent] = random.nextDouble()/100.0; } } HierarchicalUnaryRule newRule = new HierarchicalUnaryRule(this); newRule.scoreHierarchy.add(newScores); newRule.lastLevel++; return newRule; } public int mergeRule() { double[][] scoresFinalLevel = scoreHierarchy.get(lastLevel); boolean allZero = true; for (int child=0; child<scoresFinalLevel.length; child++){ for (int parent=0; parent<scoresFinalLevel[0].length; parent++){ allZero = allZero && (scoresFinalLevel[child][parent] == 0.0); } } if (allZero) { scoresFinalLevel = null; scoreHierarchy.remove(lastLevel); lastLevel--; return 1; } return 0; } public String toString() { Numberer n = Numberer.getGlobalNumberer("tags"); String cState = (String)n.object(childState); String pState = (String)n.object(parentState); if (scores==null) return pState+" -> "+cState+"\n"; StringBuilder sb = new StringBuilder(); sb.append(pState+" -> "+cState+"\n"); sb.append(ArrayUtil.toString(scores)+"\n"); for (double[][] s : scoreHierarchy){ sb.append(ArrayUtil.toString(s)+"\n"); } sb.append("\n"); // for (int cS=0; cS<scores.length; cS++){ // if (scores[cS]==null) continue; // for (int pS=0; pS<scores[cS].length; pS++){ // double p = scores[cS][pS]; // if (p>0) // sb.append(pState+"_"+pS+ " -> " + cState+"_"+cS +" "+p+"\n"); // } // } return sb.toString(); } public int countNonZeroFeatures(){ int total = 0; for (int level=0; level<=lastLevel; level++){ double[][] scoresThisLevel = scoreHierarchy.get(level); if (scoresThisLevel == null) continue; for (int child=0; child<scoresThisLevel.length; child++){ for (int parent=0; parent<scoresThisLevel[0].length; parent++){ if (scoresThisLevel[child][parent]!=0) total++; } } } return total; } }