/** * */ package edu.berkeley.nlp.discPCFG; import java.io.Serializable; import java.util.Arrays; import java.util.List; import edu.berkeley.nlp.PCFGLA.BinaryRule; import edu.berkeley.nlp.PCFGLA.ConditionalTrainer; import edu.berkeley.nlp.PCFGLA.Grammar; import edu.berkeley.nlp.PCFGLA.Rule; import edu.berkeley.nlp.PCFGLA.SimpleLexicon; import edu.berkeley.nlp.PCFGLA.SpanPredictor; import edu.berkeley.nlp.PCFGLA.UnaryRule; import edu.berkeley.nlp.syntax.StateSet; import edu.berkeley.nlp.math.DoubleArrays; import edu.berkeley.nlp.math.SloppyMath; import edu.berkeley.nlp.util.ArrayUtil; import edu.berkeley.nlp.util.Indexer; /** * @author petrov * */ public class DefaultLinearizer implements Linearizer, Serializable { Grammar grammar; SimpleLexicon lexicon; SpanPredictor spanPredictor; int[][] linearIndex; int nGrammarWeights, nLexiconWeights, nSpanWeights; int nWords, startSpanWeights; int nSubstates; int nClasses; int startIndexPrevious, startIndexNext; int startIndexFirst, startIndexLast; int startIndexBeginPair, startIndexEndPair; int startIndexPunctuation; double[] lastProbs; private static final long serialVersionUID = 2L; public DefaultLinearizer() { } /** * @param grammar * @param lexicon * @param threshold */ public DefaultLinearizer(Grammar grammar, SimpleLexicon lexicon, SpanPredictor sp) { this.grammar = grammar; this.lexicon = lexicon; this.spanPredictor = sp; this.nSubstates = (int)ArrayUtil.max(grammar.numSubStates); init(); } protected void init() { double[] tmp = null; if (!ConditionalTrainer.Options.lockGrammar){ tmp = getLinearizedGrammar(true); tmp = getLinearizedLexicon(true); } tmp = getLinearizedSpanPredictor(true); } public void delinearizeSpanPredictor(double [] probs) { if (spanPredictor==null) return; int ind = startSpanWeights, nDangerous = 0; if (spanPredictor.useFirstAndLast){ double[][] tmp = spanPredictor.firstWordScore; for (int i=0; i<tmp.length; i++){ for (int c=0; c<tmp[0].length; c++){ double val = probs[ind++]; if (Math.abs(val)>300) { nDangerous++; continue; } val = Math.exp(val); tmp[i][c] = val; } } tmp = spanPredictor.lastWordScore; for (int i=0; i<tmp.length; i++){ for (int c=0; c<tmp[0].length; c++){ double val = probs[ind++]; if (Math.abs(val)>300) { nDangerous++; continue; } val = Math.exp(val); tmp[i][c] = val; } } } if (spanPredictor.usePreviousAndNext){ double[][] tmp = spanPredictor.previousWordScore; for (int i=0; i<tmp.length; i++){ for (int c=0; c<tmp[0].length; c++){ double val = probs[ind++]; if (Math.abs(val)>300) { nDangerous++; continue; } val = Math.exp(val); tmp[i][c] = val; } } tmp = spanPredictor.nextWordScore; for (int i=0; i<tmp.length; i++){ for (int c=0; c<tmp[0].length; c++){ double val = probs[ind++]; if (Math.abs(val)>300) { nDangerous++; continue; } val = Math.exp(val); tmp[i][c] = val; } } } if (spanPredictor.useBeginAndEndPairs){ double[][] tmp = spanPredictor.beginPairScore; for (int i=0; i<tmp.length; i++){ for (int c=0; c<tmp[0].length; c++){ double val = probs[ind++]; if (Math.abs(val)>300) { nDangerous++; continue; } val = Math.exp(val); tmp[i][c] = val; } } tmp = spanPredictor.endPairScore; for (int i=0; i<tmp.length; i++){ for (int c=0; c<tmp[0].length; c++){ double val = probs[ind++]; if (Math.abs(val)>300) { nDangerous++; continue; } val = Math.exp(val); tmp[i][c] = val; } } } if (spanPredictor.usePunctuation){ double[][] tmp = spanPredictor.punctuationScores; for (int i=0; i<tmp.length; i++){ for (int c=0; c<tmp[0].length; c++){ double val = probs[ind++]; if (Math.abs(val)>300) { nDangerous++; continue; } val = Math.exp(val); tmp[i][c] = val; } } } if(nDangerous>0) System.out.println("Ignored "+nDangerous+" proposed span feature weights since they were dangerous."); } public void delinearizeGrammar(double [] probs) { int nDangerous = 0; for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){ int ind = bRule.identifier;//startIndex[ruleIndexer.indexOf(bRule)]; double[][][] scores = bRule.getScores2(); for (int j=0; j<scores.length; j++){ for (int k=0; k<scores[j].length; k++){ if (scores[j][k]!=null){ for (int l=0; l<scores[j][k].length; l++){ double val = Math.exp(probs[ind++]); if (SloppyMath.isVeryDangerous(val)) { System.out.println("dangerous value for rule "+ bRule+" "+probs[ind-1]); val = 0; nDangerous++; // continue; } scores[j][k][l] = val; } } } } } if (nDangerous>0) System.out.println("Left "+nDangerous+" binary rule weights unchanged since the proposed weight was dangerous."); nDangerous = 0; for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){ int ind = uRule.identifier;//startIndex[ruleIndexer.indexOf(uRule)]; if (uRule.childState==uRule.parentState) continue; double[][] scores = uRule.getScores2(); for (int j=0; j<scores.length; j++){ if (scores[j]!=null){ for (int k=0; k<scores[j].length; k++){ double val = Math.exp(probs[ind++]); //probs[ind++] if (SloppyMath.isVeryDangerous(val)) { System.out.println("dangerous value for rule "+ uRule+" "+probs[ind-1]); val = 0; nDangerous++; // continue; } scores[j][k] = val; } } } } if (nDangerous>0) System.out.println("Left "+nDangerous+" unary rule weights unchanged since the proposed weight was dangerous."); grammar.closedSumRulesWithParent = grammar.closedViterbiRulesWithParent = grammar.unaryRulesWithParent; grammar.closedSumRulesWithChild = grammar.closedViterbiRulesWithChild = grammar.unaryRulesWithC; // computePairsOfUnaries(); grammar.clearUnaryIntermediates(); grammar.makeCRArrays(); // return grammar; } public double[] getLinearizedGrammar(boolean update) { if (update){ // int nRules = grammar.binaryRuleMap.size() + grammar.unaryRuleMap.size(); nGrammarWeights = 0; for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){ // ruleIndexer.add(bRule); if (!grammar.isGrammarTag[bRule.parentState]){ System.out.println("Incorrect grammar tag"); } bRule.identifier = nGrammarWeights; // ruleIndexer.indexOf(bRule); // startIndex[bRule.identifier] = ; double[][][] scores = bRule.getScores2(); for (int j=0; j<scores.length; j++){ for (int k=0; k<scores[j].length; k++){ if (scores[j][k]!=null){ nGrammarWeights += scores[j][k].length; } } } } for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){ // ruleIndexer.add(uRule); uRule.identifier = nGrammarWeights; // ruleIndexer.indexOf(uRule); // startIndex[uRule.identifier] = nGrammarWeights; double[][] scores = uRule.getScores2(); for (int j=0; j<scores.length; j++){ if (scores[j]!=null){ nGrammarWeights += scores[j].length; } } } } double[] logProbs = new double[nGrammarWeights]; for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){ int ind = bRule.identifier; double[][][] scores = bRule.getScores2(); for (int j=0; j<scores.length; j++){ for (int k=0; k<scores[j].length; k++){ if (scores[j][k]!=null){ for (int l=0; l<scores[j][k].length; l++){ double val = Math.log(scores[j][k][l]); if (val==Double.NEGATIVE_INFINITY) { // toBeIgnored[ind] = true; // val=Double.MIN_VALUE; } logProbs[ind++] = val; } } } } } for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){ int ind = uRule.identifier; if (uRule.childState==uRule.parentState) continue; double[][] scores = uRule.getScores2(); for (int j=0; j<scores.length; j++){ if (scores[j]!=null){ for (int k=0; k<scores[j].length; k++){ double val = Math.log(scores[j][k]); if (val==Double.NEGATIVE_INFINITY) { // toBeIgnored[ind] = true; // val=Double.MIN_VALUE; } logProbs[ind++] = val; } } } } return logProbs; } public void delinearizeLexicon(double[] logProbs){ int nDangerous = 0; for (short tag=0; tag<lexicon.scores.length; tag++){ for (int word=0; word<lexicon.scores[tag][0].length; word++){ int index = linearIndex[tag][word]; for (int substate=0; substate<lexicon.numSubStates[tag]; substate++){ double val = Math.exp(logProbs[index++]); if (SloppyMath.isVeryDangerous(val)) { System.out.println("dangerous value when delinearizng lexicon "+lexicon.scores[tag][substate][word]); System.out.println("Word "+lexicon.wordIndexer.get(lexicon.tagWordIndexer[tag].get(word))+" tag "+logProbs[index-1]); val = 0; nDangerous++; // continue; } lexicon.scores[tag][substate][word] = val; } } } if (nDangerous>0) System.out.println("Left "+nDangerous+" lexicon weights unchanged since the proposed weight was dangerous."); // return lexicon; } public double[] getLinearizedLexicon(){ return getLinearizedLexicon(false); } public double[] getLinearizedLexicon(boolean update){ if (update) { nLexiconWeights = 0; for (short tag=0; tag<lexicon.scores.length; tag++){ for (int word=0; word<lexicon.scores[tag][0].length; word++){ nLexiconWeights += lexicon.numSubStates[tag]; } } } double[] logProbs = new double[nLexiconWeights]; if (update) linearIndex = new int[lexicon.expectedCounts.length][]; int index = 0; for (short tag=0; tag<lexicon.scores.length; tag++){ if (update) linearIndex[tag] = new int[lexicon.scores[tag][0].length]; for (int word=0; word<lexicon.scores[tag][0].length; word++){ if (update) linearIndex[tag][word] = index + nGrammarWeights; for (int substate=0; substate<lexicon.numSubStates[tag]; substate++){ double val = Math.log(lexicon.scores[tag][substate][word]); if (val==Double.NEGATIVE_INFINITY) { // toBeIgnored[index] = true; // val=Double.MIN_VALUE; } logProbs[index++] = val; } } } return logProbs; } // // public int getLinearIndex(Rule rule){ // return startIndex[ruleIndexer.indexOf(rule)]; // } public int getLinearIndex(String word, int tag){ return getLinearIndex(lexicon.wordIndexer.indexOf(word), tag); } public int getLinearIndex(int globalWordIndex, int tag){ int tagSpecificWordIndex = lexicon.tagWordIndexer[tag].indexOf(globalWordIndex); if (tagSpecificWordIndex==-1) return -1; return linearIndex[tag][tagSpecificWordIndex]; } public int dimension() { return nGrammarWeights + nLexiconWeights + nSpanWeights; } public void increment(double[] counts, StateSet stateSet, int tag, double[] weights, boolean isGold) { int globalSigIndex = stateSet.sigIndex; if (globalSigIndex != -1){ int startIndexWord = getLinearIndex(globalSigIndex, tag); if (startIndexWord>=0){ //System.out.println("incrementing scores for unseen signature tag"); for (int i=0; i<nSubstates; i++){ if (isGold) counts[startIndexWord++] += weights[i]; else counts[startIndexWord++] -= weights[i]; } } } int startIndexWord = getLinearIndex(stateSet.wordIndex, tag); if (startIndexWord>=0) { for (int i=0; i<nSubstates; i++){ if (isGold) counts[startIndexWord++] += weights[i]; else counts[startIndexWord++] -= weights[i]; weights[i]=0; } } else { for (int i=0; i<nSubstates; i++){ weights[i]=0; } } } public void increment(double[] counts, UnaryRule rule, double[] weights, boolean isGold) { int thisStartIndex = rule.identifier; int curInd = 0; int nSubstatesParent = (rule.parentState==0) ? 1 : nSubstates; for (int cp = 0; cp < nSubstates; cp++) { // if (scores[cp]==null) continue; for (int np = 0; np < nSubstatesParent; np++) { if (isGold) counts[thisStartIndex++] += weights[curInd]; else counts[thisStartIndex++] -= weights[curInd]; weights[curInd++]=0; } } } public void increment(double[] counts, BinaryRule rule, double[] weights, boolean isGold) { int thisStartIndex = rule.identifier; int curInd = 0; for (int lp = 0; lp < nSubstates; lp++) { for (int rp = 0; rp < nSubstates; rp++) { // if (scores[cp]==null) continue; for (int np = 0; np < nSubstates; np++) { if (isGold) counts[thisStartIndex++] += weights[curInd]; else counts[thisStartIndex++] -= weights[curInd]; weights[curInd++]=0; } } } } public void delinearizeWeights(double[] logWeights) { int nGrZ=0, nLexZ=0, nSpZ=0; int tmpI = 0; if (!ConditionalTrainer.Options.lockGrammar){ for (int i=0; i<nGrammarWeights; i++){ double val = logWeights[tmpI++]; if (val==0) nGrZ++; } delinearizeGrammar(logWeights); for (int i=0; i<nLexiconWeights; i++){ double val = logWeights[tmpI++]; if (val==0) nLexZ++; } delinearizeLexicon(logWeights); } for (int i=0; i<nSpanWeights; i++){ double val = logWeights[tmpI++]; if (val==0) nSpZ++; } delinearizeSpanPredictor(logWeights); lastProbs = logWeights.clone(); System.out.println("Proposed vector has "+(nGrZ+nLexZ+nSpZ)+"/"+(nGrammarWeights+nLexiconWeights+nSpanWeights)+ " zeros [grammar: "+nGrZ+"/"+nGrammarWeights+ ", lexicon: "+nLexZ+"/"+nLexiconWeights+", span: "+nSpZ+"/"+nSpanWeights+"]."); } public double[] getLinearizedSpanPredictor(boolean update) { if (spanPredictor==null) { nSpanWeights = 0; return new double[0]; } nWords = spanPredictor.nWords; nSpanWeights = spanPredictor.nFeatures; startSpanWeights = nGrammarWeights + nLexiconWeights; nClasses = spanPredictor.getNClasses(); double[] logProbs = new double[nSpanWeights]; int ind = 0; if (update){ startIndexFirst = startSpanWeights; startIndexLast = startIndexFirst + (nWords*nClasses); startIndexPrevious = (spanPredictor.useFirstAndLast) ? startIndexFirst+(2*nWords*nClasses) : startIndexFirst; startIndexNext = startIndexPrevious + (nWords*nClasses); startIndexBeginPair = (spanPredictor.usePreviousAndNext) ? startIndexPrevious+(2*nWords*nClasses) : startIndexPrevious; startIndexEndPair = (spanPredictor.useBeginAndEndPairs) ? startIndexBeginPair + (spanPredictor.beginPairScore.length*nClasses) : startIndexBeginPair; startIndexPunctuation = (spanPredictor.useBeginAndEndPairs) ? startIndexBeginPair+((spanPredictor.beginPairScore.length+spanPredictor.endPairScore.length)*nClasses) : startIndexBeginPair; } if (spanPredictor.useFirstAndLast){ double[][] tmp = spanPredictor.firstWordScore; for (int i=0; i<tmp.length; i++){ for (int c=0; c<tmp[0].length; c++){ logProbs[ind++] = Math.log(tmp[i][c]); } } tmp = spanPredictor.lastWordScore; for (int i=0; i<tmp.length; i++) { for (int c=0; c<tmp[0].length; c++){ logProbs[ind++] = Math.log(tmp[i][c]); } } } if (spanPredictor.usePreviousAndNext){ double[][] tmp = spanPredictor.previousWordScore; for (int i=0; i<tmp.length; i++){ for (int c=0; c<tmp[0].length; c++){ logProbs[ind++] = Math.log(tmp[i][c]); } } tmp = spanPredictor.nextWordScore; for (int i=0; i<tmp.length; i++) { for (int c=0; c<tmp[0].length; c++){ logProbs[ind++] = Math.log(tmp[i][c]); } } } if (spanPredictor.useBeginAndEndPairs){ double[][] tmp = spanPredictor.beginPairScore; for (int i=0; i<tmp.length; i++) { for (int c=0; c<tmp[0].length; c++){ logProbs[ind++] = Math.log(tmp[i][c]); } } tmp = spanPredictor.endPairScore; for (int i=0; i<tmp.length; i++) { for (int c=0; c<tmp[0].length; c++){ logProbs[ind++] = Math.log(tmp[i][c]); } } } if (spanPredictor.usePunctuation){ double[][] tmp = spanPredictor.punctuationScores; for (int i=0; i<tmp.length; i++){ for (int c=0; c<tmp[0].length; c++){ logProbs[ind++] = Math.log(tmp[i][c]); } } } return logProbs; } public double[] getLinearizedWeights() { double[] initialGrammarWeights = (ConditionalTrainer.Options.lockGrammar) ? new double[0] : getLinearizedGrammar(); double[] initialLexiconWeights = (ConditionalTrainer.Options.lockGrammar) ? new double[0] : getLinearizedLexicon(); double[] initialSpanWeights = getLinearizedSpanPredictor(); double[] curWeights = new double[dimension()]; int j=0; for (int i=0; i<initialGrammarWeights.length; i++) curWeights[j++] = initialGrammarWeights[i]; for (int i=0; i<initialLexiconWeights.length; i++) curWeights[j++] = initialLexiconWeights[i]; for (int i=0; i<initialSpanWeights.length; i++) curWeights[j++] = initialSpanWeights[i]; return curWeights; } public Grammar getGrammar() { return grammar; } public SimpleLexicon getLexicon() { return lexicon; } public SpanPredictor getSpanPredictor() { return spanPredictor; } public double[] getLinearizedGrammar() { return getLinearizedGrammar(false); } public void increment(double[] counts, List<StateSet> sentence, double[][][] weights, boolean isGold) { int length = sentence.size(); int firstIndex, lastIndex; int previousIndex=-1, nextIndex=-1; if (spanPredictor.usePunctuation){ int[][] punctSignatures = spanPredictor.getPunctuationSignatures(sentence); for (int start = 0; start <= length-spanPredictor.minSpanLength; start++) { for (int end = start + spanPredictor.minSpanLength; end <= length; end++) { int sig = punctSignatures[start][end]; if (sig==-1) continue; sig *= nClasses; for (int c=0; c<nClasses; c++){ counts[startIndexPunctuation+sig+c] -= weights[start][end][c]; } } } } for (int start = 0; start <= length-spanPredictor.minSpanLength; start++) { StateSet stateSet = sentence.get(start); firstIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex; if (spanPredictor.useOnlyWords) firstIndex = stateSet.wordIndex; double[] total = new double[nClasses]; for (int end = start + spanPredictor.minSpanLength; end <= length; end++) { for (int c=0; c<total.length; c++){ total[c] += weights[start][end][c]; } } int firstI = startSpanWeights + (firstIndex*nClasses); int prevI = startIndexPrevious + (previousIndex*nClasses); for (int c=0; c<total.length; c++){ double t = total[c]; if (t==0) continue; if (spanPredictor.useFirstAndLast){ counts[firstI+c] -= t; } if (spanPredictor.usePreviousAndNext && previousIndex!=-1){ counts[prevI+c] -= t; } } if (spanPredictor.useBeginAndEndPairs && previousIndex!=-1) { int beginI = (spanPredictor.getBeginIndex(previousIndex, firstIndex)*nClasses); if (beginI>=0){ beginI += startIndexBeginPair; for (int c=0; c<total.length; c++){ double t = total[c]; if (t==0) continue; counts[beginI+c] -= t; } } } previousIndex = firstIndex; } for (int end = length; end >= spanPredictor.minSpanLength; end--) { StateSet stateSet = sentence.get(end-1); lastIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex; if (spanPredictor.useOnlyWords) lastIndex = stateSet.wordIndex; double[] total = new double[spanPredictor.getNClasses()]; for (int start = 0; start <= end-spanPredictor.minSpanLength; start++) { for (int c=0; c<total.length; c++){ total[c] += weights[start][end][c]; } } int lastI = startIndexLast + (lastIndex*nClasses); int nextI = startIndexNext + (nextIndex*nClasses); for (int c=0; c<total.length; c++){ if (spanPredictor.useFirstAndLast){ counts[lastI+c] -= total[c]; } if (spanPredictor.usePreviousAndNext && nextIndex!=-1){ counts[nextI+c] -= total[c]; } } if (spanPredictor.useBeginAndEndPairs && nextIndex!=-1){ int endI = spanPredictor.getEndIndex(lastIndex, nextIndex)*nClasses; if (endI>=0){ endI += startIndexEndPair; for (int c=0; c<total.length; c++){ counts[endI+c] -= total[c]; } } } nextIndex = lastIndex; } } public double[] getLinearizedSpanPredictor() { return getLinearizedSpanPredictor(false); } /* (non-Javadoc) * @see edu.berkeley.nlp.classify.Linearizer#getNGrammarWeights() */ public int getNGrammarWeights() { return nGrammarWeights; } /* (non-Javadoc) * @see edu.berkeley.nlp.classify.Linearizer#getNLexiconWeights() */ public int getNLexiconWeights() { return nLexiconWeights; } /* (non-Javadoc) * @see edu.berkeley.nlp.classify.Linearizer#getNSpanWeights() */ public int getNSpanWeights() { return nSpanWeights; } }