package edu.berkeley.nlp.PCFGLA;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import edu.berkeley.nlp.discPCFG.Linearizer;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.ScalingTools;
/**
* Simple mixture parser.
*/
public class ArrayParser implements Parser {
// i dont know how to initialize shorts...
short zero = 0, one = 1;
protected Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
// inside scores
protected double[][][] iScore; // start idx, end idx, state -> logProb
// outside scores
protected double[][][] oScore; // start idx, end idx, state -> logProb
protected int[][] narrowLExtent = null; // the rightmost left extent of state
// s ending at position i
protected int[][] wideLExtent = null; // the leftmost left extent of state s
// ending at position i
protected int[][] narrowRExtent = null; // the leftmost right extent of state
// s starting at position i
protected int[][] wideRExtent = null; // the rightmost right extent of state s
// starting at position i
protected short length;
protected int arraySize = 0;
protected int myMaxLength = 200;
Lexicon lexicon;
int numStates;
int maxNSubStates;
int[] idxC;
double[] scoresToAdd;
int touchedRules;
double[] tmpCountsArray;
Grammar grammar;
int[] stateClass;
public ArrayParser(){
}
public ArrayParser(Grammar gr, Lexicon lex) {
this.touchedRules = 0;
this.grammar = gr;
this.lexicon = lex;
this.tagNumberer = Numberer.getGlobalNumberer("tags");
this.numStates = gr.numStates;
this.maxNSubStates = maxSubStates(gr);
this.idxC = new int[maxNSubStates];
this.scoresToAdd = new double[maxNSubStates];
tmpCountsArray = new double[scoresToAdd.length*scoresToAdd.length*scoresToAdd.length];
// System.out.println("This grammar has " + numStates
// + " states and a total of " + grammar.totalSubStates() + " substates.");
}
@SuppressWarnings("unchecked")
public List<Integer>[][] getPossibleStates(List<String> sentence, double logThreshold){
length = (short)sentence.size();
initializeArrays();
initializeChart(sentence, false);
doInsideScores();
double score = iScore[0][length][0];
if (score > Double.NEGATIVE_INFINITY) {
System.out.println("\nFound a parse for sentence with length " + length
+ ". The LL is " + score + ".");
} else {
System.out.println("Did NOT find a parse for sentence with length "
+ length + ".");
}
oScore[0][length][tagNumberer.number("ROOT")] = 0.0;
doOutsideScores();
List<Integer>[][] possibleStates = new ArrayList[length + 1][length + 1];
int unprunedStates = 0;
int prunedStates = 0;
double sentenceProb = iScore[0][length][0];
for (int diff = 1; diff <= length; diff++) {
for (int start = 0; start < (length - diff + 1); start++) {
int end = start + diff;
possibleStates[start][end] = new ArrayList<Integer>();
for (int state = 0; state < numStates; state++) {
double viterbiPosterior = iScore[start][end][state]
+ oScore[start][end][state] - sentenceProb;
if (!Double.isInfinite(viterbiPosterior)) {
unprunedStates++;
}
if (viterbiPosterior > logThreshold) {
possibleStates[start][end].add(new Integer(state));
prunedStates++;
// if ((start==0)&&(end==length) )System.out.println(start+" "+end+"
// "+state);
// System.out.println("i "+iScore[start][end][state]+" o
// "+oScore[start][end][state]+" v-pos: "+viterbiPosterior);
}
}
}
}
System.out.print("Down to " + prunedStates + " states from "
+ unprunedStates + ". ");
return possibleStates;
}
//belongs in the grammar but i didnt want to change the signature for now...
public int maxSubStates(Grammar grammar) {
int max = 0;
for (int i = 0; i < numStates; i++) {
if (grammar.numSubStates[i]>max) max = grammar.numSubStates[i];
}
return max;
}
public Tree<String> getBestParse(List<String> sentence) {
System.out
.println("This parser assumes an unsplit grammar (= split grammar with 1 substate)");
length = (short)sentence.size();
initializeArrays();
initializeChart(sentence, false);
doInsideScores();
/* for (int i = 0; i < numStates; i++) {
// if (iScore[15][16][i] != null){
if (iScore[12][13][i] > -30) {// != Double.NEGATIVE_INFINITY){//
System.out.println(i + " " + (String) tagNumberer.object(i) + " "
+ iScore[12][13][i]);
}
}
*/
// for (int i =0; i<numStates; i++){
// if (iScore[0][1][i] != Double.NEGATIVE_INFINITY){
// System.out.println(i + " " + (String) tagNumberer.object(i) + "
// "+iScore[0][1][i]);}
// }
// oScore[0][length][tagNumberer.number("ROOT")] = 0.0f;
// doOutsideScores();
Tree<String> bestTree = new Tree<String>("ROOT");
double score = iScore[0][length][tagNumberer.number("ROOT")];
if (score > Double.NEGATIVE_INFINITY) {
System.out.println("\nFound a parse for sentence with length " + length
+ ". The LL is " + score + ".");
bestTree = extractBestParse(tagNumberer.number("ROOT"), 0, length,
sentence);
restoreUnaries(bestTree);
} else {
System.out.println("Did NOT find a parse for sentence with length "
+ length + ".");
}
return bestTree;
}
public boolean hasParse() {
if (length > arraySize) {
return false;
}
return (iScore[0][length][tagNumberer.number("ROOT")] > Double.NEGATIVE_INFINITY);
}
void initializeArrays() {
if (length > arraySize) {
if (length > myMaxLength) {
throw new OutOfMemoryError("Refusal to create such large arrays.");
} else {
try {
createArrays(length + 1);
} catch (OutOfMemoryError e) {
myMaxLength = length;
if (arraySize > 0) {
try {
createArrays(arraySize);
} catch (OutOfMemoryError e2) {
throw new RuntimeException(
"CANNOT EVEN CREATE ARRAYS OF ORIGINAL SIZE!!!");
}
}
throw e;
}
}
arraySize = length + 1;
}
for (int start = 0; start < length; start++) {
for (int end = start + 1; end <= length; end++) {
Arrays.fill(iScore[start][end], Double.NEGATIVE_INFINITY);
Arrays.fill(oScore[start][end], Double.NEGATIVE_INFINITY);
}
}
for (int loc = 0; loc <= length; loc++) {
Arrays.fill(narrowLExtent[loc], -1); // the rightmost left with state s
// ending at i that we can get is
// the beginning
Arrays.fill(wideLExtent[loc], length + 1); // the leftmost left with
// state s ending at i that we
// can get is the end
Arrays.fill(narrowRExtent[loc], length + 1); // the leftmost right with
// state s starting at i
// that we can get is the
// end
Arrays.fill(wideRExtent[loc], -1); // the rightmost right with state s
// starting at i that we can get is
// the beginning
}
}
void initializeChart(List<String> sentence, boolean noSmoothing) {
// for simplicity the lexicon will store words and tags as strings,
// while the grammar will be using integers -> Numberer()
int start = 0;
int end = start + 1;
for (String word : sentence) {
end = start + 1;
// for (short tag : lexicon.getAllTags()) {
for (short tag=0; tag<numStates; tag++){
if (grammar.isGrammarTag[tag]) continue;
double prob = lexicon.score(word, tag, start, noSmoothing,false)[0];
iScore[start][end][tag] = prob;
narrowRExtent[start][tag] = end;
narrowLExtent[end][tag] = start;
wideRExtent[start][tag] = end;
wideLExtent[end][tag] = start;
/*
* UnaryRule[] unaries = grammar.getClosedUnaryRulesByChild(state); for
* (int r = 0; r < unaries.length; r++) { UnaryRule ur = unaries[r]; int
* parentState = ur.parent; double pS = (double) ur.score; double tot =
* prob + pS; if (tot > iScore[start][end][parentState]) {
* iScore[start][end][parentState] = tot;
* narrowRExtent[start][parentState] = end;
* narrowLExtent[end][parentState] = start;
* wideRExtent[start][parentState] = end; wideLExtent[end][parentState] =
* start; } }
*/
}
//scaleIScores(start,end,0);
start++;
}
}
/**
* Fills in the iScore array of each category over each span of length 2 or
* more.
*
* Note: This places the grammar and lexicon into logarithm mode!
*/
void doInsideScores() {
grammar.logarithmMode();
lexicon.logarithmMode();
// for all symbol lengths
for (int diff = 1; diff <= length; diff++) {
// for all symbol starting positions
for (int start = 0; start < (length - diff + 1); start++) {
int end = start + diff;
// for all symbols, calculate the inside score without unaries
for (int pparentState = 0; pparentState < numStates; pparentState++) {
BinaryRule[] parentRules = grammar.splitRulesWithP(pparentState);
// for all rules with this parent symbol
for (int i = 0; i < parentRules.length; i++) {
BinaryRule r = parentRules[i];
int leftState = r.leftChildState;
int parentState = r.parentState;
int narrowR = narrowRExtent[start][leftState];
boolean iPossibleL = (narrowR < end); // can this left constituent
// leave space for a right
// constituent?
if (!iPossibleL) {
continue;
}
int narrowL = narrowLExtent[end][r.rightChildState];
boolean iPossibleR = (narrowL >= narrowR); // can this right
// constituent fit next
// to the left
// constituent?
if (!iPossibleR) {
continue;
}
int min1 = narrowR;
int min2 = wideLExtent[end][r.rightChildState];
int min = (min1 > min2 ? min1 : min2); // can this right
// constituent stretch far
// enough to reach the left
// constituent?
if (min > narrowL) {
continue;
}
int max1 = wideRExtent[start][leftState];
int max2 = narrowL;
int max = (max1 < max2 ? max1 : max2); // can this left constituent
// stretch far enough to
// reach the right
// constituent?
if (min > max) {
continue;
}
double pS = r.getScore(0, 0, 0);
double oldIScore = iScore[start][end][parentState];
double bestIScore = oldIScore;
boolean foundBetter; // always set below for this rule
for (int split = min; split <= max; split++) {
double lS = iScore[start][split][leftState];
if (Double.isInfinite(lS)) {
continue;
}
double rS = iScore[split][end][r.rightChildState];
if (Double.isInfinite(rS)) {
continue;
}
touchedRules++;
double tot = pS + lS + rS;
if (tot > bestIScore) {
bestIScore = tot;
}
}
foundBetter = bestIScore > oldIScore;
if (foundBetter) { // this way of making "parentState" is better
// than previous
iScore[start][end][parentState] = bestIScore;
if (Double.isInfinite(oldIScore)) {
if (start > narrowLExtent[end][parentState]) {
narrowLExtent[end][parentState] = start;
wideLExtent[end][parentState] = start;
} else {
if (start < wideLExtent[end][parentState]) {
wideLExtent[end][parentState] = start;
}
}
if (end < narrowRExtent[start][parentState]) {
narrowRExtent[start][parentState] = end;
wideRExtent[start][parentState] = end;
} else {
if (end > wideRExtent[start][parentState]) {
wideRExtent[start][parentState] = end;
}
}
}
}
}
}
// for all symbols, close all unary productions
for (int pState = 0; pState < numStates; pState++) {
UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(pState);
double cur = iScore[start][end][pState];
double best = cur;
for (int r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
int cState = ur.childState;
if (pState == cState) continue;
double pS = ur.getScore(0, 0);
double iS = iScore[start][end][cState];
if (Double.isInfinite(iS)) {
continue;
}
double tot = iS + pS;
touchedRules++;
if (tot > best){
best = tot;
}
}
if (best>cur) {
iScore[start][end][pState] = best;
if (cur == Double.NEGATIVE_INFINITY) {
if (start > narrowLExtent[end][pState]) {
narrowLExtent[end][pState] = start;
wideLExtent[end][pState] = start;
} else {
if (start < wideLExtent[end][pState]) {
wideLExtent[end][pState] = start;
}
}
if (end < narrowRExtent[start][pState]) {
narrowRExtent[start][pState] = end;
wideRExtent[start][pState] = end;
} else {
if (end > wideRExtent[start][pState]) {
wideRExtent[start][pState] = end;
}
}
}
}
}// ~ for all symbols
}// ~for all symbol starting positions
}// ~for all symbol lengths
}
/** Calculate outside scores using internal arrays.
*
* Note: This places the grammar and lexicon into logarithm mode!
*/
private void doOutsideScores() {
grammar.logarithmMode();
lexicon.logarithmMode();
//TODO: this almost certainly underflows!
for (int diff = length; diff >= 1; diff--) {
for (int start = 0; start + diff <= length; start++) {
int end = start + diff;
// do unaries
for (int s = 0; s < numStates; s++) {
double oS = oScore[start][end][s];
if (Double.isInfinite(oS)) {
continue;
}
UnaryRule[] rules = grammar.getClosedViterbiUnaryRulesByParent(s);
for (int r = 0; r < rules.length; r++) {
UnaryRule ur = rules[r];
double pS = ur.getScore(0, 0);
double tot = oS + pS;
touchedRules++;
if (tot > oScore[start][end][ur.childState]
&& iScore[start][end][ur.childState] > Double.NEGATIVE_INFINITY) {
oScore[start][end][ur.childState] = tot;
}
}
}
// do binaries
for (int s = 0; s < numStates; s++) {
BinaryRule[] rules = grammar.splitRulesWithP(s);
for (int r = 0; r < rules.length; r++) {
BinaryRule br = rules[r];
double oS = oScore[start][end][br.parentState];
if (Double.isInfinite(oS)) {
continue;
}
int min1 = narrowRExtent[start][br.leftChildState];
if (end < min1) {
continue;
}
int max1 = narrowLExtent[end][br.rightChildState];
if (max1 < min1) {
continue;
}
int min = min1;
int max = max1;
if (max - min > 2) {
int min2 = wideLExtent[end][br.rightChildState];
min = (min1 > min2 ? min1 : min2);
if (max1 < min) {
continue;
}
int max2 = wideRExtent[start][br.leftChildState];
max = (max1 < max2 ? max1 : max2);
if (max < min) {
continue;
}
}
double pS = br.getScore(0, 0, 0);
for (int split = min; split <= max; split++) {
double lS = iScore[start][split][br.leftChildState];
if (Double.isInfinite(lS)) {
continue;
}
double rS = iScore[split][end][br.rightChildState];
if (Double.isInfinite(rS)) {
continue;
}
double totL = pS + rS + oS;
touchedRules++;
if (totL > oScore[start][split][br.leftChildState]) {
oScore[start][split][br.leftChildState] = totL;
}
double totR = pS + lS + oS;
if (totR > oScore[split][end][br.rightChildState]) {
oScore[split][end][br.rightChildState] = totR;
}
}
}
}
/* for (int s = 0; s < numStates; s++) {
int max1 = narrowLExtent[end][s];
if (max1 < start) {
continue;
}
BinaryRule[] rules = grammar.splitRulesWithRC(s);
for (int r = 0; r < rules.length; r++) {
BinaryRule br = rules[r];
double oS = oScore[start][end][br.parentState];
if (Double.isInfinite(oS)) {
continue;
}
int min1 = narrowRExtent[start][br.leftChildState];
if (max1 < min1) {
continue;
}
int min = min1;
int max = max1;
if (max - min > 2) {
int min2 = wideLExtent[end][br.rightChildState];
min = (min1 > min2 ? min1 : min2);
if (max1 < min) {
continue;
}
int max2 = wideRExtent[start][br.leftChildState];
max = (max1 < max2 ? max1 : max2);
if (max < min) {
continue;
}
}
double pS = br.getScore(0, 0, 0);
for (int split = min; split <= max; split++) {
double lS = iScore[start][split][br.leftChildState];
if (Double.isInfinite(lS)) {
continue;
}
double rS = iScore[split][end][br.rightChildState];
if (Double.isInfinite(rS)) {
continue;
}
double totL = pS + rS + oS;
if (totL > oScore[start][split][br.leftChildState]) {
oScore[start][split][br.leftChildState] = totL;
}
double totR = pS + lS + oS;
if (totR > oScore[split][end][br.rightChildState]) {
oScore[split][end][br.rightChildState] = totR;
}
}
}
}*/
}
}
}
/**
* Calculate the inside scores, P(words_i,j|nonterminal_i,j) of a tree given
* the string if words it should parse to.
*
* @param tree
* @param sentence
*/
void doInsideScores(Tree<StateSet> tree, boolean noSmoothing, boolean debugOutput, double[][][] spanScores) {
if (grammar.isLogarithmMode() || lexicon.isLogarithmMode())
throw new Error("Grammar in logarithm mode! Cannot do inside scores!");
if (tree.isLeaf()){
return;
}
List<Tree<StateSet>> children = tree.getChildren();
for (Tree<StateSet> child : children) {
if (!child.isLeaf()) doInsideScores(child, noSmoothing, debugOutput, spanScores);
}
StateSet parent = tree.getLabel();
short pState = parent.getState();
int nParentStates = parent.numSubStates();
if (tree.isPreTerminal()) {
// Plays a role similar to initializeChart()
StateSet wordStateSet = tree.getChildren().get(0).getLabel();
double[] lexiconScores = lexicon.score(wordStateSet, pState, noSmoothing,false);
if (lexiconScores.length!=nParentStates){
System.out.println("Have more scores than substates!");// truncate the array
}
parent.setIScores(lexiconScores);
parent.scaleIScores(0);
} else {
switch (children.size()) {
case 0:
break;
case 1:
StateSet child = children.get(0).getLabel();
short cState = child.getState();
int nChildStates = child.numSubStates();
double[][] uscores = grammar.getUnaryScore(pState,cState);
double[] iScores = new double[nParentStates];
boolean foundOne = false;
for (int j = 0; j < nChildStates; j++) {
if (uscores[j]!=null) { //check whether one of the parents can produce this child
double cS = child.getIScore(j);
if (cS==0) continue;
for (int i = 0; i < nParentStates; i++) {
double rS = uscores[j][i]; // rule score
if (rS==0) continue;
double res = rS * cS;
/*if (res == 0) {
System.out.println("Prevented an underflow: rS "+rS+" cS "+cS);
res = Double.MIN_VALUE;
}*/
iScores[i] += res;
foundOne = true;
}
}
}
if (debugOutput && !foundOne) {
System.out.println("iscore reached zero!");
System.out.println(grammar.getUnaryRule(pState,cState));
System.out.println(Arrays.toString(iScores));
System.out.println(ArrayUtil.toString(uscores));
System.out.println(Arrays.toString(child.getIScores()));
}
parent.setIScores(iScores);
parent.scaleIScores(child.getIScale());
break;
case 2:
StateSet leftChild = children.get(0).getLabel();
StateSet rightChild = children.get(1).getLabel();
int nLeftChildStates = leftChild.numSubStates();
int nRightChildStates = rightChild.numSubStates();
short lState = leftChild.getState();
short rState = rightChild.getState();
double[][][] bscores = grammar.getBinaryScore(pState,lState,rState);
double[] iScores2 = new double[nParentStates];
boolean foundOne2 = false;
for (int j = 0; j < nLeftChildStates; j++) {
double lcS = leftChild.getIScore(j);
if (lcS==0) continue;
for (int k = 0; k < nRightChildStates; k++) {
double rcS = rightChild.getIScore(k);
if (rcS==0) continue;
if (bscores[j][k]!=null) { // check whether one of the parents can produce these kids
for (int i = 0; i < nParentStates; i++) {
double rS = bscores[j][k][i];
if (rS==0) continue;
double res = rS * lcS * rcS;
/*if (res == 0) {
System.out.println("Prevented an underflow: rS "+rS+" lcS "+lcS+" rcS "+rcS);
res = Double.MIN_VALUE;
}*/
iScores2[i] += res;
foundOne2 = true;
}
}
}
}
if (spanScores!=null){
for (int i = 0; i < nParentStates; i++) {
iScores2[i] *= spanScores[parent.from][parent.to][stateClass[pState]];
}
}
//if (!foundOne2) System.out.println("Did not find a way to build binary transition from "+pState+" to "+lState+" and "+rState+" "+ArrayUtil.toString(bscores));
if (debugOutput && !foundOne2) {
System.out.println("iscore reached zero!");
System.out.println(grammar.getBinaryRule(pState,lState,rState));
System.out.println(Arrays.toString(iScores2));
System.out.println(Arrays.toString(bscores));
System.out.println(Arrays.toString(leftChild.getIScores()));
System.out.println(Arrays.toString(rightChild.getIScores()));
}
parent.setIScores(iScores2);
parent.scaleIScores(leftChild.getIScale()+rightChild.getIScale());
break;
default:
throw new Error("Malformed tree: more than two children");
}
}
}
/**
* Set the outside score of the root node to P=1.
*
* @param tree
*/
void setRootOutsideScore(Tree<StateSet> tree) {
tree.getLabel().setOScore(0, 1);
tree.getLabel().setOScale(0);
}
/**
* Calculate the outside scores of a tree; that is,
* P(nonterminal_i,j|words_0,i; words_j,end). It is calculate from the inside
* scores of the tree.
*
* <p>
* Note: when calling this, call setRootOutsideScore() first.
*
* @param tree
*/
void doOutsideScores(Tree<StateSet> tree, boolean unaryAbove, double[][][] spanScores) {
if (grammar.isLogarithmMode() || lexicon.isLogarithmMode())
throw new Error("Grammar in logarithm mode! Cannot do inside scores!");
if (tree.isLeaf())
return;
List<Tree<StateSet>> children = tree.getChildren();
StateSet parent = tree.getLabel();
short pState = parent.getState();
int nParentStates = parent.numSubStates();
// this sets the outside scores for the children
if (tree.isPreTerminal()) {
} else {
double[] parentScores = parent.getOScores();
if (spanScores!=null && !unaryAbove){
for (int i = 0; i < nParentStates; i++) {
parentScores[i] *= spanScores[parent.from][parent.to][stateClass[pState]];
}
}
switch (children.size()) {
case 0:
// Nothing to do
break;
case 1:
StateSet child = children.get(0).getLabel();
short cState = child.getState();
int nChildStates = child.numSubStates();
//UnaryRule uR = new UnaryRule(pState,cState);
double[][] uscores = grammar.getUnaryScore(pState,cState);
double[] oScores = new double[nChildStates];
for (int j = 0; j < nChildStates; j++) {
if (uscores[j]!=null){
double childScore = 0;
for (int i = 0; i < nParentStates; i++) {
double pS = parentScores[i];
if (pS == 0) continue;
double rS = uscores[j][i]; // rule score
if (rS == 0) continue;
childScore += pS * rS;
}
oScores[j] = childScore;
}
}
child.setOScores(oScores);
child.scaleOScores(parent.getOScale());
unaryAbove = true;
break;
case 2:
StateSet leftChild = children.get(0).getLabel();
StateSet rightChild = children.get(1).getLabel();
int nLeftChildStates = leftChild.numSubStates();
int nRightChildStates = rightChild.numSubStates();
short lState = leftChild.getState();
short rState = rightChild.getState();
//double[] leftScoresToAdd -> use childScores array instead = new double[nRightChildStates * nParentStates];
//double[][] rightScoresToAdd -> use binaryScores array instead = new double[nRightChildStates][nLeftChildStates * nParentStates];
double[][][] bscores = grammar.getBinaryScore(pState,lState,rState);
double[] lOScores = new double[nLeftChildStates];
double[] rOScores = new double[nRightChildStates];
for (int j = 0; j < nLeftChildStates; j++) {
double lcS = leftChild.getIScore(j);
double leftScore = 0;
for (int k = 0; k < nRightChildStates; k++) {
double rcS = rightChild.getIScore(k);
if (bscores[j][k]!=null){
for (int i = 0; i < nParentStates; i++) {
double pS = parentScores[i];
if (pS==0) continue;
double rS = bscores[j][k][i];
if (rS==0) continue;
leftScore += pS * rS * rcS;
rOScores[k] += pS * rS * lcS;
}
}
lOScores[j] = leftScore;
}
}
leftChild.setOScores(lOScores);
leftChild.scaleOScores(parent.getOScale()+rightChild.getIScale());
rightChild.setOScores(rOScores);
rightChild.scaleOScores(parent.getOScale()+leftChild.getIScale());
unaryAbove = false;
break;
default:
throw new Error("Malformed tree: more than two children");
}
for (Tree<StateSet> child : children) {
doOutsideScores(child, unaryAbove, spanScores);
}
}
}
public double doInsideOutsideScores(Tree<StateSet> tree, boolean noSmoothing, boolean debugOutput, double[][][] spanScores) {
doInsideScores(tree, noSmoothing, debugOutput, spanScores);
setRootOutsideScore(tree);
doOutsideScores(tree, false, spanScores);
double tree_score = tree.getLabel().getIScore(0);
int tree_scale = tree.getLabel().getIScale();
return Math.log(tree_score) + (ScalingTools.LOGSCALE*tree_scale);
}
public void doInsideOutsideScores(Tree<StateSet> tree, boolean noSmoothing, boolean debugOutput) {
doInsideScores(tree, noSmoothing, debugOutput, null);
setRootOutsideScore(tree);
doOutsideScores(tree, false, null);
}
private void createArrays(int length) {
// zero out some stuff first in case we recently ran out of memory and are
// reallocating
clearArrays();
// allocate just the parts of iScore and oScore used (end > start, etc.)
// System.out.println("initializing iScore arrays with length " + length + "
// and numStates " + numStates);
iScore = new double[length][length + 1][];
for (int start = 0; start < length; start++) {
for (int end = start + 1; end <= length; end++) {
iScore[start][end] = new double[numStates];
}
}
// System.out.println("finished initializing iScore arrays");
// System.out.println("initializing oScore arrays with length " + length + "
// and numStates " + numStates);
oScore = new double[length][length + 1][];
for (int start = 0; start < length; start++) {
for (int end = start + 1; end <= length; end++) {
oScore[start][end] = new double[numStates];
}
}
// System.out.println("finished initializing oScore arrays");
// iPossibleByL = new boolean[length + 1][numStates];
// iPossibleByR = new boolean[length + 1][numStates];
narrowRExtent = new int[length + 1][numStates];
wideRExtent = new int[length + 1][numStates];
narrowLExtent = new int[length + 1][numStates];
wideLExtent = new int[length + 1][numStates];
/*
* (op.doDep) { oPossibleByL = new boolean[length + 1][numStates];
* oPossibleByR = new boolean[length + 1][numStates];
*
* oFilteredStart = new boolean[length + 1][numStates]; oFilteredEnd = new
* boolean[length + 1][numStates]; } tags = new boolean[length +
* 1][numTags];
*
* if (Test.lengthNormalization) { wordsInSpan = new int[length + 1][length +
* 1][]; for (int start = 0; start <= length; start++) { for (int end =
* start + 1; end <= length; end++) { wordsInSpan[start][end] = new
* int[numStates]; } } }
*/// System.out.println("ExhaustivePCFGParser constructor finished.");
}
protected void clearArrays() {
iScore = oScore = null;
// iPossibleByL = iPossibleByR = oFilteredEnd = oFilteredStart =
// oPossibleByL = oPossibleByR = tags = null;
narrowRExtent = wideRExtent = narrowLExtent = wideLExtent = null;
}
// borrowed from the stanford parser
/**
* Return all best parses (except no ties allowed on POS tags?). Note that the
* returned tree may be missing intermediate nodes in a unary chain because it
* parses with a unary-closed grammar.
*/
public Tree<String> extractBestParse(int goal, int start, int end,
List<String> sentence) {
grammar.logarithmMode();
lexicon.logarithmMode();
// find sources of inside score
// no backtraces so we can speed up the parsing for its primary use
double bestScore = iScore[start][end][goal];
String goalStr = (String) tagNumberer.object(goal);
// System.out.println("Looking for " + goalStr + " from " + start + " to " + end + " with score " + bestScore + ".");
if (end - start == 1) {
// handle the (pre)terminal nodes differently
// System.out.println("Tag node: "+goalStr);
// check whether there is a rewrite that is actually better
// if the lexicon contains the goal, then we're already at
// the preterminal level, so we don't need to try to find any
// unary rules to get to a preterminal tag
if (!grammar.isGrammarTag[goal]){
//if (lexicon.getAllTags().contains(goal)) {
List<Tree<String>> child = new ArrayList<Tree<String>>();
child.add(new Tree<String>(sentence.get(start)));
return new Tree<String>(goalStr, child);
}
// if the lexicon does not contain the goal, then we must find
// the best way to get from the goal tag to a preterminal tag
else {
double veryBestScore = Double.NEGATIVE_INFINITY;
int newIndex = -1;
UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(goal);
for (int r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
double ruleScore = iScore[start][end][ur.childState] + grammar.getUnaryScore(ur)[0][0];
if ((ruleScore > veryBestScore) && (goal != ur.childState)
&& (!grammar.isGrammarTag[ur.getChildState()])){
// if ((ruleScore > veryBestScore) && (goal != ur.childState)
// && lexicon.getAllTags().contains(ur.getChildState())) {
veryBestScore = ruleScore;
newIndex = ur.childState;
}
}
// insert the nonterminal tag into the tree
List<Tree<String>> child1 = new ArrayList<Tree<String>>();
child1.add(new Tree<String>(sentence.get(start)));
String goalStr1 = (String) tagNumberer.object(newIndex);
List<Tree<String>> child = new ArrayList<Tree<String>>();
child.add(new Tree<String>(goalStr1, child1));
return new Tree<String>(goalStr, child);
}
/*
* IntTaggedWord tagging = new IntTaggedWord(words[start],
* tagNumberer.number(goalStr)); double tagScore = lex.score(tagging,
* start); if (tagScore > Double.NEGATIVE_INFINITY || floodTags) { //
* return a pre-terminal tree String wordStr = (String)
* wordNumberer.object(words[start]); Tree wordNode = tf.newLeaf(new
* StringLabel(wordStr)); List childList = new ArrayList();
* childList.add(wordNode); Tree tagNode = tf.newTreeNode(new
* StringLabel(goalStr), childList); //System.out.println("Tag node:
* "+tagNode); return Collections.singletonList(tagNode); }
*/
}
// check binaries first
for (int split = start + 1; split < end; split++) {
BinaryRule[] parentRules = grammar.splitRulesWithP(goal);
for (int i = 0; i < parentRules.length; i++) {
BinaryRule br = parentRules[i];
double score = br.getScore(0, 0, 0)
+ iScore[start][split][br.leftChildState]
+ iScore[split][end][br.rightChildState];
if (matches(score, bestScore)) {
// build binary split
Tree<String> leftChildTree = extractBestParse(br.leftChildState,
start, split, sentence);
Tree<String> rightChildTree = extractBestParse(br.rightChildState,
split, end, sentence);
List<Tree<String>> children = new ArrayList<Tree<String>>();
children.add(leftChildTree);
children.add(rightChildTree);
Tree<String> result = new Tree<String>(goalStr, children);
// System.out.println("Binary node: "+result);
// result.setScore(score);
return result;
}
}
}
// check unaries
UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(goal);
for (int r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
double score = ur.getScore(0, 0) + iScore[start][end][ur.childState];
if (ur.childState != ur.parentState && matches(score, bestScore)) {
// build unary
Tree<String> childTree = extractBestParse(ur.childState, start, end,
sentence);
List<Tree<String>> children = new ArrayList<Tree<String>>();
children.add(childTree);
Tree<String> result = new Tree<String>(goalStr, children);
// System.out.println("Unary node: "+result);
// result.setScore(score);
return result;
}
}
System.err.println("Warning: no parse found");
return null;
}
protected void restoreUnaries(Tree<String> t) {
// System.out.println("In restoreUnaries...");
for (Iterator nodeI = t.subTreeList().iterator(); nodeI.hasNext();) {
Tree<String> node = (Tree<String>) nodeI.next();
// System.out.println("Doing node: "+node.getLabel());
if (node.isLeaf() || node.isPreTerminal()
|| node.getChildren().size() != 1) {
// System.out.println("Skipping node: "+node.getLabel());
continue;
}
// System.out.println("Not skipping node: "+node.getLabel());
Tree<String> parent = node;
// Tree<String> child = node.getChildren().get(0);
short pLabel = (short)tagNumberer.number(parent.getLabel());
short cLabel = (short)tagNumberer.number(node.getChildren().get(0).getLabel());
// List path =
// grammar.getBestPath(stateNumberer.number(parent.getLabel().value()),
// stateNumberer.number(child.label().value().toString()));
// if (grammar.getUnaryScore(new UnaryRule(pLabel,cLabel))[0][0] == 0){
// continue; }// means the rule was already in grammar
List<short[]> path = grammar.getBestViterbiPath(pLabel, (short)0, cLabel, (short)0);
// System.out.println("Got path for "+pLabel + " to " + cLabel + " via " +
// path);
for (int pos=1; pos < path.size() - 1; pos++) {
int tmp = path.get(pos)[0];
int interState = tmp;
Tree<String> intermediate = new Tree<String>((String) tagNumberer
.object(interState), parent.getChildren());
List<Tree<String>> children = new ArrayList<Tree<String>>();
children.add(intermediate);
parent.setChildren(children);
parent = intermediate;
}
}
}
private static final double TOL = 1e-5;
protected boolean matches(double x, double y) {
return (Math.abs(x - y) / (Math.abs(x) + Math.abs(y) + 1e-10) < TOL);
}
/**
* @param stateSetTree
* @return
*/
public void doViterbiInsideScores(Tree<StateSet> tree) {
if (tree.isLeaf()){
return;
}
List<Tree<StateSet>> children = tree.getChildren();
for (Tree<StateSet> child : children) {
if (tree.isLeaf()) continue;
doViterbiInsideScores(child);//newChildren.add(getBestViterbiDerivation(child));
}
StateSet parent = tree.getLabel();
short pState = parent.getState();
int nParentStates = grammar.numSubStates[pState];//parent.numSubStates();
double[] iScores = new double[nParentStates];
if (tree.isPreTerminal()) {
// Plays a role similar to initializeChart()
String word = tree.getChildren().get(0).getLabel().getWord();
int pos = tree.getChildren().get(0).getLabel().from;
iScores = lexicon.score(word, pState, pos, false,false);
// parent.scaleIScores(0);
} else {
Arrays.fill(iScores, Double.NEGATIVE_INFINITY);
switch (children.size()) {
case 0:
break;
case 1:
StateSet child = children.get(0).getLabel();
short cState = child.getState();
int nChildStates = child.numSubStates();
double[][] uscores = grammar.getUnaryScore(pState,cState);
for (int j = 0; j < nChildStates; j++) {
if (uscores[j]!=null) { //check whether one of the parents can produce this child
double cS = child.getIScore(j);
if (cS==Double.NEGATIVE_INFINITY) continue;
for (int i = 0; i < nParentStates; i++) {
double rS = uscores[j][i]; // rule score
if (rS==Double.NEGATIVE_INFINITY) continue;
double res = rS + cS;
iScores[i] = Math.max(iScores[i], res);
}
}
}
// parent.scaleIScores(child.getIScale());
break;
case 2:
StateSet leftChild = children.get(0).getLabel();
StateSet rightChild = children.get(1).getLabel();
int nLeftChildStates = grammar.numSubStates[leftChild.getState()];//leftChild.numSubStates();
int nRightChildStates = grammar.numSubStates[rightChild.getState()];//rightChild.numSubStates();
short lState = leftChild.getState();
short rState = rightChild.getState();
double[][][] bscores = grammar.getBinaryScore(pState,lState,rState);
for (BinaryRule br : grammar.splitRulesWithP(pState)){
if (br.leftChildState != lState) continue;
if (br.rightChildState != rState) continue;
bscores = br.getScores2();
}
for (int j = 0; j < nLeftChildStates; j++) {
double lcS = leftChild.getIScore(j);
if (lcS==Double.NEGATIVE_INFINITY) continue;
for (int k = 0; k < nRightChildStates; k++) {
double rcS = rightChild.getIScore(k);
if (rcS==Double.NEGATIVE_INFINITY) continue;
if (bscores[j][k]!=null) { // check whether one of the parents can produce these kids
for (int i = 0; i < nParentStates; i++) {
double rS = bscores[j][k][i];
if (rS==Double.NEGATIVE_INFINITY) continue;
double res = rS + lcS + rcS;
iScores[i] = Math.max(iScores[i], res);
}
}
}
}
// parent.scaleIScores(leftChild.getIScale()+rightChild.getIScale());
break;
default:
throw new Error("Malformed tree: more than two children");
}
}
parent.setIScores(iScores);
}
Tree<String> extractBestViterbiDerivation(Tree<StateSet> tree, int substate, boolean outputScore){
if (tree.isLeaf()) return new Tree<String>(tree.getLabel().getWord());
if (substate==-1) substate=0;
if (tree.isPreTerminal()){
ArrayList<Tree<String>> child = new ArrayList<Tree<String>>();
child.add(extractBestViterbiDerivation(tree.getChildren().get(0),-1,outputScore));
String goalStr = tagNumberer.object(tree.getLabel().getState())+"-"+substate;
if (outputScore) goalStr = goalStr + " " + tree.getLabel().getIScore(substate);
return new Tree<String>(goalStr, child);
}
StateSet node = tree.getLabel();
short pState = node.getState();
ArrayList<Tree<String>> newChildren = new ArrayList<Tree<String>>();
List<Tree<StateSet>> children = tree.getChildren();
double myScore = node.getIScore(substate);
if (myScore==Double.NEGATIVE_INFINITY){
myScore = DoubleArrays.max(node.getIScores());
substate = DoubleArrays.argMax(node.getIScores());
}
switch (children.size()) {
case 1:
StateSet child = children.get(0).getLabel();
short cState = child.getState();
int nChildStates = child.numSubStates();
double[][] uscores = grammar.getUnaryScore(pState,cState);
int childIndex = -1;
for (int j = 0; j < nChildStates; j++) {
if (childIndex != -1) break;
if (uscores[j]!=null) {
double cS = child.getIScore(j);
if (cS==Double.NEGATIVE_INFINITY) continue;
double rS = uscores[j][substate]; // rule score
if (rS==Double.NEGATIVE_INFINITY) continue;
double res = rS + cS;
if (matches(res,myScore)) childIndex = j;
}
}
newChildren.add(extractBestViterbiDerivation(children.get(0), childIndex, outputScore));
break;
case 2:
StateSet leftChild = children.get(0).getLabel();
StateSet rightChild = children.get(1).getLabel();
int nLeftChildStates = leftChild.numSubStates();
int nRightChildStates = rightChild.numSubStates();
short lState = leftChild.getState();
short rState = rightChild.getState();
double[][][] bscores = grammar.getBinaryScore(pState,lState,rState);
int lChildIndex = -1, rChildIndex = -1;
for (int j = 0; j < nLeftChildStates; j++) {
if (lChildIndex!=-1 && rChildIndex!=-1) break;
double lcS = leftChild.getIScore(j);
if (lcS==Double.NEGATIVE_INFINITY) continue;
for (int k = 0; k < nRightChildStates; k++) {
if (lChildIndex!=-1 && rChildIndex!=-1) break;
double rcS = rightChild.getIScore(k);
if (rcS==Double.NEGATIVE_INFINITY) continue;
if (bscores[j][k]!=null) { // check whether one of the parents can produce these kids
double rS = bscores[j][k][substate];
if (rS==Double.NEGATIVE_INFINITY) continue;
double res = rS + lcS + rcS;
if (matches(myScore,res)){
lChildIndex = j;
rChildIndex = k;
}
}
}
}
newChildren.add(extractBestViterbiDerivation(children.get(0), lChildIndex, outputScore));
newChildren.add(extractBestViterbiDerivation(children.get(1), rChildIndex, outputScore));
break;
default:
throw new Error ("Malformed tree: more than two children");
}
String parentString = (String)tagNumberer.object(node.getState());
if (parentString.endsWith("^g")) parentString = parentString.substring(0,parentString.length()-2);
parentString = parentString+"-"+substate;
if (outputScore) parentString = parentString + " " + myScore;
return new Tree<String>(parentString, newChildren);
}
public Tree<String> getBestViterbiDerivation(Tree<StateSet> tree, boolean outputScore){
doViterbiInsideScores(tree);
if (tree.getLabel().getIScore(0)==Double.NEGATIVE_INFINITY) {
// System.out.println("Tree is unparsable!");
return null;
}
return extractBestViterbiDerivation(tree, 0, outputScore);
}
public void incrementExpectedGoldCounts(Linearizer linearizer, double[] probs, Tree<StateSet> tree){
double tree_score = tree.getLabel().getIScore(0);
int tree_scale = tree.getLabel().getIScale();
incrementExpectedGoldCounts(linearizer, probs, tree, tree_score, tree_scale);
}
public void incrementExpectedGoldCounts(Linearizer linearizer, double[] probs, Tree<StateSet> tree,
double tree_score, int tree_scale) {
if (tree.isLeaf())
return;
if (tree.isPreTerminal()){
StateSet parent = tree.getLabel();
StateSet child = tree.getChildren().get(0).getLabel();
// String word = child.getWord();
short tag = tree.getLabel().getState();
final int nSubStates = grammar.numSubStates[tag];
double scalingFactor = ScalingTools.calcScaleFactor(parent.getOScale()+parent.getIScale()-tree_scale);
// if (!combinedLexicon){
for (short substate=0; substate<nSubStates; substate++) {
//weight by the probability of seeing the tag and word together, given the sentence
double pIS = parent.getIScore(substate); // Parent outside score
if (pIS==0) { continue; }
double pOS = parent.getOScore(substate); // Parent outside score
if (pOS==0) { continue; }
double weight = 1;
weight = (pIS / tree_score) * scalingFactor * pOS;
if (weight>1.01){
System.out.println("Overflow when counting tags? "+weight);
weight = 0;
}
tmpCountsArray[substate] = weight;
}
linearizer.increment(probs, child, tag, tmpCountsArray, true); //probs[startIndexWord+substate] += weight;
// linearizer.increment(probs, child.sigIndex, tag, tmpCountsArray); //probs[startIndexWord+substate] += weight;
// } else {
//// int pos = child.from;
// double[] wordScores = lexicon.scoreWord(child, tag);
// for (short substate=0; substate<nSubStates; substate++) {
// //weight by the probability of seeing the tag and word together, given the sentence
// double pIS = wordScores[substate]; // Parent outside score
// if (pIS==0) { continue; }
// double pOS = parent.getOScore(substate); // Parent outside score
// if (pOS==0) { continue; }
// double weight = 1;
// weight = hardCounts ? 1 : (pIS / tree_score) * scalingFactor * pOS;
// tmpCountsArray[substate] = weight;
// }
// linearizer.increment(probs, child.wordIndex, tag, tmpCountsArray); //probs[startIndexWord+substate] += weight;
//
//// String sig = lexicon.getSignature(word, pos);
// double[] sigScores = lexicon.scoreSignature(child, tag);
// if (sigScores==null)
// return;
// for (short substate=0; substate<nSubStates; substate++) {
// //weight by the probability of seeing the tag and word together, given the sentence
// double pIS = sigScores[substate]; // Parent outside score
// if (pIS==0) { continue; }
// double pOS = parent.getOScore(substate); // Parent outside score
// if (pOS==0) { continue; }
// double weight = 1;
// weight = hardCounts ? 1 : (pIS / tree_score) * scalingFactor * pOS;
// tmpCountsArray[substate] = weight;
// }
// linearizer.increment(probs, child.sigIndex, tag, tmpCountsArray); //probs[startIndexWord+substate] += weight;
// }
return;
}
List<Tree<StateSet>> children = tree.getChildren();
StateSet parent = tree.getLabel();
short parentState = parent.getState();
int nParentSubStates = grammar.numSubStates[parentState];
//if (oScore[pStart][pEnd][parentState]==null) break;
switch (children.size()) {
case 0:
// This is a leaf (a preterminal node, if we count the words themselves),
// nothing to do
break;
case 1:
// first check whether this is a unary chain!
/*if (!children.get(0).isPreTerminal() && children.get(0).getChildren().size()==1){ // if so, skip the intermediate node
children = children.get(0).getChildren();
} */
StateSet child = children.get(0).getLabel();
short childState = child.getState();
// int thisStartIndex = linearizer.getLinearIndex(new UnaryRule(parentState, childState));//startIndexGrammar[parentState][childState][0];
int curInd = 0;
int nChildSubStates = grammar.numSubStates[childState];
UnaryRule urule = grammar.getUnaryRule(parentState, childState);
double[][] oldUScores = urule.getScores2();
double scalingFactor = ScalingTools.calcScaleFactor(parent.getOScale()+child.getIScale()-tree_scale);
// if (scalingFactor==0){
// System.out.println("p: "+parent.getOScale()+" c: "+child.getIScale()+" t:"+tree_scale);
// }
for (short i = 0; i < nChildSubStates; i++) {
if (oldUScores[i]==null) continue;
double cIS = child.getIScore(i);
for (short j = 0; j < nParentSubStates; j++) {
curInd++;
if (cIS==0) { continue; }
double pOS = parent.getOScore(j); // Parent outside score
if (pOS==0) { continue; }
double rS = oldUScores[i][j];
if (rS==0) { continue; }
if (tree_score==0)
tree_score = 1;
double ruleCount = (rS * cIS / tree_score) * scalingFactor * pOS;
if (ruleCount>1.01){
System.out.println("Overflow when counting binaries? "+ruleCount);
ruleCount = 0;
}
if (ruleCount!=0)
tmpCountsArray[curInd-1] = ruleCount;
}
if (parentState==0) curInd += (nChildSubStates-1);
}
linearizer.increment(probs, urule, tmpCountsArray, true); //probs[thisStartIndex + curInd-1] += ruleCount;
break;
case 2:
StateSet leftChild = children.get(0).getLabel();
short lChildState = leftChild.getState();
StateSet rightChild = children.get(1).getLabel();
short rChildState = rightChild.getState();
int nLeftChildSubStates = grammar.numSubStates[lChildState];
int nRightChildSubStates = grammar.numSubStates[rChildState];
// thisStartIndex = linearizer.getLinearIndex(new BinaryRule(parentState, lChildState, rChildState));//startIndexGrammar[parentState][lChildState][rChildState];
curInd = 0;
//new double[nLeftChildSubStates][nRightChildSubStates][];
BinaryRule brule = grammar.getBinaryRule(parentState, lChildState, rChildState);
double[][][] oldBScores = brule.getScores2();
scalingFactor = ScalingTools.calcScaleFactor(parent.getOScale()+leftChild.getIScale()+rightChild.getIScale()-tree_scale);
// if (scalingFactor==0){
// System.out.println("p: "+parent.getOScale()+" l: "+leftChild.getIScale()+" r:"+rightChild.getIScale()+" t:"+tree_scale);
// }
int nRuleStates = oldBScores[0][0].length;
int divisor = nParentSubStates/nRuleStates;
if (divisor == 0)
System.out.println("not possible");
for (short i = 0; i < nLeftChildSubStates; i++) {
double lcIS = leftChild.getIScore(i);
for (short j = 0; j < nRightChildSubStates; j++) {
// if (oldBScores[i][j]==null) continue;
if (nRuleStates==nParentSubStates){
for (short k = 0; k < nParentSubStates; k++) {
curInd++;
if (lcIS==0) { continue; }
double rcIS = rightChild.getIScore(j);
if (rcIS==0) { continue; }
double pOS = parent.getOScore(k); // Parent outside score
if (pOS==0) { continue; }
double rS = oldBScores[i][j][k];
if (rS==0) { continue; }
if (tree_score==0)
tree_score = 1;
double ruleCount = (rS * lcIS / tree_score) * rcIS * scalingFactor * pOS;
if (ruleCount>1.01){
System.out.println("Overflow when counting unaries? "+ruleCount);
ruleCount = 0;
}
if (ruleCount>0)
tmpCountsArray[curInd-1] = ruleCount;// probs[thisStartIndex + curInd-1] += ruleCount;
}
} else {
for (short k = 0; k < nParentSubStates; k++) {
curInd++;
if (lcIS==0) { continue; }
double rcIS = rightChild.getIScore(j);
if (rcIS==0) { continue; }
double pOS = parent.getOScore(k); // Parent outside score
if (pOS==0) { continue; }
double rS = oldBScores[i/divisor][j/divisor][k/divisor];
if (rS==0) { continue; }
if (tree_score==0)
tree_score = 1;
double ruleCount = (rS * lcIS / tree_score) * rcIS * scalingFactor * pOS;
if (ruleCount>1.01){
System.out.println("Overflow when counting unaries? "+ruleCount);
ruleCount = 0;
}
if (ruleCount>0)
tmpCountsArray[curInd-1] = ruleCount;// probs[thisStartIndex + curInd-1] += ruleCount;
}
}
}
}
linearizer.increment(probs, brule, tmpCountsArray, true); //probs[thisStartIndex + curInd-1] += ruleCount;
break;
default:
throw new Error("Malformed tree: more than two children");
}
for (Tree<StateSet> child : children) {
incrementExpectedGoldCounts(linearizer, probs, child, tree_score, tree_scale);
}
}
public void countPosteriors(double[][] cumulativePosteriors, Tree<StateSet> tree, double tree_score, int tree_scale) {
if (tree.isLeaf())
return;
StateSet node = tree.getLabel();
short state = node.getState();
final int nSubStates = grammar.numSubStates[state];
double scalingFactor = ScalingTools.calcScaleFactor(node.getOScale()+node.getIScale()-tree_scale);
for (short substate=0; substate<nSubStates; substate++) {
double pIS = node.getIScore(substate); // Parent outside score
if (pIS==0) { continue; }
double pOS = node.getOScore(substate); // Parent outside score
if (pOS==0) { continue; }
double weight = 1;
weight = (pIS / tree_score) * scalingFactor * pOS;
if (weight>1.01){
System.out.println("Overflow when counting tags? "+weight);
weight = 0;
}
cumulativePosteriors[state][substate] += weight;
}
for (Tree<StateSet> child : tree.getChildren()) {
countPosteriors(cumulativePosteriors, child, tree_score, tree_scale);
}
}
}