/**
*
*/
package edu.berkeley.nlp.discPCFG;
import java.io.Serializable;
import edu.berkeley.nlp.PCFGLA.BinaryRule;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.HierarchicalBinaryRule;
import edu.berkeley.nlp.PCFGLA.HierarchicalGrammar;
import edu.berkeley.nlp.PCFGLA.HierarchicalLexicon;
import edu.berkeley.nlp.PCFGLA.HierarchicalUnaryRule;
import edu.berkeley.nlp.PCFGLA.Rule;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.SpanPredictor;
import edu.berkeley.nlp.PCFGLA.UnaryRule;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.ArrayUtil;
/**
* similar to cascading linearizer but doesnt compute the grammars explicitly
* instead uses hierarchical rules and merges back unused splits
* @author petrov
*
*/
public class HierarchicalLinearizer extends DefaultLinearizer {
private static final long serialVersionUID = 1L;
HierarchicalGrammar grammar;
HierarchicalLexicon lexicon;
int finalLevel;
int[][] lexiconMapping;
int[][][] unaryMapping;
int[][][][] binaryMapping;
public HierarchicalLinearizer(){}
/**
* @param grammar
* @param lexicon
*/
public HierarchicalLinearizer(Grammar grammar, SimpleLexicon lexicon, SpanPredictor sp, int fLevel) {
this.grammar = (HierarchicalGrammar)grammar;
this.lexicon = (HierarchicalLexicon)lexicon;
this.spanPredictor = sp;
this.finalLevel = fLevel;
this.nSubstates = (int)ArrayUtil.max(grammar.numSubStates);
init();
computeMappings();
}
protected void computeMappings(){
lexiconMapping = new int[finalLevel+1][nSubstates];
unaryMapping = new int[finalLevel+1][nSubstates][nSubstates];
binaryMapping = new int[finalLevel+1][nSubstates][nSubstates][nSubstates];
int[] divisors = new int[finalLevel+1];
for (int i=0; i<=finalLevel; i++){
divisors[i] = (int)Math.pow(2, finalLevel-i);
}
for (int level=1; level<=finalLevel; level++){
int div = divisors[level];
int l = (int)Math.pow(2,level);
int[][] tmpU = new int[l][l];
int[][][] tmpB = new int[l][l][l];
int indU=0, indB=0;
for (int i=0; i<l; i++){
for (int j=0; j<l; j++){
tmpU[i][j] = indU++;
for (int k=0; k<l; k++){
tmpB[i][j][k] = indB++;
}
}
}
for (int i=0; i<nSubstates; i++){
lexiconMapping[level][i] = i/div;
for (int j=0; j<nSubstates; j++){
unaryMapping[level][i][j] = tmpU[i/div][j/div];
for (int k=0; k<nSubstates; k++){
binaryMapping[level][i][j][k] = tmpB[i/div][j/div][k/div];
}
}
}
}
}
// public void delinearizeSpanPredictor(double[] logProbs) {
//
// }
public void delinearizeGrammar(double[] probs) {
int nDangerous = 0;
for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){
HierarchicalBinaryRule hRule = (HierarchicalBinaryRule)bRule;
int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)];
double[][][] scores = hRule.getLastLevel();
for (int j=0; j<scores.length; j++){
for (int k=0; k<scores[j].length; k++){
if (scores[j][k]!=null){
for (int l=0; l<scores[j][k].length; l++){
double val = probs[ind++];
if (SloppyMath.isVeryDangerous(val)) {
nDangerous++;
continue;
}
scores[j][k][l] = val;
}
}
}
}
}
if (nDangerous>0) System.out.println("Left "+nDangerous+" binary rule weights unchanged since the proposed weight was dangerous.");
nDangerous = 0;
for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){
HierarchicalUnaryRule hRule = (HierarchicalUnaryRule)uRule;
int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)];
if (uRule.childState==uRule.parentState) continue;
double[][] scores = hRule.getLastLevel();
for (int j=0; j<scores.length; j++){
if (scores[j]!=null){
for (int k=0; k<scores[j].length; k++){
double val = probs[ind++];
if (SloppyMath.isVeryDangerous(val)) {
nDangerous++;
continue;
}
scores[j][k] = val;
}
}
}
}
if (nDangerous>0) System.out.println("Left "+nDangerous+" unary rule weights unchanged since the proposed weight was dangerous.");
grammar.explicitlyComputeScores(finalLevel);
grammar.closedSumRulesWithParent = grammar.closedViterbiRulesWithParent = grammar.unaryRulesWithParent;
grammar.closedSumRulesWithChild = grammar.closedViterbiRulesWithChild = grammar.unaryRulesWithC;
// computePairsOfUnaries();
grammar.clearUnaryIntermediates();
grammar.makeCRArrays();
// return grammar;
}
public void delinearizeLexicon(double[] logProbs) {
int nDangerous = 0;
for (short tag=0; tag<lexicon.hierarchicalScores.length; tag++){
for (int word=0; word<lexicon.hierarchicalScores[tag].length; word++){
int index = linearIndex[tag][word];
double[] vals = lexicon.getLastLevel(tag,word);
for (int substate=0; substate<vals.length; substate++){
double val = logProbs[index++];
if (SloppyMath.isVeryDangerous(val)) {
nDangerous++;
continue;
}
vals[substate] = val;
}
}
}
if (nDangerous>0) System.out.println("Left "+nDangerous+" lexicon weights unchanged since the proposed weight was dangerous.");
lexicon.explicitlyComputeScores(finalLevel);
// System.out.println(lexicon);
// return lexicon;
}
public double[] getLinearizedGrammar(boolean update) {
if (update){
// int nRules = grammar.binaryRuleMap.size() + grammar.unaryRuleMap.size();
// startIndex = new int[nRules];
nGrammarWeights = 0;
for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){
HierarchicalBinaryRule hRule = (HierarchicalBinaryRule)bRule;
// ruleIndexer.add(hRule);
if (!grammar.isGrammarTag[bRule.parentState]){ System.out.println("Incorrect grammar tag"); }
bRule.identifier = nGrammarWeights;
double[][][] scores = hRule.getLastLevel();
for (int j=0; j<scores.length; j++){
for (int k=0; k<scores[j].length; k++){
if (scores[j][k]!=null){
nGrammarWeights += scores[j][k].length;
}
}
}
}
for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){
HierarchicalUnaryRule hRule = (HierarchicalUnaryRule)uRule;
// ruleIndexer.add(hRule);
// startIndex[ruleIndexer.indexOf(uRule)] = nGrammarWeights;
uRule.identifier = nGrammarWeights;
double[][] scores = hRule.getLastLevel();
for (int j=0; j<scores.length; j++){
if (scores[j]!=null){
nGrammarWeights += scores[j].length;
}
}
}
}
double[] logProbs = new double[nGrammarWeights];
for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){
HierarchicalBinaryRule hRule = (HierarchicalBinaryRule)bRule;
int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)];
double[][][] scores = hRule.getLastLevel();
for (int j=0; j<scores.length; j++){
for (int k=0; k<scores[j].length; k++){
if (scores[j][k]!=null){
for (int l=0; l<scores[j][k].length; l++){
double val = scores[j][k][l];
logProbs[ind++] = val;
}
}
}
}
}
for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){
HierarchicalUnaryRule hRule = (HierarchicalUnaryRule)uRule;
int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)];
if (uRule.childState==uRule.parentState) continue;
double[][] scores = hRule.getLastLevel();
for (int j=0; j<scores.length; j++){
if (scores[j]!=null){
for (int k=0; k<scores[j].length; k++){
double val = scores[j][k];
logProbs[ind++] = val;
}
}
}
}
return logProbs;
}
public double[] getLinearizedLexicon(boolean update) {
if(update){
nLexiconWeights = 0;
int[] substates = new int[finalLevel+1];
for (int i=0; i<=finalLevel; i++) substates[i] = (int)Math.pow(2,i);
for (short tag=0; tag<lexicon.hierarchicalScores.length; tag++){
for (int word=0; word<lexicon.hierarchicalScores[tag].length; word++){
nLexiconWeights += lexicon.getLastLevel(tag,word).length;
}
}
}
double[] logProbs = new double[nLexiconWeights];
if (update) linearIndex = new int[lexicon.hierarchicalScores.length][];
int index = 0;
for (short tag=0; tag<lexicon.hierarchicalScores.length; tag++){
if (update) linearIndex[tag] = new int[lexicon.hierarchicalScores[tag].length];
for (int word=0; word<lexicon.hierarchicalScores[tag].length; word++){
if (update) linearIndex[tag][word] = index + nGrammarWeights;
double[] vals = lexicon.getLastLevel(tag,word);
for (int substate=0; substate<vals.length; substate++){
double val = vals[substate];
logProbs[index++] = val;
}
}
}
if (index!=logProbs.length)
System.out.println("unequal length in lexicon");
return logProbs;
}
public int getLinearIndex(int globalWordIndex, int tag){
int tagSpecificWordIndex = lexicon.tagWordIndexer[tag].indexOf(globalWordIndex);
if (tagSpecificWordIndex==-1) return -1;
return linearIndex[tag][tagSpecificWordIndex];
}
public int dimension() {
return nGrammarWeights + nLexiconWeights + nSpanWeights;
}
public void increment(double[] counts, StateSet stateSet, int tag, double[] weights, boolean isGold) {
int globalSigIndex = stateSet.sigIndex;
if (globalSigIndex != -1){
int startIndexWord = getLinearIndex(globalSigIndex, tag);
if (startIndexWord>=0){
int finalLevel = lexicon.getFinalLevel(globalSigIndex, tag);
for (int i=0; i<nSubstates; i++){
if (isGold) counts[startIndexWord + lexiconMapping[finalLevel][i]] += weights[i];
else counts[startIndexWord + lexiconMapping[finalLevel][i]] -= weights[i];
}
}
}
int globalWordIndex = stateSet.wordIndex;
int startIndexWord = getLinearIndex(globalWordIndex, tag);
if (startIndexWord>=0) {
int finalLevel = lexicon.getFinalLevel(globalWordIndex, tag);
for (int i=0; i<nSubstates; i++){
if (isGold) counts[startIndexWord + lexiconMapping[finalLevel][i]] += weights[i];
else counts[startIndexWord + lexiconMapping[finalLevel][i]] -= weights[i];
weights[i]=0;
}
} else {
for (int i=0; i<nSubstates; i++){
weights[i]=0;
}
}
}
public void increment(double[] counts, UnaryRule rule, double[] weights, boolean isGold) {
HierarchicalUnaryRule hr = (HierarchicalUnaryRule)rule;
int thisStartIndex = hr.identifier;
int finalLevel = hr.lastLevel;
int curInd = 0;
if (rule.parentState==0){
for (int cp = 0; cp < nSubstates; cp++) {
double val = weights[curInd];
if (val>0){
if (isGold) counts[thisStartIndex + lexiconMapping[finalLevel][cp]] += val;
else counts[thisStartIndex + lexiconMapping[finalLevel][cp]] -= val;
weights[curInd]=0;
}
curInd++;
}
return;
}
for (int cp = 0; cp < nSubstates; cp++) {
// if (scores[cp]==null) continue;
for (int np = 0; np < nSubstates; np++) {
double val = weights[curInd];
if (val>0){
if (isGold) counts[thisStartIndex + unaryMapping[finalLevel][cp][np]] += val;
else counts[thisStartIndex + unaryMapping[finalLevel][cp][np]] -= val;
weights[curInd]=0;
}
curInd++;
}
}
}
public void increment(double[] counts, BinaryRule rule, double[] weights, boolean isGold) {
HierarchicalBinaryRule hr = (HierarchicalBinaryRule)rule;
int thisStartIndex = hr.identifier;
int finalLevel = hr.lastLevel;
int curInd = 0;
for (int lp = 0; lp < nSubstates; lp++) {
for (int rp = 0; rp < nSubstates; rp++) {
// if (scores[cp]==null) continue;
for (int np = 0; np < nSubstates; np++) {
double val = weights[curInd];
if (val>0){
if (isGold) counts[thisStartIndex + binaryMapping[finalLevel][lp][rp][np]] += val;
else counts[thisStartIndex + binaryMapping[finalLevel][lp][rp][np]] -= val;
weights[curInd]=0;
}
curInd++;
}
}
}
}
public Grammar getGrammar() {
return grammar;
}
public SimpleLexicon getLexicon() {
return lexicon;
}
public SpanPredictor getSpanPredictor() {
return spanPredictor;
}
}