/** * */ package edu.berkeley.nlp.PCFGLA; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import edu.berkeley.nlp.syntax.Tree; import edu.berkeley.nlp.util.ArrayUtil; import edu.berkeley.nlp.util.ScalingTools; /** * @author petrov * */ public class CoarseToFineMaxRuleDerivationParser extends CoarseToFineMaxRuleParser { protected double[][][][] maxcScore; // start, end, state --> logProb protected double[][][][] maxsScore; // start, end, state --> logProb protected int[][][][] maxcSplit; // start, end, state -> split position protected int[][][][] maxcChild; // start, end, state -> unary child (if any) protected int[][][][] maxcChildSub; // start, end, state -> unary child (if any) protected int[][][][] maxcLeftChild; // start, end, state -> left child protected int[][][][] maxcRightChild; // start, end, state -> right child protected int[][][][] maxcLeftChildSub; // start, end, state -> left child protected int[][][][] maxcRightChildSub; // start, end, state -> right child public CoarseToFineMaxRuleDerivationParser(Grammar gr, Lexicon lex, double unaryPenalty, int endL, boolean viterbi, boolean sub, boolean score, boolean accurate, boolean variational, boolean useGoldPOS, boolean initializeCascade) { super(gr, lex, unaryPenalty, endL, viterbi, sub, score, accurate, variational, useGoldPOS, initializeCascade); } void doConstrainedMaxCScores(List<String> sentence, Grammar grammar, Lexicon lexicon, final boolean scale) { numSubStatesArray = grammar.numSubStates; maxcScore = new double[length][length + 1][numStates][]; maxcSplit = new int[length][length + 1][numStates][]; maxcChild = new int[length][length + 1][numStates][]; maxcChildSub = new int[length][length + 1][numStates][]; maxcLeftChild = new int[length][length + 1][numStates][]; maxcRightChild = new int[length][length + 1][numStates][]; maxcLeftChildSub = new int[length][length + 1][numStates][]; maxcRightChildSub = new int[length][length + 1][numStates][]; double initVal = Double.NEGATIVE_INFINITY; for (int start = 0; start < length; start++) { for (int end = start + 1; end <= length; end++) { for (int state=0; state<numSubStatesArray.length; state++){ if (!allowedStates[start][end][state]) continue; maxcSplit[start][end][state] = new int[numSubStatesArray[state]]; maxcChild[start][end][state] = new int[numSubStatesArray[state]]; maxcChildSub[start][end][state] = new int[numSubStatesArray[state]]; maxcLeftChild[start][end][state] = new int[numSubStatesArray[state]]; maxcRightChild[start][end][state] = new int[numSubStatesArray[state]]; maxcLeftChildSub[start][end][state] = new int[numSubStatesArray[state]]; maxcRightChildSub[start][end][state] = new int[numSubStatesArray[state]]; maxcScore[start][end][state] = new double[numSubStatesArray[state]]; Arrays.fill(maxcSplit[start][end][state], -1); Arrays.fill(maxcChild[start][end][state], -1); Arrays.fill(maxcChildSub[start][end][state], -1); Arrays.fill(maxcLeftChild[start][end][state], -1); Arrays.fill(maxcRightChild[start][end][state], -1); Arrays.fill(maxcLeftChildSub[start][end][state], -1); Arrays.fill(maxcRightChildSub[start][end][state], -1); Arrays.fill(maxcScore[start][end][state], initVal); } } } double logNormalizer = iScore[0][length][0][0]; // double thresh2 = threshold*logNormalizer; for (int diff = 1; diff <= length; diff++) { //System.out.print(diff + " "); for (int start = 0; start < (length - diff + 1); start++) { int end = start + diff; if (diff > 1) { // diff > 1: Try binary rules for (int pState=0; pState<numSubStatesArray.length; pState++){ if (!allowedStates[start][end][pState]) continue; BinaryRule[] parentRules = grammar.splitRulesWithP(pState); int nParentStates = numSubStatesArray[pState]; // == scores[0][0].length; for (int i = 0; i < parentRules.length; i++) { BinaryRule r = parentRules[i]; int lState = r.leftChildState; int rState = r.rightChildState; int narrowR = narrowRExtent[start][lState]; boolean iPossibleL = (narrowR < end); // can this left constituent leave space for a right constituent? if (!iPossibleL) { continue; } int narrowL = narrowLExtent[end][rState]; boolean iPossibleR = (narrowL >= narrowR); // can this right constituent fit next to the left constituent? if (!iPossibleR) { continue; } int min1 = narrowR; int min2 = wideLExtent[end][rState]; int min = (min1 > min2 ? min1 : min2); // can this right constituent stretch far enough to reach the left constituent? if (min > narrowL) { continue; } int max1 = wideRExtent[start][lState]; int max2 = narrowL; int max = (max1 < max2 ? max1 : max2); // can this left constituent stretch far enough to reach the right constituent? if (min > max) { continue; } double[][][] scores = r.getScores2(); int nLeftChildStates = numSubStatesArray[lState]; // == scores.length; int nRightChildStates = numSubStatesArray[rState]; // == scores[0].length; for (int split = min; split <= max; split++) { double ruleScore = 0; if (!allowedStates[start][split][lState]) continue; if (!allowedStates[split][end][rState]) continue; double scalingFactor = 0.0; if (scale) scalingFactor = Math.log(ScalingTools.calcScaleFactor( oScale[start][end][pState]+iScale[start][split][lState]+ iScale[split][end][rState]-iScale[0][length][0])); for (int lp = 0; lp < nLeftChildStates; lp++) { double lIS = iScore[start][split][lState][lp]; if (lIS == 0) continue; // if (lIS < thresh2) continue; //if (!allowedSubStates[start][split][lState][lp]) continue; for (int rp = 0; rp < nRightChildStates; rp++) { if (scores[lp][rp]==null) continue; double rIS = iScore[split][end][rState][rp]; if (rIS == 0) continue; double leftChildScore = maxcScore[start][split][lState][lp]; double rightChildScore = maxcScore[split][end][rState][rp]; if (leftChildScore==initVal||rightChildScore==initVal) continue; double gScore = leftChildScore + scalingFactor + rightChildScore; for (int np = 0; np < nParentStates; np++) { double pOS = oScore[start][end][pState][np]; if (pOS == 0) continue; double scoreToBeat = maxcScore[start][end][pState][np]; if (gScore < scoreToBeat) continue; // no chance of finding a better derivation double ruleS = scores[lp][rp][np]; if (ruleS == 0) continue; ruleScore = (pOS * ruleS * lIS * rIS) / logNormalizer; if (ruleScore==0) continue; if (doVariational){ ruleScore /= oScore[start][end][pState][np]/logNormalizer*iScore[start][end][pState][np]; } ruleScore = gScore + Math.log(ruleScore); if (ruleScore > scoreToBeat) { maxcScore[start][end][pState][np] = ruleScore; maxcSplit[start][end][pState][np] = split; maxcLeftChild[start][end][pState][np] = lState; maxcRightChild[start][end][pState][np] = rState; maxcLeftChildSub[start][end][pState][np] = lp; maxcRightChildSub[start][end][pState][np] = rp; } } } } } } } } else { // diff == 1 // We treat TAG --> word exactly as if it was a unary rule, except the score of the rule is // given by the lexicon rather than the grammar and that we allow another unary on top of it. //for (int tag : lexicon.getAllTags()){ for (int tag=0; tag<numSubStatesArray.length; tag++){ if (!allowedStates[start][end][tag]) continue; int nTagStates = numSubStatesArray[tag]; String word = sentence.get(start); //System.out.print("Attempting"); if (grammar.isGrammarTag(tag)) continue; //System.out.println("Computing maxcScore for span " +start + " to "+end); double[] lexiconScoreArray = lexicon.score(word, (short) tag, start, false,false); double lexiconScores = 0; for (int tp = 0; tp < nTagStates; tp++) { double pOS = oScore[start][end][tag][tp]; // if (pOS < thresh2) continue; double ruleS = lexiconScoreArray[tp]; lexiconScores = (pOS * ruleS) / logNormalizer; // The inside score of a word is 0.0f double scalingFactor = 0.0; if (doVariational) lexiconScores = 1; else if (scale) scalingFactor = Math.log(ScalingTools.calcScaleFactor( oScale[start][end][tag]-iScale[0][length][0])); maxcScore[start][end][tag][tp] = Math.log(lexiconScores) + scalingFactor; } } } // Try unary rules // Replacement for maxcScore[start][end], which is updated in batch double[][] maxcScoreStartEnd = new double[numStates][]; for (int i = 0; i < numStates; i++) { if (!allowedStates[start][end][i]) continue; maxcScoreStartEnd[i] = new double[numSubStatesArray[i]]; for (int j=0; j<numSubStatesArray[i]; j++){ maxcScoreStartEnd[i][j] = maxcScore[start][end][i][j]; } } // double[] unaryBonus = new double[numStates]; // int[] unaryChild = new int[numStates]; double[][] ruleScores = null; if (doVariational) ruleScores = new double[numStates][numStates]; boolean foundOne = false; for (int pState=0; pState<numSubStatesArray.length; pState++){ if (!allowedStates[start][end][pState]) continue; int nParentStates = numSubStatesArray[pState]; // == scores[0].length; UnaryRule[] unaries = grammar.getClosedSumUnaryRulesByParent(pState); if (doVariational) unaries = grammar.getUnaryRulesByParent(pState).toArray(new UnaryRule[0]); for (int r = 0; r < unaries.length; r++) { UnaryRule ur = unaries[r]; int cState = ur.childState; if ((pState == cState)) continue;// && (np == cp))continue; if (iScore[start][end][cState]==null) continue; double scalingFactor = 0.0; if (scale) scalingFactor = Math.log(ScalingTools.calcScaleFactor( oScale[start][end][pState]+iScale[start][end][cState] -iScale[0][length][0])); double[][] scores = ur.getScores2(); int nChildStates = numSubStatesArray[cState]; // == scores.length; double ruleScore = 0; for (int cp = 0; cp < nChildStates; cp++) { double cIS = iScore[start][end][cState][cp]; if (cIS == 0) continue; double childScore = maxcScore[start][end][cState][cp]; if (childScore==initVal) continue; if (scores[cp]==null) continue; for (int np = 0; np < nParentStates; np++) { double pOS = oScore[start][end][pState][np]; if (pOS < 0) continue; double gScore = scalingFactor + childScore; if (gScore < maxcScoreStartEnd[pState][np]) continue; double ruleS = scores[cp][np]; if (ruleS == 0) continue; ruleScore = (pOS * ruleS * cIS) / logNormalizer; foundOne = true; if (ruleScore==0) continue; if (doVariational){ ruleScore /= oScore[start][end][pState][np]/logNormalizer*iScore[start][end][pState][np]; } ruleScore = gScore + Math.log(ruleScore); if (ruleScore > maxcScoreStartEnd[pState][np]) { maxcScoreStartEnd[pState][np] = ruleScore; maxcChild[start][end][pState][np] = cState; maxcChildSub[start][end][pState][np] = cp; } } } } } // for (int i = 0; i < numStates; i++) { // if (maxcScore[start][end][i]+(1-unaryBonus[i]) > maxcScoreStartEnd[i]){ // maxcScore[start][end][i]+=(1-unaryBonus[i]); // } else { // maxcScore[start][end][i] = maxcScoreStartEnd[i]; // maxcChild[start][end][i] = unaryChild[i]; // } // } // if (foundOne&&doVariational) maxcScoreStartEnd = closeVariationalRules(ruleScores,start,end); maxcScore[start][end] = maxcScoreStartEnd; } } } public Tree<String> extractBestMaxRuleParse(int start, int end, List<String> sentence ) { return extractBestMaxRuleParse1(start, end, 0, 0, sentence); } /** * Returns the best parse for state "state", potentially starting with a unary rule */ public Tree<String> extractBestMaxRuleParse1(int start, int end, int state, int substate, List<String> sentence ) { //System.out.println(start+", "+end+";"); int cState = maxcChild[start][end][state][substate]; int cSubState = maxcChildSub[start][end][state][substate]; if (cState == -1) { return extractBestMaxRuleParse2(start, end, state, substate, sentence); } else { List<Tree<String>> child = new ArrayList<Tree<String>>(); child.add( extractBestMaxRuleParse2(start, end, cState, cSubState, sentence) ); String stateStr = (String) tagNumberer.object(state); if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2); totalUsedUnaries++; //System.out.println("Adding a unary spanning from "+start+" to "+end+". P: "+stateStr+" C: "+child.get(0).getLabel()); int intermediateNode = grammar.getUnaryIntermediate((short)state,(short)cState); // if (intermediateNode==0){ // System.out.println("Added a bad unary from "+start+" to "+end+". P: "+stateStr+" C: "+child.get(0).getLabel()); // } if (intermediateNode>0){ List<Tree<String>> restoredChild = new ArrayList<Tree<String>>(); nTimesRestoredUnaries++; String stateStr2 = (String)tagNumberer.object(intermediateNode); if (stateStr2.endsWith("^g")) stateStr2 = stateStr2.substring(0,stateStr2.length()-2); restoredChild.add(new Tree<String>(stateStr2, child)); //System.out.println("Restored a unary from "+start+" to "+end+": "+stateStr+" -> "+stateStr2+" -> "+child.get(0).getLabel()); return new Tree<String>(stateStr,restoredChild); } return new Tree<String>(stateStr, child); } } /** * Returns the best parse for state "state", but cannot start with a unary */ public Tree<String> extractBestMaxRuleParse2(int start, int end, int state, int substate, List<String> sentence ) { List<Tree<String>> children = new ArrayList<Tree<String>>(); String stateStr = (String)tagNumberer.object(state);//+""+start+""+end; if (stateStr.endsWith("^g")) stateStr = stateStr.substring(0,stateStr.length()-2); boolean posLevel = (end - start == 1); if (posLevel) { if (grammar.isGrammarTag(state)){ List<Tree<String>> childs = new ArrayList<Tree<String>>(); childs.add(new Tree<String>(sentence.get(start))); String stateStr2 = (String)tagNumberer.object(maxcChild[start][end][state][substate]);//+""+start+""+end; children.add(new Tree<String>(stateStr2,childs)); } else children.add(new Tree<String>(sentence.get(start))); } else { int split = maxcSplit[start][end][state][substate]; if (split == -1) { System.err.println("Warning: no symbol can generate the span from "+ start+ " to "+end+"."); System.err.println("The score is "+maxcScore[start][end][state]+" and the state is supposed to be "+stateStr); System.err.println("The insideScores are "+Arrays.toString(iScore[start][end][state])+" and the outsideScores are " +Arrays.toString(oScore[start][end][state])); System.err.println("The maxcScore is "+maxcScore[start][end][state]); //return extractBestMaxRuleParse2(start, end, maxcChild[start][end][state], sentence); return new Tree<String>("ROOT"); } int lState = maxcLeftChild[start][end][state][substate]; int lSubState = maxcLeftChildSub[start][end][state][substate]; int rState = maxcRightChild[start][end][state][substate]; int rSubState = maxcRightChildSub[start][end][state][substate]; Tree<String> leftChildTree = extractBestMaxRuleParse1(start, split, lState, lSubState, sentence); Tree<String> rightChildTree = extractBestMaxRuleParse1(split, end, rState, rSubState, sentence); children.add(leftChildTree); children.add(rightChildTree); } return new Tree<String>(stateStr, children); } }