/**
*
*/
package edu.berkeley.nlp.PCFGLA;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees.PennTreeRenderer;
import edu.berkeley.nlp.util.Pair;
/**
* @author petrov
*
*/
public class HierarchicalAdaptiveLexicalRule implements Serializable{
private static final long serialVersionUID = 1L;
double[] scores;
public short[] mapping;
Tree<Double> hierarchy;
public int nParam;
public int identifier;
// HierarchicalAdaptiveLexicalRule(short t, int w){
// this.tag = t;
// this.wordIndex = w;
// }
HierarchicalAdaptiveLexicalRule(){
hierarchy = new Tree<Double>(0.0);
scores = new double[1];
mapping = new short[1];
nParam = 1;
}
public Pair<Integer,Integer> countParameters(){
// first one is the max_depth, second one is the number of parameters
int maxDepth = hierarchy.getDepth();
nParam = hierarchy.getYield().size();
return new Pair<Integer,Integer>(maxDepth, nParam);
}
public void splitRule(int nSubstates){
splitRuleHelper(hierarchy, 2);
mapping = new short[nSubstates];
int finalLevel = (int)(Math.log(mapping.length)/Math.log(2));
updateMapping((short)0, 0, 0, finalLevel, hierarchy);
// mapping[0] = (short)0;
// mapping[1] = (short)1;
}
private Pair<Short,Integer> updateMapping(short myID, int nextSubstate, int myDepth, int finalDepth, Tree<Double> tree) {
if (tree.isLeaf()){
if (myDepth==finalDepth){
mapping[nextSubstate++] = myID;
} else {
int substatesToCover = (int)Math.pow(2,finalDepth-myDepth);
for (int i=0; i<substatesToCover; i++){
mapping[nextSubstate++] = myID;
}
}
myID++;
} else {
for (Tree<Double> child : tree.getChildren()){
Pair<Short, Integer> tmp = updateMapping(myID, nextSubstate, myDepth+1, finalDepth, child);
myID = tmp.getFirst();
nextSubstate = tmp.getSecond();
}
}
return new Pair<Short, Integer>(myID, nextSubstate);
}
private void splitRuleHelper(Tree<Double> tree, int splitFactor) {
if (tree.isLeaf()){
if (tree.getLabel()!=0||nParam==1){ // split it
ArrayList<Tree<Double>> children = new ArrayList<Tree<Double>>(splitFactor);
for (int i=0; i<splitFactor; i++){
Tree<Double> child = new Tree<Double>((GrammarTrainer.RANDOM.nextDouble()-.5)/100.0);
children.add(child);
}
tree.setChildren(children);
nParam += splitFactor-1;
// } else { //perturb it
// tree.setLabel(GrammarTrainer.RANDOM.nextDouble()/100.0);
}
} else {
for (Tree<Double> child : tree.getChildren()){
splitRuleHelper(child, splitFactor);
}
}
}
public void explicitlyComputeScores(int finalLevel, final boolean usingOnlyLastLevel){
int nSubstates = (int)Math.pow(2, finalLevel);
scores = new double[nSubstates];
int nextSubstate = fillScores(0, 0, 0, finalLevel, hierarchy, usingOnlyLastLevel);
if (nextSubstate != nSubstates)
System.out.println("Didn't fill all lexical scores!");
mapping = new short[nSubstates];
updateMapping((short)0, 0, 0, finalLevel, hierarchy);
}
private int fillScores(double previousScore, int nextSubstate, int myDepth, int finalDepth, Tree<Double> tree, final boolean usingOnlyLastLevel){
if (tree.isLeaf()){
double myScore = (usingOnlyLastLevel) ? Math.exp(tree.getLabel()) : Math.exp(previousScore + tree.getLabel());
if (myDepth==finalDepth){
scores[nextSubstate++] = myScore;
} else {
int substatesToCover = (int)Math.pow(2,finalDepth-myDepth);
for (int i=0; i<substatesToCover; i++){
scores[nextSubstate++] = myScore;
}
}
} else {
double myScore = previousScore + tree.getLabel();
for (Tree<Double> child : tree.getChildren()){
nextSubstate = fillScores(myScore, nextSubstate, myDepth+1, finalDepth, child, usingOnlyLastLevel);
}
}
return nextSubstate;
}
public void updateScores(double[] scores){
int nSubstates = updateHierarchy(hierarchy, 0, scores);
if (nSubstates != nParam) System.out.println("Didn't update all parameters");
}
private int updateHierarchy(Tree<Double> tree, int nextSubstate, double[] scores) {
if (tree.isLeaf()){
double val = scores[identifier + nextSubstate++];
if (val>200) {
System.out.println("Ignored proposed lexical value since it was danegrous");
val = 0;
} else
tree.setLabel(val);
} else {
for (Tree<Double> child : tree.getChildren()){
nextSubstate = updateHierarchy(child, nextSubstate, scores);
}
}
return nextSubstate;
}
/**
* @return
*/
public List<Double> getFinalLevel() {
return hierarchy.getYield();
}
private void compactifyHierarchy(Tree<Double> tree){
if (tree.getDepth()==2){
boolean allZero = true;
for (Tree<Double> child : tree.getChildren()){
allZero = allZero && (child.getLabel()==0.0);
}
if (allZero) {
nParam -= tree.getChildren().size()-1;
tree.setChildren(Collections.EMPTY_LIST);
}
} else {
for (Tree<Double> child : tree.getChildren()){
compactifyHierarchy(child);
}
}
}
public String toString(){
StringBuilder sb = new StringBuilder();
compactifyHierarchy(hierarchy);
sb.append(Arrays.toString(scores));
sb.append("\n");
sb.append(PennTreeRenderer.render(hierarchy));
sb.append("\n");
return sb.toString();
}
public int mergeRule() {
int paramBefore = nParam;
compactifyHierarchy(hierarchy);
scores = null;
mapping = null;
return paramBefore - nParam;
}
public int countNonZeroFeatures() {
int total = 0;
for (Tree<Double> d : hierarchy.getPreOrderTraversal()) { if (d.getLabel()!=0) total++; }
return total;
}
public int countNonZeroFringeFeatures() {
int total = 0;
for (Tree<Double> d : hierarchy.getTerminals()) { if (d.getLabel()!=0) total++; }
return total;
}
}