/**
*
*/
package edu.berkeley.nlp.PCFGLA;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.ScalingTools;
/**
* @author petrov
*
*/
public class CoarseToFineNBestParser extends CoarseToFineMaxRuleParser{
LazyList[][][] chartBeforeU;
LazyList[][][] chartAfterU;
int k;
List<Double> maxRuleScores;
int tmp_k;
/**
* @param gr
* @param lex
* @param unaryPenalty
* @param endL
* @param viterbi
* @param sub
* @param score
* @param accurate
* @param variational
* @param useGoldPOS
*/
public CoarseToFineNBestParser(Grammar gr, Lexicon lex, int k, double unaryPenalty, int endL, boolean viterbi,
boolean sub, boolean score, boolean accurate, boolean variational, boolean useGoldPOS, boolean initCascade) {
super(gr, lex, unaryPenalty, endL, viterbi, sub, score, accurate, variational,useGoldPOS, initCascade);
this.k = k;
}
/** Assumes that inside and outside scores (sum version, not viterbi) have been computed.
* In particular, the narrowRExtent and other arrays need not be updated.
*/
void doConstrainedMaxCScores(List<String> sentence, Grammar grammar, Lexicon lexicon, final boolean scale) {
numSubStatesArray = grammar.numSubStates;
double initVal = Double.NEGATIVE_INFINITY;
chartBeforeU = new LazyList[length][length + 1][numStates];
chartAfterU = new LazyList[length][length + 1][numStates];
double logNormalizer = iScore[0][length][0][0];
// double thresh2 = threshold*logNormalizer;
for (int diff = 1; diff <= length; diff++) {
//System.out.print(diff + " ");
for (int start = 0; start < (length - diff + 1); start++) {
int end = start + diff;
if (diff > 1) {
// diff > 1: Try binary rules
for (int pState=0; pState<numSubStatesArray.length; pState++){
if (!allowedStates[start][end][pState]) continue;
chartBeforeU[start][end][pState] = new LazyList(grammar.isGrammarTag);
BinaryRule[] parentRules = grammar.splitRulesWithP(pState);
int nParentStates = numSubStatesArray[pState]; // == scores[0][0].length;
double bestScore = Double.NEGATIVE_INFINITY;
HyperEdge bestElement = null;
for (int i = 0; i < parentRules.length; i++) {
BinaryRule r = parentRules[i];
int lState = r.leftChildState;
int rState = r.rightChildState;
int narrowR = narrowRExtent[start][lState];
boolean iPossibleL = (narrowR < end); // can this left constituent leave space for a right constituent?
if (!iPossibleL) { continue; }
int narrowL = narrowLExtent[end][rState];
boolean iPossibleR = (narrowL >= narrowR); // can this right constituent fit next to the left constituent?
if (!iPossibleR) { continue; }
int min1 = narrowR;
int min2 = wideLExtent[end][rState];
int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent?
if (min > narrowL) { continue; }
int max1 = wideRExtent[start][lState];
int max2 = narrowL;
int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent?
if (min > max) { continue; }
double[][][] scores = r.getScores2();
int nLeftChildStates = numSubStatesArray[lState]; // == scores.length;
int nRightChildStates = numSubStatesArray[rState]; // == scores[0].length;
for (int split = min; split <= max; split++) {
double ruleScore = 0;
if (!allowedStates[start][split][lState]) continue;
if (!allowedStates[split][end][rState]) continue;
HyperEdge bestLeft = chartAfterU[start][split][lState].getKbest(0);
double leftChildScore = (bestLeft==null) ? Double.NEGATIVE_INFINITY : bestLeft.score;
HyperEdge bestRight = chartAfterU[split][end][rState].getKbest(0);
double rightChildScore = (bestRight==null) ? Double.NEGATIVE_INFINITY : bestRight.score;
// double leftChildScore = maxcScore[start][split][lState];
// double rightChildScore = maxcScore[split][end][rState];
if (leftChildScore==initVal||rightChildScore==initVal) continue;
double scalingFactor = 0.0;
if (scale) scalingFactor = Math.log(ScalingTools.calcScaleFactor(
oScale[start][end][pState]+iScale[start][split][lState]+
iScale[split][end][rState]-iScale[0][length][0]));
double gScore = leftChildScore + scalingFactor + rightChildScore;
if (gScore == Double.NEGATIVE_INFINITY) continue; // no chance of finding a better derivation
for (int lp = 0; lp < nLeftChildStates; lp++) {
double lIS = iScore[start][split][lState][lp];
if (lIS == 0) continue;
// if (lIS < thresh2) continue;
//if (!allowedSubStates[start][split][lState][lp]) continue;
for (int rp = 0; rp < nRightChildStates; rp++) {
if (scores[lp][rp]==null) continue;
double rIS = iScore[split][end][rState][rp];
if (rIS == 0) continue;
// if (rIS < thresh2) continue;
//if (!allowedSubStates[split][end][rState][rp]) continue;
for (int np = 0; np < nParentStates; np++) {
//if (!allowedSubStates[start][end][pState][np]) continue;
double pOS = oScore[start][end][pState][np];
if (pOS == 0) continue;
// if (pOS < thresh2) continue;
double ruleS = scores[lp][rp][np];
if (ruleS == 0) continue;
ruleScore += (pOS * ruleS * lIS * rIS) / logNormalizer;
}
}
}
if (ruleScore==0) continue;
ruleScore = Math.log(ruleScore);
gScore += ruleScore;
if (gScore > Double.NEGATIVE_INFINITY) {
HyperEdge newElement = new HyperEdge(pState,lState,rState,0,0,0,start,split,end,gScore,ruleScore);
if (gScore>bestScore){
bestScore = gScore;
bestElement = newElement;
}
if (diff>2) chartBeforeU[start][end][pState].addToFringe(newElement);
}
}
}
if (diff==2&&bestElement!=null) chartBeforeU[start][end][pState].addToFringe(bestElement);
// chartBeforeU[start][end][pState].expandNextBest();
}
} else { // diff == 1
// We treat TAG --> word exactly as if it was a unary rule, except the score of the rule is
// given by the lexicon rather than the grammar and that we allow another unary on top of it.
//for (int tag : lexicon.getAllTags()){
for (int tag=0; tag<numSubStatesArray.length; tag++){
if (!allowedStates[start][end][tag]) continue;
chartBeforeU[start][end][tag] = new LazyList(grammar.isGrammarTag);
int nTagStates = numSubStatesArray[tag];
String word = sentence.get(start);
//System.out.print("Attempting");
if (grammar.isGrammarTag(tag)) continue;
//System.out.println("Computing maxcScore for span " +start + " to "+end);
double[] lexiconScoreArray = lexicon.score(word, (short) tag, start, false,false);
double lexiconScores = 0;
for (int tp = 0; tp < nTagStates; tp++) {
double pOS = oScore[start][end][tag][tp];
// if (pOS < thresh2) continue;
double ruleS = lexiconScoreArray[tp];
lexiconScores += (pOS * ruleS) / logNormalizer; // The inside score of a word is 0.0f
}
double scalingFactor = 0.0;
if (scale) scalingFactor = Math.log(ScalingTools.calcScaleFactor(
oScale[start][end][tag]-iScale[0][length][0]));
lexiconScores = Math.log(lexiconScores);
double gScore = lexiconScores + scalingFactor;
HyperEdge newElement = new HyperEdge(tag,-1,-1,0,0,0,start,start,end,gScore, lexiconScores);
chartBeforeU[start][end][tag].addToFringe(newElement);
// chartBeforeU[start][end][tag].expandNextBest();
}
}
// Try unary rules
// Replacement for maxcScore[start][end], which is updated in batch
// double[] maxcScoreStartEnd = new double[numStates];
// for (int i = 0; i < numStates; i++) {
// maxcScoreStartEnd[i] = maxcScore[start][end][i];
// }
for (int pState=0; pState<numSubStatesArray.length; pState++){
if (!allowedStates[start][end][pState]) continue;
chartAfterU[start][end][pState] = new LazyList(grammar.isGrammarTag);
int nParentStates = numSubStatesArray[pState]; // == scores[0].length;
UnaryRule[] unaries = grammar.getClosedSumUnaryRulesByParent(pState);
HyperEdge bestElement = null;
double bestScore = Double.NEGATIVE_INFINITY;
for (int r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
int cState = ur.childState;
if ((pState == cState)) continue;// && (np == cp))continue;
if (iScore[start][end][cState]==null) continue;
double childScore = Double.NEGATIVE_INFINITY;
if (chartBeforeU[start][end][cState]!=null){
HyperEdge bestChild = chartBeforeU[start][end][cState].getKbest(0);
childScore = (bestChild==null) ? Double.NEGATIVE_INFINITY : bestChild.score;
}
// double childScore = maxcScore[start][end][cState];
if (childScore==initVal) continue;
double scalingFactor = 0.0;
if (scale) scalingFactor = Math.log(ScalingTools.calcScaleFactor(
oScale[start][end][pState]+iScale[start][end][cState]
-iScale[0][length][0]));
double gScore = scalingFactor + childScore;
// if (gScore < maxcScoreStartEnd[pState]) continue;
double[][] scores = ur.getScores2();
int nChildStates = numSubStatesArray[cState]; // == scores.length;
double ruleScore = 0;
for (int cp = 0; cp < nChildStates; cp++) {
double cIS = iScore[start][end][cState][cp];
if (cIS == 0) continue;
// if (cIS < thresh2) continue;
//if (!allowedSubStates[start][end][cState][cp]) continue;
if (scores[cp]==null) continue;
for (int np = 0; np < nParentStates; np++) {
//if (!allowedSubStates[start][end][pState][np]) continue;
double pOS = oScore[start][end][pState][np];
if (pOS < 0) continue;
// if (pOS < thresh2) continue;
double ruleS = scores[cp][np];
if (ruleS == 0) continue;
ruleScore += (pOS * ruleS * cIS) / logNormalizer;
}
}
if (ruleScore==0) continue;
ruleScore = Math.log(ruleScore);
gScore += ruleScore;
if (gScore > Double.NEGATIVE_INFINITY) {
HyperEdge newElement = new HyperEdge(pState,cState,0,0,start,end,gScore, ruleScore);
if (gScore>bestScore){
bestScore = gScore;
bestElement = newElement;
}
if (diff>1) chartAfterU[start][end][pState].addToFringe(newElement);
}
}
if (diff==1&&bestElement!=null) chartAfterU[start][end][pState].addToFringe(bestElement);
if (chartBeforeU[start][end][pState]!=null){
HyperEdge bestSelf = chartBeforeU[start][end][pState].getKbest(0);
if (bestSelf != null){
HyperEdge selfRule = new HyperEdge(pState,pState,0,0,start,end,bestSelf.score,0);
chartAfterU[start][end][pState].addToFringe(selfRule);
}
}
// chartAfterU[start][end][pState].expandNextBest();
}
// maxcScore[start][end] = maxcScoreStartEnd;
}
}
}
/**
* Returns the best parse, the one with maximum expected labelled recall.
* Assumes that the maxc* arrays have been filled.
*/
public Tree<String> extractBestMaxRuleParse(int start, int end, List<String> sentence) {
return extractBestMaxRuleParse1(start, end, 0, 0, sentence);
// System.out.println(extractBestMaxRuleParse1(start, end, 0, 0, sentence));
// System.out.println(extractBestMaxRuleParse1(start, end, 0, 1, sentence));
// System.out.println(extractBestMaxRuleParse1(start, end, 0, 2, sentence));
// return extractBestMaxRuleParse1(start, end, 0, 3, sentence);
}
public List<Tree<String>> extractKBestMaxRuleParses(int start, int end, List<String> sentence, int k) {
List<Tree<String>> list = new ArrayList<Tree<String>>(k);
maxRuleScores = new ArrayList<Double>(k);
tmp_k = 0;
for (int i=0; i<k; i++){
Tree<String> tmp = extractBestMaxRuleParse1(start, end, 0, i, sentence);
if (tmp!=null){
maxRuleScores.add(chartAfterU[0][length][0].getKbest(i).score);
}
// HyperEdge parentNode = chartAfterU[start][end][0].getKbest(i);
// if (parentNode!=null) System.out.println(parentNode.score+" ");
if (tmp!=null) list.add(tmp);
else break;
}
return list;
}
public double getModelScore(Tree<String> parsedTree) {
return maxRuleScores.get(tmp_k++);
}
/**
* Returns the best parse for state "state", potentially starting with a unary rule
*/
public Tree<String> extractBestMaxRuleParse1(int start, int end, int state, int suboptimalities, List<String> sentence ) {
//System.out.println(start+", "+end+";");
HyperEdge parentNode = chartAfterU[start][end][state].getKbest(suboptimalities);
if (parentNode==null){
System.out.println("Don't have a "+(suboptimalities+1)+"-best tree.");
return null;
}
int cState = parentNode.childState;
Tree<String> result = null;
HyperEdge childNode = chartBeforeU[start][end][cState].getKbest(parentNode.childBest);
List<Tree<String>> children = new ArrayList<Tree<String>>();
String stateStr = (String)tagNumberer.object(cState);//+""+start+""+end;
if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2);
boolean posLevel = (end - start == 1);
if (posLevel) {
// List<Tree<String>> childs = new ArrayList<Tree<String>>();
// childs.add(new Tree<String>(sentence.get(start)));
// String stateStr2 = (String)tagNumberer.object(childNode.parentState);//+""+start+""+end;
// children.add(new Tree<String>(stateStr2,childs));
// }
// else {
children.add(new Tree<String>(sentence.get(start)));
// }
} else {
int split = childNode.split;
if (split == -1) {
System.err.println("Warning: no symbol can generate the span from "+ start+ " to "+end+".");
System.err.println("The score is "+maxcScore[start][end][state]+" and the state is supposed to be "+stateStr);
System.err.println("The insideScores are "+Arrays.toString(iScore[start][end][state])+" and the outsideScores are " +Arrays.toString(oScore[start][end][state]));
System.err.println("The maxcScore is "+maxcScore[start][end][state]);
//return extractBestMaxRuleParse2(start, end, maxcChild[start][end][state], sentence);
return new Tree<String>("ROOT");
}
int lState = childNode.lChildState;
int rState = childNode.rChildState;
Tree<String> leftChildTree = extractBestMaxRuleParse1(start, split, lState, childNode.lChildBest, sentence);
Tree<String> rightChildTree = extractBestMaxRuleParse1(split, end, rState, childNode.rChildBest, sentence);
children.add(leftChildTree);
children.add(rightChildTree);
}
boolean scale = false;
updateConstrainedMaxCScores(sentence, scale, childNode);
result = new Tree<String>(stateStr, children);
if (cState != state){ // unaryRule
stateStr = (String)tagNumberer.object(state);//+""+start+""+end;
if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2);
int intermediateNode = grammar.getUnaryIntermediate((short)state,(short)cState);
if (intermediateNode>0){
List<Tree<String>> restoredChild = new ArrayList<Tree<String>>();
String stateStr2 = (String)tagNumberer.object(intermediateNode);
if (stateStr2.endsWith("^g")) stateStr2 = stateStr2.substring(0,stateStr2.length()-2);
restoredChild.add(result);
result = new Tree<String>(stateStr2,restoredChild);
}
List<Tree<String>> childs = new ArrayList<Tree<String>>();
childs.add(result);
result = new Tree<String>(stateStr,childs);
}
updateConstrainedMaxCScores(sentence, scale, parentNode);
return result;
}
void updateConstrainedMaxCScores(List<String> sentence, final boolean scale, HyperEdge parent) {
int start = parent.start;
int end = parent.end;
int pState = parent.parentState;
int suboptimalities = parent.parentBest + 1;
double ruleScore = parent.ruleScore;
if (parent.alreadyExpanded)
return;
if (!parent.isUnary) {
// if (chartBeforeU[start][end][pState].sortedListSize() >= suboptimalities) return; // already have enough derivations
int lState = parent.lChildState;
int rState = parent.rChildState;
int split = parent.split;
HyperEdge newParentL = null, newParentR = null;
if (split-start>1) { // left is not a POS
int lBest = parent.lChildBest+1;
HyperEdge lChild = chartAfterU[start][split][lState].getKbest(lBest);
if (lChild!=null){
int rBest = parent.rChildBest;
HyperEdge rChild = chartAfterU[split][end][rState].getKbest(rBest);
double newScore = lChild.score + rChild.score + ruleScore;
newParentL = new HyperEdge(pState,lState,rState,suboptimalities,lBest,rBest,start,split,end,newScore,ruleScore);
// chartBeforeU[start][end][pState].addToFringe(newParentL);
}
}
if (end-split>1){
int rBest = parent.rChildBest+1;
HyperEdge rChild = chartAfterU[split][end][rState].getKbest(rBest);
if (rChild!=null){
int lBest = parent.lChildBest;
HyperEdge lChild = chartAfterU[start][split][lState].getKbest(lBest);
double newScore = lChild.score + rChild.score + ruleScore;
newParentR = new HyperEdge(pState,lState,rState,suboptimalities,lBest,rBest,start,split,end,newScore,ruleScore);
// chartBeforeU[start][end][pState].addToFringe(newParentR);
}
}
if (newParentL!=null && newParentR!=null && newParentL.score > newParentR.score) chartBeforeU[start][end][pState].addToFringe(newParentL);
else if (newParentL!=null && newParentR!=null) chartBeforeU[start][end][pState].addToFringe(newParentR);
else if (newParentL!=null || newParentR!=null){
if (newParentL!=null) chartBeforeU[start][end][pState].addToFringe(newParentL);
else /*newParentR!=null*/ chartBeforeU[start][end][pState].addToFringe(newParentR);
}
parent.alreadyExpanded = true;
// chartBeforeU[start][end][pState].expandNextBest();
}
else { // unary
// if (chartAfterU[start][end][pState].sortedListSize() >= suboptimalities) return; // already have enough derivations
int cState = parent.childState;
int cBest = parent.childBest+1;
if (end-start>1){
HyperEdge child = chartBeforeU[start][end][cState].getKbest(cBest);
if (child!=null){
double newScore = child.score + ruleScore;
HyperEdge newParent = new HyperEdge(pState,cState,suboptimalities,cBest,start,end,newScore,ruleScore);
// if (newScore>=parent.score)
// System.out.println("ullala");
chartAfterU[start][end][pState].addToFringe(newParent);
}
parent.alreadyExpanded = true;
// chartAfterU[start][end][pState].expandNextBest();
}
}
}
public List<Tree<String>> getKBestConstrainedParses(List<String> sentence, List<String> posTags, int k) {
if (sentence.size()==0) {
ArrayList<Tree<String>> result = new ArrayList<Tree<String>>();
result.add(new Tree<String>("ROOT"));
return result;
}
doPreParses(sentence,null,false,posTags);
List<Tree<String>> bestTrees = null;
double score = 0;
//bestTree = extractBestViterbiParse(0, 0, 0, length, sentence);
//score = viScore[0][length][0];
if (true){//score != Double.NEGATIVE_INFINITY) {
//score = Math.log(score) + (100*iScale[0][length][0]);
//System.out.println("\nFound a parse for sentence with length "+length+". The LL is "+score+".");
//voScore[0][length][0] = 0.0;
//doConstrainedViterbiOutsideScores(baseGrammar);
//pruneChart(pruningThreshold, baseGrammar.numSubStates, grammar.numSubStates, true);
Grammar curGrammar = grammarCascade[endLevel-startLevel+1];
Lexicon curLexicon = lexiconCascade[endLevel-startLevel+1];
//numSubStatesArray = grammar.numSubStates;
//clearArrays();
double initVal = (viterbiParse) ? Double.NEGATIVE_INFINITY : 0;
int level = isBaseline ? 1 : endLevel;
createArrays(false,curGrammar.numStates,curGrammar.numSubStates,level,initVal,false);
initializeChart(sentence,curLexicon,false,false,posTags,false);
doConstrainedInsideScores(curGrammar,viterbiParse,viterbiParse);
score = iScore[0][length][0][0];
if (!viterbiParse) score = Math.log(score);// + (100*iScale[0][length][0]);
logLikelihood = score;
if (score != Double.NEGATIVE_INFINITY) {
// System.out.println("\nFinally found a parse for sentence with length "+length+". The LL is "+score+".");
if (!viterbiParse) {
oScore[0][length][0][0] = 1.0;
doConstrainedOutsideScores(curGrammar,viterbiParse,false);
doConstrainedMaxCScores(sentence,curGrammar,curLexicon,false);
}
//Tree<String> withoutRoot = extractBestMaxRuleParse(0, length, sentence);
// add the root
//ArrayList<Tree<String>> rootChild = new ArrayList<Tree<String>>();
//rootChild.add(withoutRoot);
//bestTree = new Tree<String>("ROOT",rootChild);
//System.out.print(bestTree);
}
else {
// System.out.println("Using scaling code for sentence with length "+length+".");
setupScaling();
initializeChart(sentence,curLexicon,false,false,posTags,true);
doScaledConstrainedInsideScores(curGrammar);
score = iScore[0][length][0][0];
if (!viterbiParse) score = Math.log(score) + (100*iScale[0][length][0]);
// System.out.println("Finally found a parse for sentence with length "+length+". The LL is "+score+".");
// System.out.println("Scale: "+iScale[0][length][0]);
oScore[0][length][0][0] = 1.0;
oScale[0][length][0] = 0;
doScaledConstrainedOutsideScores(curGrammar);
doConstrainedMaxCScores(sentence,curGrammar,curLexicon,true);
}
grammar = curGrammar;
lexicon = curLexicon;
bestTrees = extractKBestMaxRuleParses(0, length, sentence, k);
}
return bestTrees;
}
public CoarseToFineNBestParser newInstance(){
CoarseToFineNBestParser newParser = new CoarseToFineNBestParser(grammar, lexicon, k, unaryPenalty, endLevel, viterbiParse, outputSub, outputScore, accurate, this.doVariational,useGoldPOS, false);
newParser.initCascade(this);
return newParser;
}
public synchronized Object call() {
List<Tree<String>> result = getKBestConstrainedParses(nextSentence, null, k);
nextSentence = null;
synchronized(queue) {
queue.add(result,-nextSentenceID);
queue.notifyAll();
}
return null;
}
}