package edu.berkeley.nlp.PCFGLA;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.PriorityQueue;
import edu.berkeley.nlp.util.StringUtils;
public class ConstrainedArrayParser extends ArrayParser implements Callable{
List<Integer>[][] possibleStates;
/** inside scores; start idx, end idx, state -> logProb */
protected double[][][][] iScore;
/** outside scores; start idx, end idx, state -> logProb */
protected double[][][][] oScore;
protected short[] numSubStatesArray;
public long totalUsedUnaries;
public long nRules, nRulesInf;
//the chart is now using scaled probabilities, NOT log-probs.
protected int[][][] iScale; // for each (start,end) span there is a scaling factor
protected int[][][] oScale;
Binarization binarization;
Counter<String> stateCounter = new Counter<String>();
Counter<String> ruleCounter = new Counter<String>();
public boolean viterbi = false;
/** number of times we restored unaries */
public int nTimesRestoredUnaries;
boolean noConstrains = false;
protected List<String> nextSentence;
protected int nextSentenceID;
int myID;
PriorityQueue<List<Tree<String>>> queue;
public void setID(int i, PriorityQueue<List<Tree<String>>> q){
myID = i;
queue = q;
}
public void setNextSentence(List<String> nextS, int nextID){
nextSentence = nextS;
nextSentenceID = nextID;
}
public synchronized Object call() {
Tree<String> parse = getBestParse(nextSentence);
nextSentence = null;
ArrayList<Tree<String>> result = new ArrayList<Tree<String>>();
result.add(parse);
synchronized(queue) {
queue.add(result,-nextSentenceID);
queue.notifyAll();
}
return null;
}
public ConstrainedArrayParser newInstance(){
ConstrainedArrayParser newParser = new ConstrainedArrayParser(grammar, lexicon, numSubStatesArray);
return newParser;
}
public double getLogLikelihood(Tree<String> t){
System.out.println("Unsuported for now!");
return Double.NEGATIVE_INFINITY;
}
public Tree<String>[] getSampledTrees(List<String> sentence, List<Integer>[][] pStates, int n){
return null;
}
public void setNoConstraints(boolean noC){
this.noConstrains = noC;
}
public List<Tree<String>> getKBestConstrainedParses(List<String> sentence, List<String> posTags, int k) {
return null;
}
public ConstrainedArrayParser(){
}
public ConstrainedArrayParser(Grammar gr, Lexicon lex, short[] nSub) {
super(gr, lex);
this.numSubStatesArray = nSub;
totalUsedUnaries=0;
nTimesRestoredUnaries=0;
nRules=0;
nRulesInf=0;
//Math.pow(GrammarTrainer.SCALE,scaleDiff);
}
// public Tree<String> getBestConstrainedParse(List<String> sentence, List<Integer>[][] pStates) {
// length = (short)sentence.size();
// this.possibleStates = pStates;
// createArrays();
// initializeChart(sentence);
//
// doConstrainedInsideScores();
// //showScores(iScore, "Inside scores:");
///* oScore[0][length][0][0] = 0;
// doConstrainedOutsideScores();
//
// List<Integer> possibleParentSt = possibleStates[12][13];
// for (int pState : possibleParentSt){
// System.out.println(pState + " " + (String) tagNumberer.object(pState) + " iScore "+ Arrays.toString(iScore[12][13][pState]) + " oScore "+ Arrays.toString(oScore[12][13][pState]));
// }
//
//*/
// Tree<String> bestTree = new Tree<String>("ROOT");
// double score = iScore[0][length][0][0];
// if (score > Double.NEGATIVE_INFINITY) {
// //System.out.println("\nFound a parse for sentence with length "+length+". The LL is "+score+".");
// Tree<StateSet> bestStateSetTree = extractBestStateSetTree(zero, zero, zero, length, sentence);
// tallyStatesAndRules(bestStateSetTree);
// bestTree = restoreStateSetTreeUnaries(bestStateSetTree);
// //bestTree = extractBestParse(0, 0, 0, length, sentence);
// //restoreUnaries(bestTree);
// }
// else {
// System.out.println("()\nDid NOT find a parse for sentence with length "+length+".");
// }
//
//
// return bestTree;
// }
/**
* Create a string representing the state for a StateSet tree that has the
* substate first iScore.
*
*/
private String getStateString(Tree<StateSet> tree) {
return tagNumberer.object(tree.getLabel().getState())+"&"+(short)tree.getLabel().getIScore(0);
}
/** Compute statistics on how often each state and rule appeared.
*
* @param bestStateSetTree
*/
private void tallyStatesAndRules(Tree<StateSet> bestStateSetTree) {
if (bestStateSetTree.isLeaf() || bestStateSetTree.isPreTerminal())
return;
String stateString = getStateString(bestStateSetTree);
stateCounter.incrementCount(stateString,1);
String ruleString = stateString+"->";
for (Tree<StateSet> child : bestStateSetTree.getChildren()) {
tallyStatesAndRules(child);
ruleString += "|"+getStateString(child);
}
ruleCounter.incrementCount(ruleString,1);
}
/**
* Print the statistics about how often each state and rule appeared.
*
*/
public void printStateAndRuleTallies() {
System.out.println("STATE TALLIES");
for (String state : stateCounter.keySet()) {
System.out.println(state+" "+stateCounter.getCount(state));
}
System.out.println("RULE TALLIES");
for (String rule : ruleCounter.keySet()) {
System.out.println(rule+" "+ruleCounter.getCount(rule));
}
}
protected void createArrays() {
// zero out some stuff first in case we recently ran out of memory and are reallocating
clearArrays();
// allocate just the parts of iScore and oScore used (end > start, etc.)
// System.out.println("initializing iScore arrays with length " + length + " and numStates " + numStates);
iScore = new double[length][length + 1][][];
oScore = new double[length][length + 1][][];
iScale = new int[length][length + 1][];
oScale = new int[length][length + 1][];
for (int start = 0; start < length; start++) { // initialize for all POS tags so that we can use the lexicon
int end = start+1;
iScore[start][end] = new double[numStates][];
oScore[start][end] = new double[numStates][];
iScale[start][end] = new int[numStates];
oScale[start][end] = new int[numStates];
for (int state = 0; state < numStates; state++){
iScore[start][end][state] = new double[numSubStatesArray[state]];
oScore[start][end][state] = new double[numSubStatesArray[state]];
Arrays.fill(iScore[start][end][state], Float.NEGATIVE_INFINITY);
Arrays.fill(oScore[start][end][state], Float.NEGATIVE_INFINITY);
}
}
for (int start = 0; start < length; start++) {
for (int end = start + 2; end <= length; end++) {
iScore[start][end] = new double[numStates][];
oScore[start][end] = new double[numStates][];
iScale[start][end] = new int[numStates];
oScale[start][end] = new int[numStates];
List<Integer> pStates = null;
if (noConstrains){
pStates = new ArrayList<Integer>();
for (int i = 0; i<numStates; i++){pStates.add(i); }
}
else {
pStates = possibleStates[start][end];
}
for (int state : pStates){
iScore[start][end][state] = new double[numSubStatesArray[state]];
oScore[start][end][state] = new double[numSubStatesArray[state]];
Arrays.fill(iScore[start][end][state], Float.NEGATIVE_INFINITY);
Arrays.fill(oScore[start][end][state], Float.NEGATIVE_INFINITY);
}
if (start==0 && end==length ) {
if (pStates.size()==0)
System.out.println("no states span the entire tree!");
if (iScore[start][end][0]==null)
System.out.println("ROOT does not span the entire tree!");
}
}
}
narrowRExtent = new int[length + 1][numStates];
wideRExtent = new int[length + 1][numStates];
narrowLExtent = new int[length + 1][numStates];
wideLExtent = new int[length + 1][numStates];
for (int loc = 0; loc <= length; loc++) {
Arrays.fill(narrowLExtent[loc], -1); // the rightmost left with state s ending at i that we can get is the beginning
Arrays.fill(wideLExtent[loc], length + 1); // the leftmost left with state s ending at i that we can get is the end
Arrays.fill(narrowRExtent[loc], length + 1); // the leftmost right with state s starting at i that we can get is the end
Arrays.fill(wideRExtent[loc], -1); // the rightmost right with state s starting at i that we can get is the beginning
}
}
void initializeChart(List<String> sentence) {
// for simplicity the lexicon will store words and tags as strings,
// while the grammar will be using integers -> Numberer()
int start = 0;
int end = start+1;
for (String word : sentence) {
end = start+1;
for (short tag=0; tag<grammar.numSubStates.length; tag++){
if (grammar.isGrammarTag[tag]) continue;
// List<Integer> possibleSt = possibleStates[start][end];
// for (int tag : possibleSt){
narrowRExtent[start][tag] = end;
narrowLExtent[end][tag] = start;
wideRExtent[start][tag] = end;
wideLExtent[end][tag] = start;
double[] lexiconScores = lexicon.score(word,tag,start,false,false);
for (short n=0; n<numSubStatesArray[tag]; n++){
double prob = lexiconScores[n];
/* if (prob>0){
prob = -10;
System.out.println("Should never happen! Log-Prob > 0!!!");
System.out.println("Word "+word+" Tag "+(String)tagNumberer.object(tag)+" prob "+prob);
}*/
iScore[start][end][tag][n] = prob;
/* UnaryRule[] unaries = grammar.getClosedUnaryRulesByChild(state);
for (int r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
int parentState = ur.parent;
double pS = (double) ur.score;
double tot = prob + pS;
if (tot > iScore[start][end][parentState]) {
iScore[start][end][parentState] = tot;
narrowRExtent[start][parentState] = end;
narrowLExtent[end][parentState] = start;
wideRExtent[start][parentState] = end;
wideLExtent[end][parentState] = start;
}
}*/
}
}
start++;
}
}
public Tree<String> getBestConstrainedParse(List<String> sentence, List<String> posTags, boolean[][][][] allowedS){//List<Integer>[][] pStates) {
return getBestConstrainedParse(sentence, posTags);
}
public Tree<String> getBestConstrainedParse(List<String> sentence, List<String> posTags){//List<Integer>[][] pStates) {
length = (short)sentence.size();
// this.possibleStates = pStates;
noConstrains = true;
createArrays();
initializeChart(sentence);
doConstrainedInsideScores();
//showScores(iScore, "Inside scores:");
/* oScore[0][length][0][0] = 0;
doConstrainedOutsideScores();
List<Integer> possibleParentSt = possibleStates[12][13];
for (int pState : possibleParentSt){
System.out.println(pState + " " + (String) tagNumberer.object(pState) + " iScore "+ Arrays.toString(iScore[12][13][pState]) + " oScore "+ Arrays.toString(oScore[12][13][pState]));
}
*/
Tree<String> bestTree = new Tree<String>("ROOT");
double score = iScore[0][length][0][0];
if (score > Double.NEGATIVE_INFINITY) {
//System.out.println("\nFound a parse for sentence with length "+length+". The LL is "+score+".");
Tree<StateSet> bestStateSetTree = extractBestStateSetTree(zero, zero, zero, length, sentence);
// tallyStatesAndRules(bestStateSetTree);
bestTree = restoreStateSetTreeUnaries(bestStateSetTree);
//bestTree = extractBestParse(0, 0, 0, length, sentence);
//restoreUnaries(bestTree);
}
else {
System.out.println("()\nDid NOT find a parse for sentence with length "+length+".");
}
return bestTree;
}
/** Fills in the iScore array of each category over each span
* of length 2 or more.
*/
void doConstrainedInsideScores() {
grammar.logarithmMode();
lexicon.logarithmMode();
for (int diff = 1; diff <= length; diff++) {
System.out.print(diff + " ");
for (int start = 0; start < (length - diff + 1); start++) {
int end = start + diff;
List<Integer> possibleSt = null;
if (noConstrains){
possibleSt = new ArrayList<Integer>();
for (int i = 0; i<numStates; i++){possibleSt.add(i); }
}
else {
possibleSt = possibleStates[start][end];
}
for (int pState : possibleSt) {
BinaryRule[] parentRules = grammar.splitRulesWithP(pState);
for (int i = 0; i < parentRules.length; i++) {
BinaryRule r = parentRules[i];
int lState = r.leftChildState;
int rState = r.rightChildState;
int narrowR = narrowRExtent[start][lState];
boolean iPossibleL = (narrowR < end); // can this left constituent leave space for a right constituent?
if (!iPossibleL) { continue; }
int narrowL = narrowLExtent[end][rState];
boolean iPossibleR = (narrowL >= narrowR); // can this right constituent fit next to the left constituent?
if (!iPossibleR) { continue; }
int min1 = narrowR;
int min2 = wideLExtent[end][rState];
int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent?
if (min > narrowL) { continue; }
int max1 = wideRExtent[start][lState];
int max2 = narrowL;
int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent?
if (min > max) { continue; }
// new: loop over all substates
double[][][] scores = r.getScores2();
int nParentSubStates = numSubStatesArray[pState];
for (int np = 0; np < nParentSubStates; np++) {
double oldIScore = iScore[start][end][pState][np];
double bestIScore = oldIScore;
for (int split = min; split <= max; split++) {
if (iScore[start][split][lState] == null) continue;
if (iScore[split][end][rState] == null) continue;
for (int lp = 0; lp < scores.length; lp++) {
double lS = iScore[start][split][lState][lp];
if (lS == Double.NEGATIVE_INFINITY) continue;
for (int rp = 0; rp < scores[0].length; rp++) {
nRules++;
double pS = Double.NEGATIVE_INFINITY;
if (scores[lp][rp]!=null) pS = scores[lp][rp][np];
if (pS==Double.NEGATIVE_INFINITY){
nRulesInf++;
continue;
//System.out.println("s "+start+" sp "+split+" e "+end+" pS "+pS+" rS "+rS);
}
double rS = iScore[split][end][rState][rp];
if (rS == Double.NEGATIVE_INFINITY) continue;
double tot = pS + lS + rS;
if (tot >= bestIScore) { bestIScore = tot;}
}
}
}
if (bestIScore > oldIScore) { // this way of making "parentState" is better
// than previous
iScore[start][end][pState][np] = bestIScore;
if (oldIScore == Double.NEGATIVE_INFINITY) {
if (start > narrowLExtent[end][pState]) {
narrowLExtent[end][pState] = start;
wideLExtent[end][pState] = start;
} else {
if (start < wideLExtent[end][pState]) {
wideLExtent[end][pState] = start;
}
}
if (end < narrowRExtent[start][pState]) {
narrowRExtent[start][pState] = end;
wideRExtent[start][pState] = end;
} else {
if (end > wideRExtent[start][pState]) {
wideRExtent[start][pState] = end;
}
}
}
}
}
}
}
for (int pState : possibleSt){//int pState=0; pState<0; pState++){//
//UnaryRule[] unaries = grammar.getUnaryRulesByParent(pState).toArray(new UnaryRule[0]);
// it actually seems to be better to use the unaries without the closure...
//UnaryRule[] unaries = new UnaryRule[0];
UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(pState);
for (int r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
int cState = ur.childState;
if (iScore[start][end][cState]==null) continue;
//if ((pState == cState)) continue;// && (np == cp))continue;
//new loop over all substates
double[][] scores = ur.getScores2();
int nParentSubStates = numSubStatesArray[pState];
for (int np = 0; np < nParentSubStates; np++) {
double oldIScore = iScore[start][end][pState][np];
double bestIScore = oldIScore;
for (int cp = 0; cp < scores.length; cp++) {
double pS = Double.NEGATIVE_INFINITY;
if (scores[cp]!=null) pS = scores[cp][np];
nRules++;
if (pS==Double.NEGATIVE_INFINITY){
nRulesInf++;
continue;
}
double iS = iScore[start][end][cState][cp];
if (iS == Double.NEGATIVE_INFINITY) continue;
double tot = iS + pS;
if (tot >= bestIScore) { bestIScore = tot; }
}
if (bestIScore > oldIScore) {
iScore[start][end][pState][np] = bestIScore;
if (oldIScore == Double.NEGATIVE_INFINITY) {
if (start > narrowLExtent[end][pState]) {
narrowLExtent[end][pState] = start;
wideLExtent[end][pState] = start;
} else {
if (start < wideLExtent[end][pState]) {
wideLExtent[end][pState] = start;
}
}
if (end < narrowRExtent[start][pState]) {
narrowRExtent[start][pState] = end;
wideRExtent[start][pState] = end;
} else {
if (end > wideRExtent[start][pState]) {
wideRExtent[start][pState] = end;
}
}
}
}
}
}
}
}
}
}
void doConstrainedOutsideScores() {
grammar.logarithmMode();
lexicon.logarithmMode();
for (int diff = length; diff >= 1; diff--) {
for (int start = 0; start + diff <= length; start++) {
int end = start + diff;
// do unaries
//List<Integer> possibleParentSt = possibleStates[start][end];
List<Integer> possibleParentSt = null;
if (noConstrains){
possibleParentSt = new ArrayList<Integer>();
for (int i = 0; i<numStates; i++){possibleParentSt.add(i); }
}
else {
possibleParentSt = possibleStates[start][end];
}
for (int pState : possibleParentSt){
// this check should be unnecessary. if we get a null pointer
// exception here, then we did not initialize the arrays properly - slav
//if (oScore[start][end][pState] == null) { continue; }
UnaryRule[] rules = grammar.getClosedViterbiUnaryRulesByParent(pState);
for (int r = 0; r < rules.length; r++) {
UnaryRule ur = rules[r];
int cState = ur.childState;
if (oScore[start][end][cState] == null) { continue; }
//new loop over all substates
double[][] scores = ur.getScores2();
for (int cp = 0; cp < scores.length; cp++) {
double oldOScore = oScore[start][end][cState][cp];
double bestOScore = oldOScore;
double iS = iScore[start][end][cState][cp];
if (iS == Double.NEGATIVE_INFINITY) { continue; }
for (int np = 0; np < scores[0].length; np++) {
double oS = oScore[start][end][pState][np];
double pS = Double.NEGATIVE_INFINITY;
if (scores[cp]!=null) pS = scores[cp][np];
double tot = oS + pS;
if (tot > bestOScore) {
bestOScore = tot;
}
}
if (bestOScore > oldOScore) {
oScore[start][end][cState][cp] = bestOScore;
}
}
}
}
// do binaries
//for (int lState = 0; lState < numStates; lState++) {
for (int pState=0; pState < numStates; pState++){
//BinaryRule[] rules = grammar.splitRulesWithLC(lState);
BinaryRule[] rules = grammar.splitRulesWithP(pState);
for (int r = 0; r < rules.length; r++) {
BinaryRule br = rules[r];
if (oScore[start][end][br.parentState]==null) {continue;}
int lState = br.leftChildState;
int min1 = narrowRExtent[start][lState];
if (end < min1) { continue; }
int rState = br.rightChildState;
int max1 = narrowLExtent[end][rState];
if (max1 < min1) { continue; }
int min = min1;
int max = max1;
if (max - min > 2) {
int min2 = wideLExtent[end][rState];
min = (min1 > min2 ? min1 : min2);
if (max1 < min) { continue; }
int max2 = wideRExtent[start][lState];
max = (max1 < max2 ? max1 : max2);
if (max < min) { continue; }
}
double[][][] scores = br.getScores2();
for (int split = min; split <= max; split++) {
if (oScore[start][split][lState] == null) continue;
if (oScore[split][end][rState] == null) continue;
for (int lp=0; lp<scores.length; lp++){
double lS = iScore[start][split][lState][lp];
if (lS == Double.NEGATIVE_INFINITY) { continue; }
for (int rp=0; rp<scores[lp].length; rp++){
double rS = iScore[split][end][rState][rp];
if (rS == Double.NEGATIVE_INFINITY) { continue; }
if (scores[lp][rp]==null) continue;
for (int np=0; np<scores[lp][rp].length; np++){
double oS = oScore[start][end][br.parentState][np];
double pS = scores[lp][rp][np];
double totL = pS + rS + oS;
if (totL > oScore[start][split][lState][lp]) {
oScore[start][split][lState][lp] = totL;
}
double totR = pS + lS + oS;
if (totR > oScore[split][end][rState][rp]) {
oScore[split][end][rState][rp] = totR;
}
}
}
}
}
}
}
/* for (int rState = 0; rState < numStates; rState++) {
int max1 = narrowLExtent[end][rState];
if (max1 < start) { continue; }
BinaryRule[] rules = grammar.splitRulesWithRC(rState);
for (int r = 0; r < rules.length; r++) {
BinaryRule br = rules[r];
if (oScore[start][end][br.parentState]==null) {continue;}
int lState = br.leftChildState;
int min1 = narrowRExtent[start][lState];
if (max1 < min1) { continue; }
int min = min1;
int max = max1;
if (max - min > 2) {
int min2 = wideLExtent[end][rState];
min = (min1 > min2 ? min1 : min2);
if (max1 < min) { continue; }
int max2 = wideRExtent[start][lState];
max = (max1 < max2 ? max1 : max2);
if (max < min) { continue; }
}
double[][][] scores = br.getScores();
for (int split = min; split <= max; split++) {
if (oScore[start][split][lState] == null) continue;
if (oScore[split][end][rState] == null) continue;
for (int lp=0; lp<scores[0].length; lp++){
double lS = iScore[start][split][lState][lp];
if (lS == Double.NEGATIVE_INFINITY) { continue; }
for (int rp=0; rp<scores[0][0].length; rp++){
double rS = iScore[split][end][rState][rp];
if (rS == Double.NEGATIVE_INFINITY) { continue; }
for (int np=0; np<scores.length; np++){
double oS = oScore[start][end][br.parentState][np];
if (oS == Double.NEGATIVE_INFINITY) { continue; }
double pS = scores[np][lp][rp];
double totL = pS + rS + oS;
if (totL > oScore[start][split][lState][lp]) {
System.err.println("Shouldn't occur!");
System.exit(1);
oScore[start][split][lState][lp] = totL;
}
double totR = pS + lS + oS;
if (totR > oScore[split][end][rState][rp]) {
System.err.println("Shouldn't occur!");
System.exit(1);
oScore[split][end][rState][rp] = totR;
}
}
}
}
}
}
}*/
}
}
}
public void showScores(double[][][][] scores, String title) {
System.out.println(title);
for (int diff = 1; diff <= length; diff++) {
for (int start = 0; start < (length - diff + 1); start++) {
int end = start + diff;
System.out.print("[" + start + " " + end + "]: ");
//List<Integer> possibleSt = possibleStates[start][end];
List<Integer> possibleSt = null;
if (noConstrains){
possibleSt = new ArrayList<Integer>();
for (int i = 0; i<numStates; i++){possibleSt.add(i); }
}
else {
possibleSt = possibleStates[start][end];
}
for (int state : possibleSt) {
if (scores[start][end][state] != null) {
for (int s=0; s < grammar.numSubStates[state]; s++ ) {
Numberer n = grammar.tagNumberer;
System.out.print("(" +
StringUtils.escapeString(n.object(state).toString(), new char[]{'\"'}, '\\') +
"[" + s + "] " + scores[start][end][state][s] + ")");
}
}
}
System.out.println();
}
}
}
/**
* Return the single best parse.
* Note that the returned tree may be missing intermediate nodes in
* a unary chain because it parses with a unary-closed grammar.
*/
public Tree<String> extractBestParse(int gState, int gp, int start, int end, List<String> sentence ) {
// find sources of inside score
// no backtraces so we can speed up the parsing for its primary use
double bestScore = iScore[start][end][gState][gp];
String goalStr = (String)tagNumberer.object(gState);
//System.out.println("Looking for "+goalStr+" from "+start+" to "+end+" with score "+ bestScore+".");
if (end - start == 1) {
// if the goal state is a preterminal state, then it can't transform into
// anything but the word below it
// if (lexicon.getAllTags().contains(gState)) {
if (!grammar.isGrammarTag[gState]){
List<Tree<String>> child = new ArrayList<Tree<String>>();
child.add(new Tree<String>(sentence.get(start)));
return new Tree<String>(goalStr, child);
}
// if the goal state is not a preterminal state, then find a way to
// transform it into one
else {
double veryBestScore = Double.NEGATIVE_INFINITY;
int newIndex = -1;
UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(gState);
for (int r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
int cState = ur.childState;
double[][] scores = ur.getScores2();
for (int cp=0; cp<scores.length; cp++){
if (scores[cp]==null) continue;
double ruleScore = iScore[start][end][cState][cp] + scores[cp][gp];
if ((ruleScore >= veryBestScore) && (gState != cState || gp != cp)
&& (!grammar.isGrammarTag[ur.getChildState()])){
// && lexicon.getAllTags().contains(cState)) {
veryBestScore = ruleScore;
newIndex = cState;
}
}
}
List<Tree<String>> child1 = new ArrayList<Tree<String>>();
child1.add(new Tree<String>(sentence.get(start)));
String goalStr1 = (String) tagNumberer.object(newIndex);
if (goalStr1==null)
System.out.println("goalStr1==null with newIndex=="+newIndex+" goalStr=="+goalStr);
List<Tree<String>> child = new ArrayList<Tree<String>>();
child.add(new Tree<String>(goalStr1, child1));
return new Tree<String>(goalStr, child);
}
}
// check binaries first
for (int split = start + 1; split < end; split++) {
//for (Iterator binaryI = grammar.bRuleIteratorByParent(gState, gp); binaryI.hasNext();) {
//BinaryRule br = (BinaryRule) binaryI.next();
BinaryRule[] parentRules = grammar.splitRulesWithP(gState);
for (int i = 0; i < parentRules.length; i++) {
BinaryRule br = parentRules[i];
int lState = br.leftChildState;
if (iScore[start][split][lState]==null) continue;
int rState = br.rightChildState;
if (iScore[split][end][rState]==null) continue;
//new: iterate over substates
double[][][] scores = br.getScores2();
for (int lp=0; lp<scores.length; lp++){
for (int rp=0; rp<scores[lp].length; rp++){
if (scores[lp][rp]==null) continue;
double score = scores[lp][rp][gp] + iScore[start][split][lState][lp]
+ iScore[split][end][rState][rp];
if (matches(score, bestScore)) {
// build binary split
Tree<String> leftChildTree = extractBestParse(lState, lp, start, split, sentence);
Tree<String> rightChildTree = extractBestParse(rState, rp, split, end, sentence);
List<Tree<String>> children = new ArrayList<Tree<String>>();
children.add(leftChildTree);
children.add(rightChildTree);
Tree<String> result = new Tree<String>(goalStr, children);
//System.out.println("Binary node: "+result);
//result.setScore(score);
return result;
}
}
}
}
}
// check unaries
//for (Iterator unaryI = grammar.uRuleIteratorByParent(gState, gp); unaryI.hasNext();) {
//UnaryRule ur = (UnaryRule) unaryI.next();
UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(gState);
for (int r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
int cState = ur.childState;
if (iScore[start][end][cState]==null) continue;
//new: iterate over substates
double[][] scores = ur.getScores2();
for (int cp=0; cp<scores.length; cp++){
if (scores[cp]==null) continue;
double score = scores[cp][gp] + iScore[start][end][cState][cp];
if ((cState != ur.parentState || cp != gp) && matches(score, bestScore)) {
// build unary
Tree<String> childTree = extractBestParse(cState, cp, start, end, sentence);
List<Tree<String>> children = new ArrayList<Tree<String>>();
children.add(childTree);
Tree<String> result = new Tree<String>(goalStr, children);
//System.out.println("Unary node: "+result);
//result.setScore(score);
return result;
}
}
}
System.err.println("Warning: could not find the optimal way to build state "+goalStr+" spanning from "+ start+ " to "+end+".");
return null;
}
/**
* Return the single best parse.
* Note that the returned tree may be missing intermediate nodes in
* a unary chain because it parses with a unary-closed grammar.
* A StateSet tree is returned, but the subState array is used in a
* different way:
* it has only one entry, whose value is the substate! - dirty hack...
*/
public Tree<StateSet> extractBestStateSetTree(short gState, short gp, short start, short end, List<String> sentence ) {
// find sources of inside score
// no backtraces so we can speed up the parsing for its primary use
double bestScore = iScore[start][end][gState][gp];
//Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
//System.out.println("Looking for "+(String)tagNumberer.object(gState)+" from "+start+" to "+end+" with score "+ bestScore+".");
if (end - start == 1) {
// if the goal state is a preterminal state, then it can't transform into
// anything but the word below it
if (!grammar.isGrammarTag(gState)) {
List<Tree<StateSet>> child = new ArrayList<Tree<StateSet>>();
StateSet node = new StateSet(zero,zero,sentence.get(start),start,end);
child.add(new Tree<StateSet>(node));
StateSet root = new StateSet(gState,one,null,start,end);
root.allocate();
root.setIScore(0,gp);
return new Tree<StateSet>(root, child);
}
// if the goal state is not a preterminal state, then find a way to
// transform it into one
else {
double veryBestScore = Double.NEGATIVE_INFINITY;
short newIndex = -1;
short newSubstate = -1;
UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(gState);
for (int r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
short cState = ur.childState;
double[][] scores = ur.getScores2();
for (short cp=0; cp<scores.length; cp++){
if (scores[cp]==null) continue;
if (iScore[start][end][cState]==null) continue;
double ruleScore = iScore[start][end][cState][cp] + scores[cp][gp];
if ((ruleScore >= veryBestScore) && (gState != cState || gp != cp)
&& !grammar.isGrammarTag(cState)) {
veryBestScore = ruleScore;
newIndex = cState;
newSubstate = cp;
}
}
}
List<Tree<StateSet>> child1 = new ArrayList<Tree<StateSet>>();
StateSet node1 = new StateSet(zero,zero,sentence.get(start),start,end);
child1.add(new Tree<StateSet>(node1));
if (newIndex==-1)
System.out.println("goalStr1==null with newIndex=="+newIndex+" goalState=="+gState);
List<Tree<StateSet>> child = new ArrayList<Tree<StateSet>>();
StateSet node = new StateSet(newIndex,one, null, start, end);
node.allocate();
node.setIScore(0,newSubstate);
child.add(new Tree<StateSet>(node,child1));
StateSet root = new StateSet(gState,one, null, start, end);
root.allocate();
root.setIScore(0,gp);
//totalUsedUnaries++;
return new Tree<StateSet>(root, child);
}
}
// check binaries first
double bestBScore = Double.NEGATIVE_INFINITY;
// BinaryRule bestBRule = null;
// short bestBLp, bestBRp;
//TODO: fix parsing
for (int split = start + 1; split < end; split++) {
BinaryRule[] parentRules = grammar.splitRulesWithP(gState);
for (short i = 0; i < parentRules.length; i++) {
BinaryRule br = parentRules[i];
short lState = br.leftChildState;
if (iScore[start][split][lState]==null) continue;
short rState = br.rightChildState;
if (iScore[split][end][rState]==null) continue;
//new: iterate over substates
double[][][] scores = br.getScores2();
for (short lp=0; lp<scores.length; lp++){
for (short rp=0; rp<scores[lp].length; rp++){
if (scores[lp][rp]==null) continue;
double score = scores[lp][rp][gp] + iScore[start][split][lState][lp]
+ iScore[split][end][rState][rp];
if (score > bestBScore)
bestBScore = score;
if (matches(score, bestScore)) {
// build binary split
Tree<StateSet> leftChildTree = extractBestStateSetTree(lState, lp, start, (short)split, sentence);
Tree<StateSet> rightChildTree = extractBestStateSetTree(rState, rp, (short)split, end, sentence);
List<Tree<StateSet>> children = new ArrayList<Tree<StateSet>>();
children.add(leftChildTree);
children.add(rightChildTree);
StateSet root = new StateSet(gState,one, null, start, end);
root.allocate();
root.setIScore(0,gp);
Tree<StateSet> result = new Tree<StateSet>(root, children);
//System.out.println("Binary node: "+result);
//result.setScore(score);
return result;
}
}
}
}
}
double bestUScore = Double.NEGATIVE_INFINITY;
// check unaries
UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(gState);
for (short r = 0; r < unaries.length; r++) {
UnaryRule ur = unaries[r];
short cState = ur.childState;
if (iScore[start][end][cState]==null) continue;
//new: iterate over substates
double[][] scores = ur.getScores2();
for (short cp=0; cp<scores.length; cp++){
if (scores[cp]==null) continue;
double rScore = scores[cp][gp];
double score = rScore + iScore[start][end][cState][cp];
if (score > bestUScore)
bestUScore = score;
if ((cState != ur.parentState || cp != gp) && matches(score, bestScore)) {
// build unary
Tree<StateSet> childTree = extractBestStateSetTree(cState, cp, start, end, sentence);
List<Tree<StateSet>> children = new ArrayList<Tree<StateSet>>();
children.add(childTree);
StateSet root = new StateSet(gState,one, null, start, end);
root.allocate();
root.setIScore(0,gp);
Tree<StateSet> result = new Tree<StateSet>(root, children);
//System.out.println("Unary node: "+result);
//result.setScore(score);
totalUsedUnaries++;
return result;
}
}
}
System.err.println("Warning: could not find the optimal way to build state "+gState+" spanning from "+ start+ " to "+end+".");
System.err.println("The goal score was "+bestScore+", but the best we found was a binary rule giving "+bestBScore+" and a unary rule giving "+bestUScore);
showScores(iScore,"iScores");
return null;
}
// the state set tree has nodes that are labeled with substate information
// the substate information is the first element in the iscore array
protected Tree<String> restoreStateSetTreeUnaries(Tree<StateSet> t) {
//System.out.println("In restoreUnaries...");
//System.out.println("Doing node: "+node.getLabel());
if (t.isLeaf()) { // shouldn't happen
System.err.println("Tried to restore unary from a leaf...");
return null;
} else if (t.isPreTerminal()){ // preterminal unaries have already been restored
List<Tree<String>> child = new ArrayList<Tree<String>>();
child.add(new Tree<String>(t.getChildren().get(0).getLabel().getWord()));
return new Tree<String>((String)tagNumberer.object(t.getLabel().getState()), child);
} else if (t.getChildren().size() != 1) { // nothing to restore
// build binary split
Tree<String> leftChildTree = restoreStateSetTreeUnaries(t.getChildren().get(0));
Tree<String> rightChildTree = restoreStateSetTreeUnaries(t.getChildren().get(1));
List<Tree<String>> children = new ArrayList<Tree<String>>();
children.add(leftChildTree);
children.add(rightChildTree);
return new Tree<String>((String)tagNumberer.object(t.getLabel().getState()), children);
} // the interesting part:
//System.out.println("Not skipping node: "+node.getLabel());
StateSet parent = t.getLabel();
StateSet child = t.getChildren().get(0).getLabel();
short pLabel = parent.getState();
short pSubState = (short)parent.getIScore(0); // dirty hack
short cLabel = child.getState();
short cSubState = (short)child.getIScore(0);
//System.out.println("P: "+(String)tagNumberer.object(pLabel)+" C: "+(String)tagNumberer.object(cLabel));
List<Tree<String>> goodChild = new ArrayList<Tree<String>>();
goodChild.add(restoreStateSetTreeUnaries(t.getChildren().get(0)));
// do we need a check here? if we can check whether the rule was
// in the original grammar, then we wouldnt need the getBestPath call.
// but getBestPath should be able to take care of that...
// if (grammar.getUnaryScore(new UnaryRule(pLabel,cLabel))[0][0] != 0){ continue; }// means the rule was already in grammar
//System.out.println("Got path: "+path);
//if (path.size()==1) return goodChild;
Tree<String> result = new Tree<String>((String)tagNumberer.object(pLabel),goodChild);
Tree<String> working = result;
// List<short[]> path = grammar.getBestViterbiPath(pLabel,pSubState, cLabel,cSubState);
// if (path.size()>2) {
// nTimesRestoredUnaries++;
// }
// for (int pos=1; pos < path.size() - 1; pos++) {
// int interState = path.get(pos)[0];
// Tree<String> intermediate = new Tree<String>((String) tagNumberer.object(interState), working.getChildren());
// List<Tree<String>> children = new ArrayList<Tree<String>>();
// children.add(intermediate);
// working.setChildren(children);
// working = intermediate;
// }
return working;
}
public double[][][][] getInsideScores() {
return ArrayUtil.clone(iScore);
}
public double[][][][] getOutsideScores() {
return ArrayUtil.clone(oScore);
}
public void printUnaryStats(){
System.out.println(" Used a total of "+totalUsedUnaries+" unary productions.");
System.out.println(" restored unaries "+nTimesRestoredUnaries);
System.out.println(" Out of "+nRules+" rules "+nRulesInf+" had probability=-Inf.");
}
public void projectConstraints(boolean[][][][] allowed, boolean allSubstatesAllowed){
System.err.println("Not supported!\nThis parser cannot project constraints!");
}
/**
* @return the numSubStatesArray
*/
public short[] getNumSubStatesArray() {
return numSubStatesArray;
}
}