/** * */ package edu.berkeley.nlp.discPCFG; import java.util.List; import edu.berkeley.nlp.PCFGLA.BinaryRule; import edu.berkeley.nlp.PCFGLA.ConditionalTrainer; import edu.berkeley.nlp.PCFGLA.Grammar; import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveBinaryRule; import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveGrammar; import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveLexicalRule; import edu.berkeley.nlp.PCFGLA.HierarchicalAdaptiveUnaryRule; import edu.berkeley.nlp.PCFGLA.HierarchicalBinaryRule; import edu.berkeley.nlp.PCFGLA.HierarchicalFullyConnectedAdaptiveLexicon; import edu.berkeley.nlp.PCFGLA.HierarchicalGrammar; import edu.berkeley.nlp.PCFGLA.HierarchicalUnaryRule; import edu.berkeley.nlp.PCFGLA.SimpleLexicon; import edu.berkeley.nlp.PCFGLA.SpanPredictor; import edu.berkeley.nlp.PCFGLA.UnaryRule; import edu.berkeley.nlp.syntax.StateSet; import edu.berkeley.nlp.syntax.StateSetWithFeatures; import edu.berkeley.nlp.util.ArrayUtil; import edu.berkeley.nlp.math.DoubleArrays; import edu.berkeley.nlp.math.SloppyMath; /** * @author petrov * */ public class HiearchicalAdaptiveLinearizer extends HierarchicalLinearizer { private static final long serialVersionUID = 1L; HierarchicalAdaptiveGrammar grammar; HierarchicalFullyConnectedAdaptiveLexicon lexicon; public HiearchicalAdaptiveLinearizer(Grammar grammar, SimpleLexicon lexicon, SpanPredictor sp, int fLevel) { this.grammar = (HierarchicalAdaptiveGrammar)grammar; lexicon.explicitlyComputeScores(fLevel); grammar.closedSumRulesWithParent = grammar.closedViterbiRulesWithParent = grammar.unaryRulesWithParent; grammar.closedSumRulesWithChild = grammar.closedViterbiRulesWithChild = grammar.unaryRulesWithC; grammar.clearUnaryIntermediates(); grammar.makeCRArrays(); this.lexicon = (HierarchicalFullyConnectedAdaptiveLexicon)lexicon; this.spanPredictor = sp; this.finalLevel = fLevel; this.nSubstates = (int)ArrayUtil.max(grammar.numSubStates); init(); computeMappings(); } public SimpleLexicon getLexicon() { return lexicon; } public Grammar getGrammar() { return grammar; } public double[] getLinearizedLexicon(boolean update) { if(update){ nLexiconWeights = 0; for (short tag=0; tag<lexicon.rules.length; tag++){ for (int word=0; word<lexicon.rules[tag].length; word++){ lexicon.rules[tag][word].identifier = nLexiconWeights + nGrammarWeights; nLexiconWeights += lexicon.rules[tag][word].getFinalLevel().size(); //lexicon.rules[tag][word].nParam; } } } double[] logProbs = new double[nLexiconWeights]; // if (update) linearIndex = new int[lexicon.rules.length][]; int index = 0; for (short tag=0; tag<lexicon.rules.length; tag++){ // if (update) linearIndex[tag] = new int[lexicon.rules[tag].length]; for (int word=0; word<lexicon.rules[tag].length; word++){ // if (update) linearIndex[tag][word] = index + nGrammarWeights; List<Double> vals = lexicon.rules[tag][word].getFinalLevel(); for (Double val : vals){ logProbs[index++] = val; } } } if (index!=logProbs.length) System.out.println("unequal length in lexicon"); return logProbs; } public void delinearizeLexicon(double[] logProbs, boolean usingOnlyLastLevel) { for (short tag=0; tag<lexicon.rules.length; tag++){ for (int word=0; word<lexicon.rules[tag].length; word++){ lexicon.rules[tag][word].updateScores(logProbs); lexicon.rules[tag][word].explicitlyComputeScores(finalLevel, usingOnlyLastLevel); } } } public void delinearizeLexicon(double[] logProbs) { for (short tag=0; tag<lexicon.rules.length; tag++){ for (int word=0; word<lexicon.rules[tag].length; word++){ lexicon.rules[tag][word].updateScores(logProbs); lexicon.rules[tag][word].explicitlyComputeScores(finalLevel, false); } } } public void increment(double[] counts, StateSet stateSet, int tag, double[] weights, boolean isGold) { if (!(stateSet instanceof StateSetWithFeatures)){ int globalSigIndex = stateSet.sigIndex; if (globalSigIndex != -1){ int tagSpecificWordIndex = lexicon.tagWordIndexer[tag].indexOf(globalSigIndex); if (tagSpecificWordIndex>=0){ HierarchicalAdaptiveLexicalRule rule = lexicon.rules[tag][tagSpecificWordIndex]; int startIndexWord = rule.identifier; short[] mapping = rule.mapping; for (int i=0; i<nSubstates; i++){ if (isGold) counts[startIndexWord + mapping[i]] += weights[i]; else counts[startIndexWord + mapping[i]] -= weights[i]; } } } int globalWordIndex = stateSet.wordIndex; int tagSpecificWordIndex = lexicon.tagWordIndexer[tag].indexOf(globalWordIndex); if (tagSpecificWordIndex<0){ for (int i=0; i<nSubstates; i++){ weights[i]=0; } } else { HierarchicalAdaptiveLexicalRule rule = lexicon.rules[tag][tagSpecificWordIndex]; int startIndexWord = rule.identifier; short[] mapping = rule.mapping; for (int i=0; i<nSubstates; i++){ if (isGold) counts[startIndexWord + mapping[i]] += weights[i]; else counts[startIndexWord + mapping[i]] -= weights[i]; weights[i] = 0; } } } else { StateSetWithFeatures stateSetF = (StateSetWithFeatures) stateSet; for (int f : stateSetF.features){ if (f<0) continue; int tagF = lexicon.tagWordIndexer[tag].indexOf(f); if (tagF<0) continue; HierarchicalAdaptiveLexicalRule rule = lexicon.rules[tag][tagF]; int startIndexWord = rule.identifier; short[] mapping = rule.mapping; for (int i=0; i<nSubstates; i++){ if (isGold) counts[startIndexWord + mapping[i]] += weights[i]; else counts[startIndexWord + mapping[i]] -= weights[i]; } } for (int i=0; i<nSubstates; i++){ weights[i] = 0; } } } public void increment(double[] counts, BinaryRule rule, double[] weights, boolean isGold) { HierarchicalAdaptiveBinaryRule hr = (HierarchicalAdaptiveBinaryRule)rule; int thisStartIndex = hr.identifier; if (true){ for (int curInd=0; curInd<hr.nParam; curInd++){ double val = weights[curInd]; if(val>0){ weights[curInd]=0; if (isGold) counts[thisStartIndex + curInd] += val; else counts[thisStartIndex + curInd] -= val; } // System.out.println(counts[thisStartIndex + curInd]); } } else { int curInd=0; for (int lp = 0; lp < nSubstates; lp++) { for (int rp = 0; rp < nSubstates; rp++) { // if (scores[cp]==null) continue; for (int np = 0; np < nSubstates; np++) { double val = weights[curInd]; short mapping[][][] = hr.mapping; if (val>0){ counts[thisStartIndex + mapping[lp][rp][np]] += val; weights[curInd]=0; } curInd++; } } } } } public void increment(double[] counts, UnaryRule rule, double[] weights, boolean isGold) { HierarchicalAdaptiveUnaryRule hr = (HierarchicalAdaptiveUnaryRule)rule; int thisStartIndex = hr.identifier; if (true){ // if (hr.parentState==0) // System.out.println("letss ee"); for (int curInd=0; curInd<hr.nParam; curInd++){ double val = weights[curInd]; if(val>0){ weights[curInd]=0; if (isGold) counts[thisStartIndex + curInd] += val; else counts[thisStartIndex + curInd] -= val; } // System.out.println(counts[thisStartIndex + curInd]); } } else { int curInd = 0; if (rule.parentState==-1){ for (int cp = 0; cp < nSubstates; cp++) { double val = weights[curInd]; short[][] mapping = hr.mapping; if (val>0){ if (isGold) counts[thisStartIndex + mapping[cp][0]] += val; else counts[thisStartIndex + mapping[cp][0]] -= val; weights[curInd]=0; } curInd++; } return; } for (int cp = 0; cp < nSubstates; cp++) { // if (scores[cp]==null) continue; for (int np = 0; np < nSubstates; np++) { double val = weights[curInd]; short[][] mapping = hr.mapping; if (val>0){ if (isGold) counts[thisStartIndex + mapping[cp][np]] += val; else counts[thisStartIndex + mapping[cp][np]] -= val; weights[curInd]=0; } curInd++; } } } } public void delinearizeGrammar(double[] probs) { int nDangerous = 0; for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){ HierarchicalAdaptiveBinaryRule hRule = (HierarchicalAdaptiveBinaryRule)bRule; hRule.updateScores(probs); } if (nDangerous>0) System.out.println("Left "+nDangerous+" binary rule weights unchanged since the proposed weight was dangerous."); nDangerous = 0; for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){ HierarchicalAdaptiveUnaryRule hRule = (HierarchicalAdaptiveUnaryRule)uRule; hRule.updateScores(probs); } if (nDangerous>0) System.out.println("Left "+nDangerous+" unary rule weights unchanged since the proposed weight was dangerous."); grammar.explicitlyComputeScores(finalLevel); grammar.closedSumRulesWithParent = grammar.closedViterbiRulesWithParent = grammar.unaryRulesWithParent; grammar.closedSumRulesWithChild = grammar.closedViterbiRulesWithChild = grammar.unaryRulesWithC; // computePairsOfUnaries(); grammar.clearUnaryIntermediates(); grammar.makeCRArrays(); // return grammar; } public double[] getLinearizedGrammar(boolean update) { if (update){ // int nRules = grammar.binaryRuleMap.size() + grammar.unaryRuleMap.size(); // startIndex = new int[nRules]; nGrammarWeights = 0; for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){ HierarchicalAdaptiveBinaryRule hRule = (HierarchicalAdaptiveBinaryRule)bRule; if (!grammar.isGrammarTag[bRule.parentState]){ System.out.println("Incorrect grammar tag"); } bRule.identifier = nGrammarWeights; nGrammarWeights += hRule.nParam; } for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){ HierarchicalAdaptiveUnaryRule hRule = (HierarchicalAdaptiveUnaryRule)uRule; uRule.identifier = nGrammarWeights; nGrammarWeights += hRule.nParam; } } double[] logProbs = new double[nGrammarWeights]; for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){ HierarchicalAdaptiveBinaryRule hRule = (HierarchicalAdaptiveBinaryRule)bRule; int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)]; List<Double> vals = hRule.getFinalLevel(); for (Double val : vals){ logProbs[ind++] = val; } } for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){ HierarchicalAdaptiveUnaryRule hRule = (HierarchicalAdaptiveUnaryRule)uRule; int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)]; if (uRule.childState==uRule.parentState) continue; List<Double> vals = hRule.getFinalLevel(); for (Double val : vals){ logProbs[ind++] = val; } } return logProbs; } public void delinearizeLexiconWeights(double[] logWeights) { int nGrZ=0, nLexZ=0, nSpZ=0; int tmpI = 0; for (int i=0; i<nGrammarWeights; i++){ double val = logWeights[tmpI++]; if (val==0) nGrZ++; } for (int i=0; i<nLexiconWeights; i++){ double val = logWeights[tmpI++]; if (val==0) nLexZ++; } delinearizeLexicon(logWeights, true); } }