/** * */ package edu.berkeley.nlp.PCFGLA; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import edu.berkeley.nlp.discPCFG.Linearizer; import edu.berkeley.nlp.syntax.StateSet; import edu.berkeley.nlp.syntax.Tree; import edu.berkeley.nlp.math.DoubleArrays; import edu.berkeley.nlp.math.SloppyMath; import edu.berkeley.nlp.util.ArrayUtil; import edu.berkeley.nlp.util.Numberer; import edu.berkeley.nlp.util.ScalingTools; /** * @author petrov * */ public class ConstrainedTwoChartsParser extends ConstrainedArrayParser{ /** inside and outside scores; start idx, end idx, state, substate -> logProb/prob */ /** NEW: we now have two charts one before applying unaries and one after: */ protected double[][][][] iScorePreU, iScorePostU; protected double[][][][] oScorePreU, oScorePostU; protected int[][][] iScale; protected int[][][] oScale; protected double[][][] maxcScore; // start, end, state --> logProb protected double[][][] maxsScore; // start, end, state --> logProb protected int[][][] maxcSplit; // start, end, state -> split position protected int[][][] maxcChild; // start, end, state -> unary child (if any) protected int[][][] maxcLeftChild; // start, end, state -> left child protected int[][][] maxcRightChild; // start, end, state -> right child public boolean[][][][] allowedSubStates; double[] tmpCountsArray; boolean[] grammarTags; double[] unscaledScoresToAdd; int[][] goldBinaryProduction; int[][] goldUnaryParent; int[][] goldUnaryChild; int[] goldPOS; SpanPredictor spanPredictor; public double[][][] spanScores; int[] stateClass; // double edgesTouched; // int sentencesParsed; public ConstrainedTwoChartsParser(Grammar gr, Lexicon lex, SpanPredictor sp) { grammar = gr; lexicon = lex; spanPredictor = sp; if (spanPredictor!=null) stateClass = spanPredictor.getStateClass(); numSubStatesArray = grammar.numSubStates.clone(); grammarTags = grammar.isGrammarTag; numStates = grammar.numStates; scoresToAdd = new double[(int)ArrayUtil.max(numSubStatesArray)]; unscaledScoresToAdd = new double[scoresToAdd.length]; tmpCountsArray = new double[scoresToAdd.length*scoresToAdd.length*scoresToAdd.length]; tagNumberer = Numberer.getGlobalNumberer("tags"); arraySize = 0; } void doConstrainedInsideScores(final boolean viterbi) { double initVal = 0; //int smallestScale = 10, largestScale = -10; for (int diff = 1; diff <= length; diff++) { //smallestScale = 10; largestScale = -10; //System.out.print(diff + " "); for (int start = 0; start < (length - diff + 1); start++) { int end = start + diff; for (int pState=0; pState<numSubStatesArray.length; pState++){ if (diff==1) continue; // there are no binary rules that span over 1 symbol only if (allowedSubStates[start][end][pState]==null) continue; BinaryRule[] parentRules = grammar.splitRulesWithP(pState); int nParentStates = numSubStatesArray[pState]; boolean somethingChanged = false; for (int i = 0; i < parentRules.length; i++) { BinaryRule r = parentRules[i]; int lState = r.leftChildState; int rState = r.rightChildState; int narrowR = narrowRExtent[start][lState]; boolean iPossibleL = (narrowR < end); // can this left constituent leave space for a right constituent? if (!iPossibleL) { continue; } int narrowL = narrowLExtent[end][rState]; boolean iPossibleR = (narrowL >= narrowR); // can this right constituent fit next to the left constituent? if (!iPossibleR) { continue; } int min1 = narrowR; int min2 = wideLExtent[end][rState]; int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent? if (min > narrowL) { continue; } int max1 = wideRExtent[start][lState]; int max2 = narrowL; int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent? if (min > max) { continue; } // TODO switch order of loops for efficiency double[][][] scores = r.getScores2(); int nLeftChildStates = numSubStatesArray[lState]; int nRightChildStates = numSubStatesArray[rState]; int nRuleStates = scores[0][0] == null ? nParentStates : scores[0][0].length; int divisor = nParentStates/nRuleStates; for (int split = min; split <= max; split++) { boolean changeThisRound = false; if (allowedSubStates[start][split][lState] == null) continue; if (allowedSubStates[split][end][rState] == null) continue; for (int lp = 0; lp < nLeftChildStates; lp++) { double lS = iScorePostU[start][split][lState][lp]; if (lS == initVal) continue; for (int rp = 0; rp < nRightChildStates; rp++) { // if (scores[lp][rp]==null) continue; double rS = iScorePostU[split][end][rState][rp]; if (rS == initVal) continue; double tmp = lS*rS; if (nRuleStates==nParentStates){ if (scores[lp][rp] == null) continue; for (int np = 0; np < nParentStates; np++) { if (!allowedSubStates[start][end][pState][np]) continue; double pS = scores[lp][rp][np]; if (pS == initVal) continue; double thisRound = pS*tmp; if (viterbi){ unscaledScoresToAdd[np] = Math.max(unscaledScoresToAdd[np],thisRound); } else { unscaledScoresToAdd[np] += thisRound; } changeThisRound = true; } } else { for (int np = 0; np < nRuleStates; np++) { double pS = scores[lp/divisor][rp/divisor][np]; if (pS == initVal) continue; double thisRound = pS*tmp; for (int nnp=0; nnp<divisor; nnp++){ int p = np*divisor+nnp; if (!allowedSubStates[start][end][pState][p]) continue; if (viterbi){ unscaledScoresToAdd[p] = Math.max(unscaledScoresToAdd[p],thisRound); } else { unscaledScoresToAdd[p] += thisRound; } } changeThisRound = true; } } } } if (!changeThisRound) continue; somethingChanged = true; //boolean firstTime = false; int parentScale = iScale[start][end][pState]; int currentScale = iScale[start][split][lState]+iScale[split][end][rState]; currentScale = ScalingTools.scaleArray(unscaledScoresToAdd,currentScale); if (parentScale!=currentScale) { if (parentScale==Integer.MIN_VALUE){ // first time to build this span iScale[start][end][pState] = currentScale; } else { int newScale = Math.max(currentScale,parentScale); ScalingTools.scaleArrayToScale(unscaledScoresToAdd,currentScale,newScale); ScalingTools.scaleArrayToScale(iScorePreU[start][end][pState],parentScale,newScale); iScale[start][end][pState] = newScale; } } for (int np = 0; np < nParentStates; np++) { if (viterbi){ iScorePreU[start][end][pState][np] = Math.max(iScorePreU[start][end][pState][np],unscaledScoresToAdd[np]); } else { iScorePreU[start][end][pState][np] += unscaledScoresToAdd[np]; } } Arrays.fill(unscaledScoresToAdd,0); } } if (somethingChanged) { // apply span predictions // if (spanScores!=null){ // double val = spanScores[start][end][0]; // for (int np = 0; np < nParentStates; np++){ // iScorePreU[start][end][pState][np] *= val; // } // } if (start > narrowLExtent[end][pState]) { narrowLExtent[end][pState] = start; wideLExtent[end][pState] = start; } else { if (start < wideLExtent[end][pState]) { wideLExtent[end][pState] = start; } } if (end < narrowRExtent[start][pState]) { narrowRExtent[start][pState] = end; wideRExtent[start][pState] = end; } else { if (end > wideRExtent[start][pState]) { wideRExtent[start][pState] = end; } } } } // now do the unaries for (int pState=0; pState<numSubStatesArray.length; pState++){ if (allowedSubStates[start][end][pState] == null) continue; if (iScorePreU[start][end][pState] == null) continue; // Should be: Closure under sum-product: UnaryRule[] unaries = grammar.getClosedSumUnaryRulesByParent(pState); //UnaryRule[] unaries = grammar.getUnaryRulesByParent(pState).toArray(new UnaryRule[0]); int nParentStates = numSubStatesArray[pState];//scores[0].length; int parentScale = iScale[start][end][pState]; int scaleBeforeUnaries = parentScale; boolean somethingChanged = false; for (int r = 0; r < unaries.length; r++) { UnaryRule ur = unaries[r]; int cState = ur.childState; if ((pState == cState)) continue; if (allowedSubStates[start][end][cState]==null) continue; if (iScorePreU[start][end][cState] == null) continue; double[][] scores = ur.getScores2(); boolean changeThisRound = false; int nChildStates = numSubStatesArray[cState];//scores[0].length; for (int cp = 0; cp < nChildStates; cp++) { if (scores[cp]==null) continue; double iS = iScorePreU[start][end][cState][cp]; if (iS == initVal) continue; for (int np = 0; np < nParentStates; np++) { if (!allowedSubStates[start][end][pState][np]) continue; if (np>scores[cp].length) System.out.println("how come?"); double pS = scores[cp][np]; if (pS == initVal) continue; double thisRound = iS*pS; if (viterbi){ unscaledScoresToAdd[np] = Math.max(unscaledScoresToAdd[np],thisRound); } else { unscaledScoresToAdd[np] += thisRound; } somethingChanged = true; changeThisRound = true; } } if (!changeThisRound) continue; //boolean firstTime = false; int currentScale = iScale[start][end][cState]; currentScale = ScalingTools.scaleArray(unscaledScoresToAdd,currentScale); if (parentScale!=currentScale) { if (parentScale==Integer.MIN_VALUE){ // first time to build this span parentScale = currentScale; } else { int newScale = Math.max(currentScale,parentScale); ScalingTools.scaleArrayToScale(unscaledScoresToAdd,currentScale,newScale); ScalingTools.scaleArrayToScale(iScorePostU[start][end][pState],parentScale,newScale); parentScale = newScale; } } for (int np = 0; np < nParentStates; np++) { if (viterbi){ iScorePostU[start][end][pState][np] = Math.max(iScorePostU[start][end][pState][np],unscaledScoresToAdd[np]); } else { iScorePostU[start][end][pState][np] += unscaledScoresToAdd[np]; } } Arrays.fill(unscaledScoresToAdd,0); } if (somethingChanged){ int newScale = Math.max(scaleBeforeUnaries,parentScale); ScalingTools.scaleArrayToScale(iScorePreU[start][end][pState],scaleBeforeUnaries,newScale); ScalingTools.scaleArrayToScale(iScorePostU[start][end][pState],parentScale,newScale); iScale[start][end][pState] = newScale; if (start > narrowLExtent[end][pState]) { narrowLExtent[end][pState] = start; wideLExtent[end][pState] = start; } else { if (start < wideLExtent[end][pState]) { wideLExtent[end][pState] = start; } } if (end < narrowRExtent[start][pState]) { narrowRExtent[start][pState] = end; wideRExtent[start][pState] = end; } else { if (end > wideRExtent[start][pState]) { wideRExtent[start][pState] = end; } } } // in any case copy/add the scores from before for (int np = 0; np < nParentStates; np++) { double val = iScorePreU[start][end][pState][np]; if (val>0) { if (viterbi){ iScorePostU[start][end][pState][np] = Math.max(iScorePostU[start][end][pState][np],val); } else { iScorePostU[start][end][pState][np] += val; } } } } } } } void doConstrainedOutsideScores(final boolean viterbi) { double initVal = 0; // Arrays.fill(scoresToAdd,initVal); for (int diff = length; diff >= 1; diff--) { for (int start = 0; start + diff <= length; start++) { int end = start + diff; // do unaries // apply span predictions // if (spanScores!=null){ // double val = spanScores[start][end][0]; // if (val != 1){ // for (int pState=0; pState<numSubStatesArray.length; pState++){ // if (allowedSubStates[start][end][pState]==null) continue; // for (int np = 0; np < numSubStatesArray[pState]; np++){ // oScorePreU[start][end][pState][np] *= val; // } // } // } // } for (int cState=0; cState<numSubStatesArray.length; cState++){ if (allowedSubStates[start][end][cState]==null) continue; if (end-start>1 && !grammarTags[cState]) continue; if (iScorePostU[start][end][cState]==null) continue; // Should be: Closure under sum-product: // UnaryRule[] rules = grammar.getClosedSumUnaryRulesByParent(pState); UnaryRule[] rules = grammar.getClosedSumUnaryRulesByChild(cState); //UnaryRule[] rules = grammar.getClosedViterbiUnaryRulesByParent(pState); // For now: //UnaryRule[] rules = grammar.getUnaryRulesByChild(cState).toArray(new UnaryRule[0]); int nChildStates = numSubStatesArray[cState]; boolean somethingChanged = false; int childScale = oScale[start][end][cState]; int scaleBeforeUnaries = childScale; for (int r = 0; r < rules.length; r++) { UnaryRule ur = rules[r]; int pState = ur.parentState; if ((pState == cState)) continue; if (allowedSubStates[start][end][pState]==null) continue; if (iScorePostU[start][end][pState]==null) continue; int nParentStates = numSubStatesArray[pState]; double[][] scores = ur.getScores2(); boolean changeThisRound = false; for (int cp = 0; cp < nChildStates; cp++) { if (scores[cp]==null) continue; if (!allowedSubStates[start][end][cState][cp]) continue; for (int np = 0; np < nParentStates; np++) { if (!allowedSubStates[start][end][pState][np]) continue; double pS = scores[cp][np]; if (pS == initVal) continue; double oS = oScorePreU[start][end][pState][np]; if (oS == initVal) continue; double thisRound = oS*pS; if (viterbi){ unscaledScoresToAdd[cp] = Math.max(unscaledScoresToAdd[cp],thisRound); } else { unscaledScoresToAdd[cp] += thisRound; } somethingChanged = true; changeThisRound = true; } } if (!changeThisRound) continue; int currentScale = oScale[start][end][pState]; currentScale = ScalingTools.scaleArray(unscaledScoresToAdd,currentScale); if (childScale!=currentScale) { if (childScale==Integer.MIN_VALUE){ // first time to build this span childScale = currentScale; } else { int newScale = Math.max(currentScale,childScale); ScalingTools.scaleArrayToScale(unscaledScoresToAdd,currentScale,newScale); ScalingTools.scaleArrayToScale(oScorePostU[start][end][cState],childScale,newScale); childScale = newScale; } } for (int cp = 0; cp < nChildStates; cp++) { if (viterbi){ oScorePostU[start][end][cState][cp] = Math.max(oScorePostU[start][end][cState][cp],unscaledScoresToAdd[cp]); } else { oScorePostU[start][end][cState][cp] += unscaledScoresToAdd[cp]; } } Arrays.fill(unscaledScoresToAdd,initVal); } if (somethingChanged){ int newScale = Math.max(scaleBeforeUnaries,childScale); ScalingTools.scaleArrayToScale(oScorePreU[start][end][cState],scaleBeforeUnaries,newScale); ScalingTools.scaleArrayToScale(oScorePostU[start][end][cState],childScale,newScale); oScale[start][end][cState] = newScale; } // copy/add the entries where the unaries were not useful for (int cp=0; cp<nChildStates; cp++){ double val = oScorePreU[start][end][cState][cp]; if (val>0) { if (viterbi){ oScorePostU[start][end][cState][cp] = Math.max(oScorePostU[start][end][cState][cp], val); } else { oScorePostU[start][end][cState][cp] += val; } } } } // do binaries if (diff==1) continue; // there is no space for a binary for (int pState=0; pState<numSubStatesArray.length; pState++){ if (allowedSubStates[start][end][pState] == null) continue; final int nParentStates = numSubStatesArray[pState]; BinaryRule[] rules = grammar.splitRulesWithP(pState); //BinaryRule[] rules = grammar.splitRulesWithLC(lState); for (int r = 0; r < rules.length; r++) { BinaryRule br = rules[r]; int lState = br.leftChildState; int min1 = narrowRExtent[start][lState]; if (end < min1) { continue; } int rState = br.rightChildState; int max1 = narrowLExtent[end][rState]; if (max1 < min1) { continue; } int min = min1; int max = max1; if (max - min > 2) { int min2 = wideLExtent[end][rState]; min = (min1 > min2 ? min1 : min2); if (max1 < min) { continue; } int max2 = wideRExtent[start][lState]; max = (max1 < max2 ? max1 : max2); if (max < min) { continue; } } double[][][] scores = br.getScores2(); int nLeftChildStates = numSubStatesArray[lState]; int nRightChildStates = numSubStatesArray[rState]; int nRuleStates = scores[0][0] == null ? nParentStates : scores[0][0].length; int divisor = nParentStates/nRuleStates; for (int split = min; split <= max; split++) { if (allowedSubStates[start][split][lState] == null) continue; if (allowedSubStates[split][end][rState] == null) continue; if (split-start>1 && !grammarTags[lState]) continue; if (end-split>1 && !grammarTags[rState]) continue; boolean somethingChanged = false; for (int lp=0; lp<nLeftChildStates; lp++){ double lS = iScorePostU[start][split][lState][lp]; if (lS==initVal) continue; for (int rp=0; rp<nRightChildStates; rp++){ // if (scores[lp][rp]==null) continue; double rS = iScorePostU[split][end][rState][rp]; if (rS==initVal) continue; if (nRuleStates==nParentStates){ if (scores[lp][rp] == null) continue; for (int np=0; np<nParentStates; np++){ double pS = scores[lp][rp][np]; if (pS == initVal) continue; double oS = oScorePostU[start][end][pState][np]; if (oS == initVal) continue; // if (!allowedSubStates[start][end][pState][np]) continue; double thisRoundL = pS*rS*oS; double thisRoundR = pS*lS*oS; if (viterbi){ scoresToAdd[lp] = Math.max(scoresToAdd[lp],thisRoundL); unscaledScoresToAdd[rp] = Math.max(unscaledScoresToAdd[rp],thisRoundR); } else { scoresToAdd[lp] += thisRoundL; unscaledScoresToAdd[rp] += thisRoundR; } somethingChanged = true; } } else { for (int np=0; np<nParentStates; np++){ double pS = scores[lp/divisor][rp/divisor][np/divisor]; if (pS == initVal) continue; double oS = oScorePostU[start][end][pState][np]; if (oS == initVal) continue; // if (!allowedSubStates[start][end][pState][np]) continue; double thisRoundL = pS*rS*oS; double thisRoundR = pS*lS*oS; if (viterbi){ scoresToAdd[lp] = Math.max(scoresToAdd[lp],thisRoundL); unscaledScoresToAdd[rp] = Math.max(unscaledScoresToAdd[rp],thisRoundR); } else { scoresToAdd[lp] += thisRoundL; unscaledScoresToAdd[rp] += thisRoundR; } somethingChanged = true; } } } } if (!somethingChanged) continue; if (DoubleArrays.max(scoresToAdd)!=0){//oScale[start][end][pState]!=Integer.MIN_VALUE && iScale[split][end][rState]!=Integer.MIN_VALUE){ int leftScale = oScale[start][split][lState]; int currentScale = oScale[start][end][pState]+iScale[split][end][rState]; currentScale = ScalingTools.scaleArray(scoresToAdd,currentScale); if (leftScale!=currentScale) { if (leftScale==Integer.MIN_VALUE){ // first time to build this span oScale[start][split][lState] = currentScale; } else { int newScale = Math.max(currentScale,leftScale); ScalingTools.scaleArrayToScale(scoresToAdd,currentScale,newScale); ScalingTools.scaleArrayToScale(oScorePreU[start][split][lState],leftScale,newScale); oScale[start][split][lState] = newScale; } } for (int cp=0; cp<nLeftChildStates; cp++){ if (scoresToAdd[cp] > initVal){ if (viterbi){ oScorePreU[start][split][lState][cp] = Math.max(oScorePreU[start][split][lState][cp],scoresToAdd[cp]); } else { oScorePreU[start][split][lState][cp] += scoresToAdd[cp]; } } } Arrays.fill(scoresToAdd, 0); } if (DoubleArrays.max(unscaledScoresToAdd)!=0){//oScale[start][end][pState]!=Integer.MIN_VALUE && iScale[start][split][lState]!=Integer.MIN_VALUE){ int rightScale = oScale[split][end][rState]; int currentScale = oScale[start][end][pState]+iScale[start][split][lState]; if (currentScale==Integer.MIN_VALUE) System.out.println("shhaaa"); currentScale = ScalingTools.scaleArray(unscaledScoresToAdd,currentScale); if (rightScale!=currentScale) { if (rightScale==Integer.MIN_VALUE){ // first time to build this span oScale[split][end][rState] = currentScale; } else { int newScale = Math.max(currentScale,rightScale); ScalingTools.scaleArrayToScale(unscaledScoresToAdd,currentScale,newScale); ScalingTools.scaleArrayToScale(oScorePreU[split][end][rState],rightScale,newScale); oScale[split][end][rState] = newScale; } } for (int cp=0; cp<nRightChildStates; cp++){ if (unscaledScoresToAdd[cp] > initVal) { if (viterbi){ oScorePreU[split][end][rState][cp] = Math.max(oScorePreU[split][end][rState][cp],unscaledScoresToAdd[cp]); } else { oScorePreU[split][end][rState][cp] += unscaledScoresToAdd[cp]; } } } Arrays.fill(unscaledScoresToAdd, 0); } } } } } } } void initializeChart(List<StateSet> sentence, boolean noSmoothing, List<String> posTags) { final boolean useGoldPOS = (posTags!=null); int start = 0; int end = start+1; for (StateSet word : sentence) { end = start+1; int goldTag = -1; if (useGoldPOS) goldTag = tagNumberer.number(posTags.get(start)); for (short tag=0; tag<numSubStatesArray.length; tag++){ if (allowedSubStates[start][end][tag] == null) continue; if (grammarTags[tag]) continue; if (useGoldPOS && tag!=goldTag) continue; narrowRExtent[start][tag] = end; narrowLExtent[end][tag] = start; wideRExtent[start][tag] = end; wideLExtent[end][tag] = start; // double[] lexiconScores = lexicon.score(word.getWord(),tag,start,noSmoothing,false); double[] lexiconScores = lexicon.score(word,tag,noSmoothing,false); //if (!logProbs) iScale[start][end][tag] = scaleArray(lexiconScores,0); iScale[start][end][tag] = 0; for (short n=0; n<lexiconScores.length; n++){ if (!allowedSubStates[start][end][tag][n]) continue; double prob = lexiconScores[n]; iScorePreU[start][end][tag][n] = prob; } /* if (start==1){ System.out.println(word+" +TAG "+(String)tagNumberer.object(tag)+" "+Arrays.toString(lexiconScores)); }*/ } start++; } } @Override protected void createArrays() { if (arraySize<length){ // if we haven't seen such a long sentence before, allocate arrays arraySize = length; iScorePreU = new double[length][length + 1][][]; iScorePostU = new double[length][length + 1][][]; oScorePreU = new double[length][length + 1][][]; oScorePostU = new double[length][length + 1][][]; iScale = new int[length][length + 1][]; oScale = new int[length][length + 1][]; for (int start = 0; start < length; start++) { for (int end = start + 1; end <= length; end++) { iScorePreU[start][end] = new double[numStates][]; iScorePostU[start][end] = new double[numStates][]; oScorePreU[start][end] = new double[numStates][]; oScorePostU[start][end] = new double[numStates][]; iScale[start][end] = new int[numStates]; oScale[start][end] = new int[numStates]; Arrays.fill(iScale[start][end], Integer.MIN_VALUE); Arrays.fill(oScale[start][end], Integer.MIN_VALUE); for (int state=0; state<numSubStatesArray.length;state++){ if (end-start>1 && !grammarTags[state]) continue; iScorePreU[start][end][state] = new double[numSubStatesArray[state]]; iScorePostU[start][end][state] = new double[numSubStatesArray[state]]; oScorePreU[start][end][state] = new double[numSubStatesArray[state]]; oScorePostU[start][end][state] = new double[numSubStatesArray[state]]; } } } narrowRExtent = new int[length + 1][numStates]; wideRExtent = new int[length + 1][numStates]; narrowLExtent = new int[length + 1][numStates]; wideLExtent = new int[length + 1][numStates]; for (int loc = 0; loc <= length; loc++) { Arrays.fill(narrowLExtent[loc], -1); // the rightmost left with state s ending at i that we can get is the beginning Arrays.fill(wideLExtent[loc], length + 1); // the leftmost left with state s ending at i that we can get is the end Arrays.fill(narrowRExtent[loc], length + 1); // the leftmost right with state s starting at i that we can get is the end Arrays.fill(wideRExtent[loc], -1); // the rightmost right with state s starting at i that we can get is the beginning } } } @Override public Tree<String> getBestConstrainedParse(List<String> sentence, List<String> posTags, boolean[][][][] allowedStates) { // setConstraints(allowedStates,true); boolean noSmoothing = false; List<StateSet> testSentenceStateSet = convertToTestSet(sentence); double ll = doConstrainedInsideOutsideScores(testSentenceStateSet, allowedStates, noSmoothing, null, posTags, viterbi); Tree<String> bestTree = null; if (ll==Double.NEGATIVE_INFINITY){ return new Tree<String>("ROOT"); } if (viterbi) { bestTree = extractBestViterbiParse(0, 0, 0, length, sentence, true); } else { if (spanScores==null) doConstrainedMaxCScores(testSentenceStateSet); else doConstrainedMaxCScores(testSentenceStateSet, spanScores); bestTree = extractBestMaxRuleParse(0, length, sentence); } maxcScore = null; maxcSplit = null; maxcChild = null; maxcLeftChild = null; maxcRightChild = null; // sentencesParsed++; // System.out.println("For parsing "+sentencesParsed+" I hat to touch "+edgesTouched/((double)sentencesParsed)+" on average."); return bestTree; } /** Assumes that inside and outside scores (sum version, not viterbi) have been computed. * In particular, the narrowRExtent and other arrays need not be updated. */ void doConstrainedMaxCScores(List<StateSet> sentence) { maxcScore = new double[length][length + 1][numStates]; maxcSplit = new int[length][length + 1][numStates]; maxcChild = new int[length][length + 1][numStates]; maxcLeftChild = new int[length][length + 1][numStates]; maxcRightChild = new int[length][length + 1][numStates]; double tree_score = iScorePostU[0][length][0][0]; int tree_scale = iScale[0][length][0]; for (int diff = 1; diff <= length; diff++) { //System.out.print(diff + " "); for (int start = 0; start < (length - diff + 1); start++) { int end = start + diff; Arrays.fill(maxcSplit[start][end], -1); Arrays.fill(maxcChild[start][end], -1); Arrays.fill(maxcLeftChild[start][end], -1); Arrays.fill(maxcRightChild[start][end], -1); if (diff > 1) { // diff > 1: Try binary rules for (int pState=0; pState<numSubStatesArray.length; pState++){ if (allowedSubStates[start][end][pState] == null) { continue; } BinaryRule[] parentRules = grammar.splitRulesWithP(pState); int nParentStates = numSubStatesArray[pState]; // == scores[0][0].length; for (int i = 0; i < parentRules.length; i++) { BinaryRule r = parentRules[i]; int lState = r.leftChildState; int rState = r.rightChildState; int narrowR = narrowRExtent[start][lState]; boolean iPossibleL = (narrowR < end); // can this left constituent leave space for a right constituent? if (!iPossibleL) { continue; } int narrowL = narrowLExtent[end][rState]; boolean iPossibleR = (narrowL >= narrowR); // can this right constituent fit next to the left constituent? if (!iPossibleR) { continue; } int min1 = narrowR; int min2 = wideLExtent[end][rState]; int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent? if (min > narrowL) { continue; } int max1 = wideRExtent[start][lState]; int max2 = narrowL; int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent? if (min > max) { continue; } double[][][] scores = r.getScores2(); int nLeftChildStates = numSubStatesArray[lState]; // == scores.length; int nRightChildStates = numSubStatesArray[rState]; // == scores[0].length; for (int split = min; split <= max; split++) { double ruleScore = 0; if (allowedSubStates[start][split][lState] == null) continue; if (allowedSubStates[split][end][rState] == null) continue; double scalingFactor = ScalingTools.calcScaleFactor( oScale[start][end][pState]+ iScale[start][split][lState]+ iScale[split][end][rState]-tree_scale); if (scalingFactor==0) continue; for (int lp = 0; lp < nLeftChildStates; lp++) { double lIS = iScorePostU[start][split][lState][lp]; if (lIS == 0) continue; //if (!allowedSubStates[start][split][lState][lp]) continue; for (int rp = 0; rp < nRightChildStates; rp++) { if (scores[lp][rp]==null) continue; double rIS = iScorePostU[split][end][rState][rp]; if (rIS == 0) continue; //if (!allowedSubStates[split][end][rState][rp]) continue; for (int np = 0; np < nParentStates; np++) { //if (!allowedSubStates[start][end][pState][np]) continue; double pOS = oScorePostU[start][end][pState][np]; if (pOS == 0) continue; double ruleS = scores[lp][rp][np]; if (ruleS == 0) continue; ruleScore += pOS * scalingFactor * ruleS / tree_score * lIS * rIS; } } } if (ruleScore==0) continue; // if (ruleScore==0) { // System.out.println("possible underflow binary"); // if (ruleScore==0){ // System.out.println("Underflow:"); // System.out.println("pOS: "+Arrays.toString(oScorePostU[start][end][pState])); // System.out.println("scalingFactor: "+scalingFactor); // System.out.println("tree_score: "+tree_score); // System.out.println("lIS: "+Arrays.toString(iScorePostU[start][split][lState])); // System.out.println("rIS: "+Arrays.toString(iScorePostU[split][end][rState])); // } // } double leftChildScore = maxcScore[start][split][lState]; double rightChildScore = maxcScore[split][end][rState]; double gScore = ruleScore * leftChildScore * rightChildScore; if (gScore > maxcScore[start][end][pState]) { maxcScore[start][end][pState] = gScore; maxcSplit[start][end][pState] = split; maxcLeftChild[start][end][pState] = lState; maxcRightChild[start][end][pState] = rState; } } } } } else { // diff == 1 // We treat TAG --> word exactly as if it was a unary rule, except the score of the rule is // given by the lexicon rather than the grammar and that we allow another unary on top of it. for (short tag=0; tag<numSubStatesArray.length; tag++){ if (allowedSubStates[start][end][tag]==null) continue; if (grammar.isGrammarTag(tag)) continue; // maxcScore[start][end][tag] = 1; double scalingFactor = ScalingTools.calcScaleFactor( oScale[start][end][tag]- tree_scale); if (scalingFactor==0){ continue; } int nTagStates = numSubStatesArray[tag]; // String word = sentence.get(start); StateSet word = sentence.get(start); double[] lexiconScoreArray = lexicon.score(word, tag, false, false); double lexiconScores = 0; for (int tp = 0; tp < nTagStates; tp++) { double pOS = oScorePostU[start][end][tag][tp]; if (pOS == 0) continue; double ruleS = lexiconScoreArray[tp]; if (ruleS==0) continue; lexiconScores += (pOS * ruleS) / tree_score; } if (lexiconScores==0) continue; // if (lexiconScores==0) System.out.println("possible underflow lexicon"); maxcScore[start][end][tag] = lexiconScores * scalingFactor; } } // Try unary rules // Replacement for maxcScore[start][end], which is updated in batch double[] maxcScoreStartEnd = new double[numStates]; for (int i = 0; i < numStates; i++) { maxcScoreStartEnd[i] = maxcScore[start][end][i]; } for (int pState=0; pState<numSubStatesArray.length; pState++){ if (allowedSubStates[start][end][pState] == null) { continue; } UnaryRule[] unaries = grammar.getClosedSumUnaryRulesByParent(pState); int nParentStates = numSubStatesArray[pState]; // == scores[0].length; for (int r = 0; r < unaries.length; r++) { UnaryRule ur = unaries[r]; // List<UnaryRule> urules = grammar.getUnaryRulesByParent(pState);// // for (UnaryRule ur : urules){ int cState = ur.childState; if ((pState == cState)) continue;// && (np == cp))continue; if (allowedSubStates[start][end][cState]==null) continue; double[][] scores = ur.getScores2(); int nChildStates = numSubStatesArray[cState]; // == scores.length; double ruleScore = 0; double scalingFactor = ScalingTools.calcScaleFactor( oScale[start][end][pState]+iScale[start][end][cState]-tree_scale); if (scalingFactor==0){ continue; } for (int cp = 0; cp < nChildStates; cp++) { double cIS = iScorePreU[start][end][cState][cp]; if (cIS == 0) continue; //if (!allowedSubStates[start][end][cState][cp]) continue; if (scores[cp]==null) continue; for (int np = 0; np < nParentStates; np++) { //if (!allowedSubStates[start][end][pState][np]) continue; double pOS = oScorePreU[start][end][pState][np]; if (pOS == 0) continue; double ruleS = scores[cp][np]; if (ruleS == 0) continue; ruleScore += pOS * scalingFactor * ruleS / tree_score * cIS ; } } if (ruleScore==0) continue; double childScore = maxcScore[start][end][cState]; double gScore = ruleScore * childScore; if (gScore > maxcScoreStartEnd[pState]) { maxcScoreStartEnd[pState] = gScore; maxcChild[start][end][pState] = cState; } } } maxcScore[start][end] = maxcScoreStartEnd; } } } public double doConstrainedInsideOutsideScores(List<StateSet> sentence, boolean[][][][] allowed, boolean noSmoothing, Tree<StateSet> goldTree, List<String> posTags, boolean viterbi){ scrubArrays(); length = (short)sentence.size(); if (allowed!=null) allowedSubStates = allowed; else setConstraints(null, false); createArrays(); initializeChart(sentence,noSmoothing,posTags); double logLikelihood = Double.NEGATIVE_INFINITY; if (spanPredictor!=null) { spanScores = spanPredictor.predictSpans(sentence); doConstrainedInsideScores(viterbi, spanScores); logLikelihood = getLikelihoodAndSetRootOutsideScore(); doConstrainedOutsideScores(viterbi, spanScores); } else { doConstrainedInsideScores(viterbi); logLikelihood = getLikelihoodAndSetRootOutsideScore(); doConstrainedOutsideScores(viterbi); } return logLikelihood; } void doConstrainedMaxCScores(List<StateSet> testSentenceStateSet, double[][][] spanScores2) { throw new Error("Currently not supported"); } void doConstrainedOutsideScores(boolean viterbi, double[][][] spanScores2) { throw new Error("Currently not supported"); } void doConstrainedInsideScores(boolean viterbi, double[][][] spanScores2) { throw new Error("Currently not supported"); } protected double getLikelihoodAndSetRootOutsideScore() { oScorePreU[0][length][0][0] = 1.0; oScale[0][length][0] = 0; return Math.log(iScorePostU[0][length][0][0])+ (ScalingTools.LOGSCALE*iScale[0][length][0]); } /** * */ protected void scrubArrays() { if (iScorePostU==null) return; for (int start = 0; start < length; start++) { for (int end = start + 1; end <= length; end++) { for (int state=0; state<numSubStatesArray.length;state++){ if (allowedSubStates[start][end][state] != null){ if (end-start>1 && !grammarTags[state]) continue; Arrays.fill(iScorePreU[start][end][state],0); Arrays.fill(iScorePostU[start][end][state],0); Arrays.fill(oScorePreU[start][end][state],0); Arrays.fill(oScorePostU[start][end][state],0); Arrays.fill(iScale[start][end], Integer.MIN_VALUE); Arrays.fill(oScale[start][end], Integer.MIN_VALUE); } } } } for (int loc = 0; loc <= length; loc++) { Arrays.fill(narrowLExtent[loc], -1); // the rightmost left with state s ending at i that we can get is the beginning Arrays.fill(wideLExtent[loc], length + 1); // the leftmost left with state s ending at i that we can get is the end Arrays.fill(narrowRExtent[loc], length + 1); // the leftmost right with state s starting at i that we can get is the end Arrays.fill(wideRExtent[loc], -1); // the rightmost right with state s starting at i that we can get is the beginning } } /** * @param allowedSubStates2 */ protected void setConstraints(boolean[][][][] allowedSubStates2, boolean allSubstates) { allowedSubStates = new boolean[length][length+1][][]; for (int start = 0; start < length; start++) { for (int end = start + 1; end <= length; end++) { allowedSubStates[start][end] = new boolean[numStates][]; for (int state = 0; state<numStates; state++){ if (allowedSubStates2==null){ // then we parse without constraints if (end-start>1&&!grammarTags[state]) continue; boolean[] tmp = new boolean[numSubStatesArray[state]]; Arrays.fill(tmp, true); allowedSubStates[start][end][state] = tmp; } else if (allowedSubStates2[start][end][state]!=null){ allowedSubStates[start][end][state] = new boolean[numSubStatesArray[state]]; if (allSubstates) Arrays.fill(allowedSubStates[start][end][state],true); else { for (int substate=0; substate<allowedSubStates2[start][end][state].length; substate++){ if (allowedSubStates2[start][end][state][substate]){ allowedSubStates[start][end][state][2*substate] = true; if (state!=0) allowedSubStates[start][end][state][2*substate+1] = true; } } } } } } } } public void incrementExpectedCounts(Linearizer linearizer, double[] probs, List<StateSet> sentence) { // numSubStatesArray = grammar.numSubStates; double tree_score = iScorePostU[0][length][0][0]; int tree_scale = iScale[0][length][0]; if (SloppyMath.isDangerous(tree_score)){ System.out.println("Training tree has zero probability - presumably underflow!"); return; // System.exit(-1); } for (int start = 0; start < length; start++) { final int lastState = numSubStatesArray.length; StateSet currentStateSet = sentence.get(start); for (int tag=0; tag<lastState; tag++){ if (grammar.isGrammarTag(tag)) continue; if (allowedSubStates[start][start+1][tag] == null) continue; double scalingFactor = ScalingTools.calcScaleFactor( oScale[start][start+1][tag]+ iScale[start][start+1][tag]- tree_scale); if (scalingFactor==0){ continue; } final int nSubStates = numSubStatesArray[tag]; // if (!combinedLexicon){ for (short substate=0; substate<nSubStates; substate++) { //weight by the probability of seeing the tag and word together, given the sentence double iS = iScorePreU[start][start+1][tag][substate]; if (iS==0) continue; double oS = oScorePostU[start][start+1][tag][substate]; if (oS==0) continue; double weight = iS / tree_score * scalingFactor * oS; if (isValidExpectation(weight)){ tmpCountsArray[substate] = weight; } } linearizer.increment(probs, currentStateSet, tag, tmpCountsArray, false); //probs[startIndexWord+substate] += weight; // linearizer.increment(probs, sigIndex, tag, tmpCountsArray); //probs[startIndexWord+substate] += weight; // } // else { // double[] wordScores = lexicon.scoreWord(currentStateSet, tag); // for (short substate=0; substate<nSubStates; substate++) { // //weight by the probability of seeing the tag and word together, given the sentence // double iS = wordScores[substate]; // if (iS==0) continue; // double oS = oScorePostU[start][start+1][tag][substate]; // if (oS==0) continue; // double weight = iS / tree_score * scalingFactor * oS; // tmpCountsArray[substate] = weight; // } // linearizer.increment(probs, wordIndex, tag, tmpCountsArray); //probs[startIndexWord+substate] += weight; // // double[] sigScores = lexicon.scoreSignature(currentStateSet, tag); // if (sigScores==null) continue; // for (short substate=0; substate<nSubStates; substate++) { // //weight by the probability of seeing the tag and word together, given the sentence // double iS = sigScores[substate]; // if (iS==0) continue; // double oS = oScorePostU[start][start+1][tag][substate]; // if (oS==0) continue; // double weight = iS / tree_score * scalingFactor * oS; // tmpCountsArray[substate] = weight; // } // linearizer.increment(probs, sigIndex, tag, tmpCountsArray); //probs[startIndexWord+substate] += weight; // } } } for (int diff = 1; diff <= length; diff++) { for (int start = 0; start < (length - diff + 1); start++) { int end = start + diff; final int lastState = numSubStatesArray.length; for (short pState=0; pState<lastState; pState++){ if (diff==1) continue; // there are no binary rules that span over 1 symbol only if (allowedSubStates[start][end][pState] == null) continue; final int nParentSubStates = numSubStatesArray[pState]; BinaryRule[] parentRules = grammar.splitRulesWithP(pState); for (int i = 0; i < parentRules.length; i++) { BinaryRule r = parentRules[i]; short lState = r.leftChildState; short rState = r.rightChildState; int narrowR = narrowRExtent[start][lState]; boolean iPossibleL = (narrowR < end); // can this left constituent leave space for a right constituent? if (!iPossibleL) { continue; } int narrowL = narrowLExtent[end][rState]; boolean iPossibleR = (narrowL >= narrowR); // can this right constituent fit next to the left constituent? if (!iPossibleR) { continue; } int min1 = narrowR; int min2 = wideLExtent[end][rState]; int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent? if (min > narrowL) { continue; } int max1 = wideRExtent[start][lState]; int max2 = narrowL; int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent? if (min > max) { continue; } double[][][] scores = r.getScores2(); boolean foundSomething = false; int nRuleStates = scores[0][0].length; int divisor = nParentSubStates/nRuleStates; for (int split = min; split <= max; split++) { if (allowedSubStates[start][split][lState] == null) continue; if (allowedSubStates[split][end][rState] == null) continue; double scalingFactor = ScalingTools.calcScaleFactor( oScale[start][end][pState]+ iScale[start][split][lState]+ iScale[split][end][rState]-tree_scale); if (scalingFactor==0){ continue; } int curInd = 0; for (int lp = 0; lp < scores.length; lp++) { double lcIS = iScorePostU[start][split][lState][lp]; if (lcIS == 0) { curInd += scores[0].length * nParentSubStates; continue; } double tmpA = lcIS / tree_score; for (int rp = 0; rp < scores[0].length; rp++) { // if (scores[lp][rp]==null) continue; double rcIS = iScorePostU[split][end][rState][rp]; if (rcIS == 0) { curInd += nParentSubStates; continue; } double tmpB = tmpA * rcIS * scalingFactor; if (nRuleStates==nParentSubStates){ for (int np = 0; np < nParentSubStates; np++) { double pOS = oScorePostU[start][end][pState][np]; if (pOS==0) { curInd++; continue; } double rS = scores[lp][rp][np]; double ruleCount = rS * tmpB * pOS; if (isValidExpectation(ruleCount)){ tmpCountsArray[curInd] += ruleCount; foundSomething = true; } curInd++; } } else { for (int np = 0; np < nParentSubStates; np++) { double pOS = oScorePostU[start][end][pState][np]; if (pOS==0) { curInd++; continue; } double rS = scores[lp/divisor][rp/divisor][np/divisor]; double ruleCount = rS * tmpB * pOS; if (isValidExpectation(ruleCount)){ tmpCountsArray[curInd] += ruleCount; foundSomething = true; } curInd++; } } } } } if (!foundSomething) continue; // nothing changed this round linearizer.increment(probs, r, tmpCountsArray, false); } } final int lastStateU = numSubStatesArray.length; for (short pState=0; pState<lastStateU; pState++){ if (allowedSubStates[start][end][pState] == null) continue; // List<UnaryRule> unaries = grammar.getUnaryRulesByParent(pState); int nParentSubStates = numSubStatesArray[pState]; UnaryRule[] unaries = grammar.getClosedSumUnaryRulesByParent(pState); for (UnaryRule ur : unaries) { short cState = ur.childState; if ((pState == cState)) continue;// && (np == cp))continue; if (allowedSubStates[start][end][cState] == null) continue; double scalingFactor = ScalingTools.calcScaleFactor( oScale[start][end][pState]+iScale[start][end][cState]-tree_scale); if (scalingFactor==0){ continue; } double[][] scores = ur.getScores2(); int curInd = 0; for (int cp = 0; cp < scores.length; cp++) { if (scores[cp]==null) continue; double cIS = iScorePreU[start][end][cState][cp]; if (cIS == 0) { curInd += nParentSubStates; continue; } double tmpA = cIS / tree_score * scalingFactor; for (int np = 0; np < nParentSubStates; np++) { double pOS = oScorePreU[start][end][pState][np]; if (pOS==0){ curInd++; continue; } double rS = scores[cp][np]; double ruleCount = rS * tmpA * pOS; if (isValidExpectation(ruleCount)){ tmpCountsArray[curInd] = ruleCount; } curInd++; } } linearizer.increment(probs, ur, tmpCountsArray, false); //probs[thisStartIndex + curInd-1] += ruleCount; } } } } } public boolean isValidExpectation(double val){ return (val>0 && val < 1.01); } public void updateGrammarAndLexicon(Grammar grammar2, Lexicon lexicon2) { this.grammar = grammar2; this.lexicon = lexicon2; } /** * @param testSentence * @param tree * @param threshold * @return */ public boolean[][][][] getPossibleStates(List<String> testSentence, Tree<StateSet> tree, double threshold, boolean[][][][] previousConstraints, StringBuilder sb) { boolean noSmoothing = false;//true;//(tree!=null); int previouslyPossibleSub = (countPossibleSubStates(previousConstraints)-1)/2; int previouslyPossible = countPossibleStates(previousConstraints); List<StateSet> testSentenceStateSet = convertToTestSet(testSentence); doConstrainedInsideOutsideScores(testSentenceStateSet, previousConstraints,noSmoothing,null,null,true); boolean[][][][] allowedStates = computeAllowedStates(threshold); if (allowedStates[0][testSentence.size()][0]==null) System.out.println("Root got pruned!"); int possibleStates = countPossibleStates(allowedStates); int possibleSubStates = countPossibleSubStates(allowedStates); if (tree!=null) { if (possibleSubStates==0) sb.append("Only gold tree is left!"); putGoldTreeBackIn(tree,allowedStates); } int possibleSubStates2 = countPossibleSubStates(allowedStates); if (possibleSubStates!=possibleSubStates2){ sb.append(", saved gold tree"); possibleSubStates = possibleSubStates2; possibleStates = countPossibleStates(allowedStates); } if (possibleSubStates2==0){ sb.append(", Parse failure! No pruning!"); allowedStates = previousConstraints; possibleSubStates = previouslyPossible; } sb.append(", from: "+previouslyPossibleSub+" ("+previouslyPossible+") to: "+possibleSubStates+" ("+possibleStates+") substates."); return allowedStates; } /** * @param testSentence * @return */ protected List<StateSet> convertToTestSet(List<String> testSentence) { ArrayList<StateSet> list = new ArrayList<StateSet>(testSentence.size()); short ind = 0; for (String word : testSentence){ StateSet stateSet = new StateSet((short)-1, (short)1, word, ind, (short)(ind+1)); ind++; stateSet.wordIndex = -2; stateSet.sigIndex = -2; list.add(stateSet); } return list; } /** * @param allowedStates * @return */ private int countPossibleSubStates(boolean[][][][] allowedStates) { if (allowedStates==null) return 0; int possibleStates = 0; for (int start = 0; start < allowedStates.length; start++) { for (int end = start + 1; end <= allowedStates.length; end++) { final int lastState = numSubStatesArray.length; for (int state = 0; state < lastState; state++) { if (allowedStates[start][end][state]==null) continue; for (int substate = 0; substate < allowedStates[start][end][state].length; substate++) { if (allowedStates[start][end][state][substate]) possibleStates++; } } } } return possibleStates; } private int countPossibleStates(boolean[][][][] allowedStates) { if (allowedStates==null) return 0; int possibleStates = 0; for (int start = 0; start < allowedStates.length; start++) { for (int end = start + 1; end <= allowedStates.length; end++) { final int lastState = numSubStatesArray.length; for (int state = 0; state < lastState; state++) { if (allowedStates[start][end][state]==null) continue; for (int substate = 0; substate < allowedStates[start][end][state].length; substate++) { if (allowedStates[start][end][state][substate]) { possibleStates++; break; } } } } } return possibleStates; } /** * @param tree * @param allowedStates */ private void putGoldTreeBackIn(Tree<StateSet> tree, boolean[][][][] allowedStates) { StateSet node = tree.getLabel(); int state = node.getState(); if (state<numStates){ boolean[] tmp = new boolean[numSubStatesArray[state]]; Arrays.fill(tmp, true); allowedStates[node.from][node.to][state] = tmp; } else { System.out.println("Haven't seen state "+node); } for (Tree<StateSet> child : tree.getChildren()){ if (!child.isLeaf()) putGoldTreeBackIn(child, allowedStates); } } /** * @param tree * @param threshold * @return */ boolean[][][][] computeAllowedStates(double threshold) { double tree_score = iScorePostU[0][length][0][0]; int tree_scale = iScale[0][length][0]; boolean[][][][] result = new boolean[length][length+1][][]; for (int start = 0; start < length; start++) { for (int end = start + 1; end <= length; end++) { result[start][end] = new boolean[numStates][]; final int lastState = numSubStatesArray.length; for (int state = 0; state < lastState; state++) { if (allowedSubStates[start][end][state]==null) continue; boolean atLeastOnePossible = false; for (int substate = 0; substate < numSubStatesArray[state]; substate++) { if (!allowedSubStates[start][end][state][substate]) continue; double iS = iScorePostU[start][end][state][substate]; if (iS==0) continue; double oS = oScorePostU[start][end][state][substate]; if (oS==0) continue; double scalingFactor = ScalingTools.calcScaleFactor( oScale[start][end][state]+iScale[start][end][state]-tree_scale); if (scalingFactor==0) continue; double tmp = Math.max(iS*oScorePreU[start][end][state][substate], iScorePreU[start][end][state][substate]*oS); double posterior = tmp / tree_score * scalingFactor; if (posterior > threshold) { if (result[start][end][state]==null) result[start][end][state] = new boolean[numSubStatesArray[state]]; result[start][end][state][substate]=true; atLeastOnePossible = true; } } if (!atLeastOnePossible) result[start][end][state]=null; } } } return result; } /** * Returns the best parse, the one with maximum expected labelled recall. * Assumes that the maxc* arrays have been filled. */ public Tree<String> extractBestMaxRuleParse(int start, int end, List<String> sentence ) { return extractBestMaxRuleParse1(start, end, 0, sentence); } /** * Returns the best parse for state "state", potentially starting with a unary rule */ public Tree<String> extractBestMaxRuleParse1(int start, int end, int state, List<String> sentence ) { //System.out.println(start+", "+end+";"); int cState = maxcChild[start][end][state]; if (cState == -1) { return extractBestMaxRuleParse2(start, end, state, sentence); } else { List<Tree<String>> child = new ArrayList<Tree<String>>(); child.add( extractBestMaxRuleParse2(start, end, cState, sentence) ); String stateStr = (String) tagNumberer.object(state); if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2); //System.out.println("Adding a unary spanning from "+start+" to "+end+". P: "+stateStr+" C: "+child.get(0).getLabel()); int intermediateNode = grammar.getUnaryIntermediate((short)state,(short)cState); if (intermediateNode==0){ // System.out.println("Added a bad unary from "+start+" to "+end+". P: "+stateStr+" C: "+child.get(0).getLabel()); } if (intermediateNode>0){ List<Tree<String>> restoredChild = new ArrayList<Tree<String>>(); String stateStr2 = (String)tagNumberer.object(intermediateNode); if (stateStr2.endsWith("^g")) stateStr2 = stateStr2.substring(0,stateStr2.length()-2); restoredChild.add(new Tree<String>(stateStr2, child)); //System.out.println("Restored a unary from "+start+" to "+end+": "+stateStr+" -> "+stateStr2+" -> "+child.get(0).getLabel()); return new Tree<String>(stateStr,restoredChild); } return new Tree<String>(stateStr, child); } } /** * Returns the best parse for state "state", but cannot start with a unary */ public Tree<String> extractBestMaxRuleParse2(int start, int end, int state, List<String> sentence ) { List<Tree<String>> children = new ArrayList<Tree<String>>(); String stateStr = (String)tagNumberer.object(state);//+""+start+""+end; if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2); boolean posLevel = (end - start == 1); if (posLevel) { if (grammar.isGrammarTag(state)){ List<Tree<String>> childs = new ArrayList<Tree<String>>(); childs.add(new Tree<String>(sentence.get(start))); String stateStr2 = (String)tagNumberer.object(maxcChild[start][end][state]);//+""+start+""+end; children.add(new Tree<String>(stateStr2,childs)); } else children.add(new Tree<String>(sentence.get(start))); } else { int split = maxcSplit[start][end][state]; if (split == -1) { System.err.println("Warning: no symbol can generate the span from "+ start+ " to "+end+"."); System.err.println("The score is "+maxcScore[start][end][state]+" and the state is supposed to be "+stateStr); System.err.println("The insideScores are "+Arrays.toString(iScorePostU[start][end][state])+" and the outsideScores are " +Arrays.toString(oScorePostU[start][end][state])); System.err.println("The maxcScore is "+maxcScore[start][end][state]); for (short start2=0; start2<length; start2++){ for (short tag=0; tag<numSubStatesArray.length; tag++){ if (grammar.isGrammarTag(tag)) continue; if (maxcScore[start2][start2+1][tag]>0) System.err.println("The maxcScore for word "+start2+" is "+maxcScore[start2][start2+1][tag]); } } //return extractBestMaxRuleParse2(start, end, maxcChild[start][end][state], sentence); return new Tree<String>("ROOT"); } int lState = maxcLeftChild[start][end][state]; int rState = maxcRightChild[start][end][state]; Tree<String> leftChildTree = extractBestMaxRuleParse1(start, split, lState, sentence); Tree<String> rightChildTree = extractBestMaxRuleParse1(split, end, rState, sentence); children.add(leftChildTree); children.add(rightChildTree); } return new Tree<String>(stateStr, children); } @Override public void projectConstraints(boolean[][][][] allowed, boolean allSubstatesAllowed) { if (allowed==null) return; for (int start = 0; start < allowed.length; start++) { for (int end = start + 1; end <= allowed.length; end++) { for (int state = 0; state<numStates; state++){ if (allowed[start][end][state]!=null){ if (numSubStatesArray[state]==allowed[start][end][state].length) continue; boolean[] tmp = new boolean[numSubStatesArray[state]]; if (allSubstatesAllowed) Arrays.fill(tmp, true); else { for (int substate=0; substate<allowed[start][end][state].length; substate++){ if (allowed[start][end][state][substate]){ if (2*substate>=tmp.length) System.out.println("too long"); tmp[2*substate] = true; if (grammar.numSubStates[state]!=1) tmp[2*substate+1] = true; } } } allowed[start][end][state]=tmp; } } } } } public void checkScores(Tree<StateSet> tree) { StateSet node = tree.getLabel(); int state = node.getState(); int from = node.from, to = node.to; int oldS = iScale[from][to][state]; int newS = ScalingTools.scaleArray(iScorePostU[from][to][state], oldS); if (oldS>newS){ System.out.println("why?? iscale"); } oldS = oScale[from][to][state]; newS = ScalingTools.scaleArray(oScorePostU[from][to][state], oldS); if (oldS>newS){ ScalingTools.scaleArrayToScale(oScorePostU[from][to][state], newS, oldS); System.out.println("why?? oscale"); } for (int substate=0; substate<numSubStatesArray[state]; substate++){ if ((node.getIScale()==iScale[from][to][state])&&(!SloppyMath.isGreater(iScorePostU[from][to][state][substate],node.getIScore(substate)))){ if (!allowedSubStates[from][to][state][substate]) System.out.println("This state was pruned!"); else { System.out.println("Gold iScore is higher for state "+state+" from "+from+" to "+to+"!"); System.out.println("Gold "+node.getIScore(substate) +" all "+ iScorePostU[from][to][state][substate]); } } double tmpA = node.getOScore(substate); double tmpB = oScorePostU[from][to][state][substate]; if ((node.getOScale()==oScale[from][to][state])&&(!SloppyMath.isGreater(tmpB, tmpA))){ if (!allowedSubStates[from][to][state][substate]) System.out.println("This state was pruned!"); else { System.out.println("Gold oScore is higher for state "+state+" from "+from+" to "+to+"!"); System.out.println("Gold "+node.getOScore(substate) +" all "+ oScorePostU[from][to][state][substate]); } } } for (Tree<StateSet> child : tree.getChildren()){ if (!child.isLeaf()) checkScores(child); } } // compute the loss in conditional likelihood for merges in nodes in the gold tree public void tallyConditionalLoss(Tree<StateSet> tree, double[][][] deltas, double[][] mergeWeights) { if (tree.isLeaf()) return; for (Tree<StateSet> child : tree.getChildren()) { tallyConditionalLoss(child, deltas, mergeWeights); } StateSet label = tree.getLabel(); short state = label.getState(); if (state==0) return; // nothing to be done for the ROOT int start = label.from, end = label.to; if (allowedSubStates[start][end][state]==null){ System.out.println("Gold state was pruned!!!"); } double[] goldScores = new double[label.numSubStates()]; double[] allScores = new double[label.numSubStates()]; double combinedGoldScore, combinedAllScore; double separatedGoldScoreSum = 0, separatedAllScoreSum = 0, tmp; //don't need to deal with scale factor because we divide below for (int i = 0; i < label.numSubStates(); i++) { // in the gold tree tmp = label.getIScore(i) * label.getOScore(i); goldScores[i] = tmp; separatedGoldScoreSum += tmp; // for all trees tmp = iScorePostU[start][end][state][i] * oScorePostU[start][end][state][i]; allScores[i] = tmp; separatedAllScoreSum += tmp; } if (separatedAllScoreSum==0) return; // for some reason this seems to happen quite often // calculate merged scores for (int i = 0; i < numSubStatesArray[state]; i=i+2) { int j = i+1; double lossInGold = 0, lossInAll = 0; int[] map = {i,j}; double[] tmp1 = new double[2], tmp2 = new double[2]; double mergeWeightSum = 0; for (int k=0; k<2; k++) { mergeWeightSum += mergeWeights[state][map[k]]; } if (mergeWeightSum==0) mergeWeightSum = 1; for (int k=0; k<2; k++) { tmp1[k] = label.getIScore(map[k])*mergeWeights[state][map[k]]/mergeWeightSum; tmp2[k] = label.getOScore(map[k]); } combinedGoldScore = (tmp1[0]+tmp1[1]) * (tmp2[0]+tmp2[1]); double combinedGoldScoreSum = separatedGoldScoreSum - goldScores[i] - goldScores[j] + combinedGoldScore; if (combinedGoldScore!=0 && separatedGoldScoreSum!=0) lossInGold = separatedGoldScoreSum/combinedGoldScoreSum; // now do the same for all trees for (int k=0; k<2; k++) { tmp1[k] = iScorePostU[start][end][state][map[k]]*mergeWeights[state][map[k]]/mergeWeightSum; tmp2[k] = oScorePostU[start][end][state][map[k]]; } combinedAllScore = (tmp1[0]+tmp1[1]) * (tmp2[0]+tmp2[1]); double combinedAllScoreSum = separatedAllScoreSum - allScores[i] - allScores[j] + combinedAllScore; if (combinedGoldScore!=0 && separatedGoldScoreSum!=0) lossInAll = separatedAllScoreSum/combinedAllScoreSum; if (SloppyMath.isDangerous(lossInAll)|| SloppyMath.isDangerous(lossInGold)){ System.out.println("too many zeros "); System.out.println("tmp1: " + Arrays.toString(tmp1) + "\ntmp2: " + Arrays.toString(tmp2) + "\ngoldScores: " + Arrays.toString(goldScores) + "\nallScores: " + Arrays.toString(allScores) + "\nmergeWeights: " + Arrays.toString(mergeWeights[state]) + "\nseparatedGoldScoreSum: "+separatedGoldScoreSum+ "\nseparatedAllScoreSum: "+separatedAllScoreSum + "\ncombinedGoldScoreSum: "+combinedGoldScoreSum+ "\ncombinedAllScoreSum: "+combinedAllScoreSum); } else deltas[state][i][j] += Math.log(lossInGold/lossInAll); if (Double.isNaN(deltas[state][i][j])) { System.out.println(" deltas["+tagNumberer.object(state)+"]["+i+"]["+j+"] = NaN"); System.out.println( Arrays.toString(tmp1) + " " + Arrays.toString(tmp2) + " " + combinedGoldScore+" "+Arrays.toString(mergeWeights[state])); } } } private void setGoldProductions(Tree<StateSet> tree, boolean isBinaryChild){ StateSet node = tree.getLabel(); short parentState = node.getState(); if (parentState==0){ // this is the ROOT node, initialize arrays goldBinaryProduction = new int[length][length+1]; ArrayUtil.fill(goldBinaryProduction,-1); goldUnaryParent = new int[length][length+1]; ArrayUtil.fill(goldUnaryParent,-1); goldUnaryChild = new int[length][length+1]; ArrayUtil.fill(goldUnaryChild,-1); goldPOS = new int[length]; Arrays.fill(goldPOS,-1); goldUnaryParent[0][length] = 0; } if (isBinaryChild) goldBinaryProduction[node.from][node.to] = parentState; else if (parentState!=0) goldUnaryChild[node.from][node.to] = parentState; List<Tree<StateSet>> children = tree.getChildren(); if (children.size()==2){ //binary goldBinaryProduction[node.from][node.to] = parentState; setGoldProductions(children.get(0), true); setGoldProductions(children.get(1), true); } else { // unary or POS Tree<StateSet> child = children.get(0); if (child.isLeaf()) { goldPOS[node.from] = parentState; } else { // unary goldUnaryParent[node.from][node.to] = parentState; setGoldProductions(child, false); } } } public void doPreParses(List<String> sentence, List<String> posTags, Grammar[] grammarCascade, Lexicon[] lexiconCascade, boolean accurate, int startLevel, int endLevel, boolean isBaseline){ throw new Error("currently not supported"); // boolean noSmoothing = false; // clearArrays(); // length = (short)sentence.size(); // double score = 0; // double[] accurateThresholds = {-8,-12,-12,-11,-12,-12,-14}; // double[] fastThresholds = {-8,-9.75,-10,-9.6,-9.66,-8.01,-7.4,-10}; // double[] pruningThreshold = null; // // // createArrays(); // // if (accurate) // pruningThreshold = accurateThresholds; // else // pruningThreshold = fastThresholds; // // //int startLevel = -1; // for (int level=startLevel; level<=endLevel; level++){ // if (level==-1) continue; // don't do the pre-pre parse // if (!isBaseline && level==endLevel) continue;// // this.grammar = grammarCascade[level-startLevel]; // this.lexicon = lexiconCascade[level-startLevel]; // this.numSubStatesArray = grammar.numSubStates; // if (level==startLevel) setConstraints(null, false); // // scrubArrays(); // // // initializeChart(sentence,noSmoothing,posTags); // final boolean viterbi = true; // doConstrainedInsideScores(viterbi); // score = iScorePostU[0][length][0][0]; // // // if (score==Double.NEGATIVE_INFINITY) continue; //// System.out.println("\nFound a parse for sentence with length "+length+". The LL is "+score+"."); // oScorePreU[0][length][0][0] = 0.0; // doConstrainedOutsideScores(viterbi); // // pruneChart(/*Double.NEGATIVE_INFINITY*/pruningThreshold[level+1], level); // } } protected void pruneChart(double threshold, int level){ int totalStates = 0, previouslyPossible = 0, nowPossible = 0; double sentenceProb = iScorePostU[0][length][0][0]; double sentenceScale = iScale[0][length][0]; if (level<1) nowPossible=totalStates=previouslyPossible=length; int startDiff = (level<0) ? 2 : 1; for (int diff = startDiff; diff <= length; diff++) { for (int start = 0; start < (length - diff + 1); start++) { int end = start + diff; int lastState = (level<0) ? 1 : numSubStatesArray.length; for (int state = 0; state < lastState; state++) { if (diff>1&&!grammarTags[state]) continue; if (allowedSubStates[start][end][state]==null) continue; boolean nonePossible = true; int thisScale = iScale[start][end][state]+oScale[start][end][state]; double scalingFactor = 1; if (thisScale != sentenceScale){ scalingFactor *= Math.pow(ScalingTools.SCALE,thisScale-sentenceScale); } for (int substate = 0; substate < numSubStatesArray[state]; substate++) { totalStates++; if (!allowedSubStates[start][end][state][substate]) continue; previouslyPossible++; double iS = iScorePostU[start][end][state][substate]; double oS = oScorePostU[start][end][state][substate]; if (iS==0||oS==0) { allowedSubStates[start][end][state][substate] = false; continue; } double posterior = iS * scalingFactor * oS / sentenceProb; if (posterior > threshold) { allowedSubStates[start][end][state][substate]=true; nowPossible++; nonePossible=false; } else { allowedSubStates[start][end][state][substate] = false; } } if (nonePossible) allowedSubStates[start][end][state] = null; } } } String parse = ""; if (level==-1) parse = "Pre-Parse"; else if (level==0) parse = "X-Bar"; else parse = ((int)Math.pow(2,level))+"-Substates"; // System.out.print(parse+". NoPruning: " +totalStates + ". Before: "+previouslyPossible+". After: "+nowPossible+"."); } public double[][] getBracketPosteriors() { double tree_score = iScorePostU[0][length][0][0]; int tree_scale = iScale[0][length][0]; double[][] result = new double[length][length+1]; for (int start = 0; start < length; start++) { for (int end = start + 1; end <= length; end++) { final int lastState = numSubStatesArray.length; for (int state = 0; state < lastState; state++) { if (allowedSubStates[start][end][state]==null) continue; for (int substate = 0; substate < numSubStatesArray[state]; substate++) { if (!allowedSubStates[start][end][state][substate]) continue; double iS = iScorePostU[start][end][state][substate]; if (iS==0) continue; double oS = oScorePostU[start][end][state][substate]; if (oS==0) continue; double scalingFactor = ScalingTools.calcScaleFactor( oScale[start][end][state]+iScale[start][end][state]-tree_scale); if (scalingFactor==0) continue; double tmp = Math.max(iS*oScorePreU[start][end][state][substate], iScorePreU[start][end][state][substate]*oS); double posterior = tmp / tree_score * scalingFactor; // if (posterior>1.01) // System.out.println("too much"); result[start][end] += posterior; // if (result[start][end]>1.01) // result[start][end] = 1; // System.out.println("too much"); } } } } return result; } /** * Return the single best parse. * Note that the returned tree may be missing intermediate nodes in * a unary chain because it parses with a unary-closed grammar. */ public Tree<String> extractBestViterbiParse(int gState, int gp, int start, int end, List<String> sentence, boolean unaryAllowed) { // find sources of inside score // no backtraces so we can speed up the parsing for its primary use double bestScore = (unaryAllowed) ? iScorePostU[start][end][gState][gp] : iScorePreU[start][end][gState][gp]; String goalStr = (String)tagNumberer.object(gState); //System.out.println("Looking for "+goalStr+" from "+start+" to "+end+" with score "+ bestScore+"."); if (end - start == 1) { // if the goal state is a preterminal state, then it can't transform into // anything but the word below it // if (lexicon.getAllTags().contains(gState)) { if (!grammar.isGrammarTag[gState]){ List<Tree<String>> child = new ArrayList<Tree<String>>(); child.add(new Tree<String>(sentence.get(start))); return new Tree<String>(goalStr, child); } // if the goal state is not a preterminal state, then find a way to // transform it into one else { double veryBestScore = Double.NEGATIVE_INFINITY; int newIndex = -1; UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(gState); for (int r = 0; r < unaries.length; r++) { UnaryRule ur = unaries[r]; int cState = ur.childState; double[][] scores = ur.getScores2(); for (int cp=0; cp<scores.length; cp++){ if (scores[cp]==null) continue; double ruleScore = iScorePreU[start][end][cState][cp] * scores[cp][gp]; if ((ruleScore >= veryBestScore) && (gState != cState || gp != cp) && (!grammar.isGrammarTag[ur.getChildState()])){ // && lexicon.getAllTags().contains(cState)) { veryBestScore = ruleScore; newIndex = cState; } } } List<Tree<String>> child1 = new ArrayList<Tree<String>>(); child1.add(new Tree<String>(sentence.get(start))); String goalStr1 = (String) tagNumberer.object(newIndex); if (goalStr1==null) System.out.println("goalStr1==null with newIndex=="+newIndex+" goalStr=="+goalStr); List<Tree<String>> child = new ArrayList<Tree<String>>(); child.add(new Tree<String>(goalStr1, child1)); return new Tree<String>(goalStr, child); } } // check binaries first for (int split = start + 1; split < end; split++) { //for (Iterator binaryI = grammar.bRuleIteratorByParent(gState, gp); binaryI.hasNext();) { //BinaryRule br = (BinaryRule) binaryI.next(); BinaryRule[] parentRules = grammar.splitRulesWithP(gState); for (int i = 0; i < parentRules.length; i++) { BinaryRule br = parentRules[i]; int lState = br.leftChildState; if (iScorePostU[start][split][lState]==null) continue; int rState = br.rightChildState; if (iScorePostU[split][end][rState]==null) continue; //new: iterate over substates double[][][] scores = br.getScores2(); for (int lp=0; lp<scores.length; lp++){ for (int rp=0; rp<scores[lp].length; rp++){ if (scores[lp][rp]==null) continue; double score = ScalingTools.scaleToScale(scores[lp][rp][gp] * iScorePostU[start][split][lState][lp] * iScorePostU[split][end][rState][rp], iScale[start][split][lState]+iScale[split][end][rState], iScale[start][end][gState]); if (matches(score, bestScore)) { // build binary split Tree<String> leftChildTree = extractBestViterbiParse(lState, lp, start, split, sentence, true); Tree<String> rightChildTree = extractBestViterbiParse(rState, rp, split, end, sentence, true); List<Tree<String>> children = new ArrayList<Tree<String>>(); children.add(leftChildTree); children.add(rightChildTree); Tree<String> result = new Tree<String>(goalStr, children); //System.out.println("Binary node: "+result); //result.setScore(score); return result; } } } } } // check unaries //for (Iterator unaryI = grammar.uRuleIteratorByParent(gState, gp); unaryI.hasNext();) { //UnaryRule ur = (UnaryRule) unaryI.next(); UnaryRule[] unaries = grammar.getClosedViterbiUnaryRulesByParent(gState); for (int r = 0; r < unaries.length; r++) { UnaryRule ur = unaries[r]; int cState = ur.childState; if (iScorePostU[start][end][cState]==null) continue; //new: iterate over substates double[][] scores = ur.getScores2(); for (int cp=0; cp<scores.length; cp++){ if (scores[cp]==null) continue; double score = ScalingTools.scaleToScale(scores[cp][gp] * iScorePreU[start][end][cState][cp], iScale[start][end][cState], iScale[start][end][gState]); if ((cState != ur.parentState || cp != gp) && matches(score, bestScore)) { // build unary Tree<String> childTree = extractBestViterbiParse(cState, cp, start, end, sentence, false); List<Tree<String>> children = new ArrayList<Tree<String>>(); children.add(childTree); Tree<String> result = new Tree<String>(goalStr, children); //System.out.println("Unary node: "+result); //result.setScore(score); return result; } } } System.err.println("Warning: could not find the optimal way to build state "+goalStr+" spanning from "+ start+ " to "+end+"."); return null; } public double[][][][] getPreUnaryInsideScores() { return iScorePreU; } public double[][][][] getPostUnaryInsideScores() { return iScorePostU; } public double[][][][] getPreUnaryOutsideScores() { return oScorePreU; } public double[][][][] getPostUnaryOutsideScores() { return oScorePostU; } public int[][][] getInsideScalingFactors() { return iScale; } public int[][][] getOutsideScalingFactors() { return oScale; } }