/**
*
*/
package edu.berkeley.nlp.PCFGLA;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Numberer;
/**
* @author petrov
*
*/
public class HierarchicalUnaryRule extends UnaryRule {
private static final long serialVersionUID = 1L;
public HierarchicalUnaryRule(HierarchicalUnaryRule b) {
super(b);
this.scoreHierarchy = new ArrayList<double[][]>();
for (double[][] scores : b.scoreHierarchy){
this.scoreHierarchy.add(ArrayUtil.clone(scores));
}
this.lastLevel = b.lastLevel;
this.scores = null;
}
// assume for now that the rule being passed in is unsplit
public HierarchicalUnaryRule(UnaryRule b) {
super(b);
this.scoreHierarchy = new ArrayList<double[][]>();
double[][] scoreThisLevel = new double[1][1];
scoreThisLevel[0][0] = Math.log(b.scores[0][0]);
scoreHierarchy.add(scoreThisLevel);
this.lastLevel = 0;
this.scores = null;
}
/*
* new stuff below
*/
/**
* before: scores[childSubState][parentSubState] gives score for this rule
* now: have a hierarchy of refinements
*/
List<double[][]> scoreHierarchy;
public int lastLevel = -1;
public void explicitlyComputeScores(int finalLevel, short[] newNumSubStates){
int newMaxStates = (int)Math.pow(2,finalLevel+1);
int newPStates = Math.min(newMaxStates, newNumSubStates[this.parentState]);
int newCStates = Math.min(newMaxStates, newNumSubStates[this.childState]);
newPStates = (this.parentState==0) ? 1 : newPStates;
this.scores = new double[newCStates][newPStates];
for (int level=0; level<=lastLevel; level++){
double[][] scoresThisLevel = scoreHierarchy.get(level);
if (scoresThisLevel == null) continue;
int divisorC = newCStates / scoresThisLevel.length;
int divisorP = newPStates / scoresThisLevel[0].length;
for (int child=0; child<newCStates; child++){
for (int parent=0; parent<newPStates; parent++){
this.scores[child][parent] += scoresThisLevel[child/divisorC][parent/divisorP];
}
}
}
for (int child=0; child<newCStates; child++){
for (int parent=0; parent<newPStates; parent++){
this.scores[child][parent] = Math.exp(scores[child][parent]);
}
}
}
public double[][] getLastLevel(){
return this.scoreHierarchy.get(lastLevel);
}
public HierarchicalUnaryRule splitRule(short[] numSubStates, short[] newNumSubStates, Random random, double randomness, boolean doNotNormalize, int mode) {
// when splitting on parent, never split on ROOT, but otherwise split everything
if (mode!=2) throw new Error("Can't split hiereachical rule in this mode!");
int newMaxStates = (int)Math.pow(2,lastLevel+1);
int newPStates = Math.min(newMaxStates, newNumSubStates[this.parentState]);
int newCStates = Math.min(newMaxStates, newNumSubStates[this.childState]);
if (parentState==0) newPStates = 1;
double[][] newScores = new double[newCStates][newPStates];
for (int child=0; child<newCStates; child++){
for (int parent=0; parent<newPStates; parent++){
newScores[child][parent] = random.nextDouble()/100.0;
}
}
HierarchicalUnaryRule newRule = new HierarchicalUnaryRule(this);
newRule.scoreHierarchy.add(newScores);
newRule.lastLevel++;
return newRule;
}
public int mergeRule() {
double[][] scoresFinalLevel = scoreHierarchy.get(lastLevel);
boolean allZero = true;
for (int child=0; child<scoresFinalLevel.length; child++){
for (int parent=0; parent<scoresFinalLevel[0].length; parent++){
allZero = allZero && (scoresFinalLevel[child][parent] == 0.0);
}
}
if (allZero) {
scoresFinalLevel = null;
scoreHierarchy.remove(lastLevel);
lastLevel--;
return 1;
}
return 0;
}
public String toString() {
Numberer n = Numberer.getGlobalNumberer("tags");
String cState = (String)n.object(childState);
String pState = (String)n.object(parentState);
if (scores==null) return pState+" -> "+cState+"\n";
StringBuilder sb = new StringBuilder();
sb.append(pState+" -> "+cState+"\n");
sb.append(ArrayUtil.toString(scores)+"\n");
for (double[][] s : scoreHierarchy){
sb.append(ArrayUtil.toString(s)+"\n");
}
sb.append("\n");
// for (int cS=0; cS<scores.length; cS++){
// if (scores[cS]==null) continue;
// for (int pS=0; pS<scores[cS].length; pS++){
// double p = scores[cS][pS];
// if (p>0)
// sb.append(pState+"_"+pS+ " -> " + cState+"_"+cS +" "+p+"\n");
// }
// }
return sb.toString();
}
public int countNonZeroFeatures(){
int total = 0;
for (int level=0; level<=lastLevel; level++){
double[][] scoresThisLevel = scoreHierarchy.get(level);
if (scoresThisLevel == null) continue;
for (int child=0; child<scoresThisLevel.length; child++){
for (int parent=0; parent<scoresThisLevel[0].length; parent++){
if (scoresThisLevel[child][parent]!=0) total++;
}
}
}
return total;
}
}