/**
*
*/
package edu.berkeley.nlp.PCFGLA;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Numberer;
/**
* @author petrov
*
*/
public class HierarchicalLexicon extends SimpleLexicon {
private static final long serialVersionUID = 1L;
public List<double[]>[][] hierarchicalScores; // for each tag, word store a list of hiearchical features
public int[][] finalLevels;
/**
* @param numSubStates
* @param threshold
*/
public HierarchicalLexicon(short[] numSubStates, double threshold) {
super(numSubStates, threshold);
hierarchicalScores = new List[numStates][];
}
public HierarchicalLexicon(SimpleLexicon lex){
super(lex.numSubStates,lex.threshold);
this.expectedCounts = new double[numStates][][];
this.tagWordIndexer = new IntegerIndexer[numStates];
this.wordIndexer = lex.wordIndexer;
this.wordCounter = lex.wordCounter;
// this.wordIsAmbiguous = lex.wordIsAmbiguous;
for (int tag=0; tag<numStates; tag++){
this.tagWordIndexer[tag] = lex.tagWordIndexer[tag].copy();
}
this.nWords = lex.nWords;
this.smoother = lex.smoother;
makeHiearchicalScores(lex.scores);
this.scores = null;
}
// assume for now that the scores being passed in are from an unsplit baseline grammar
private void makeHiearchicalScores(double[][][] scores) {
hierarchicalScores = new List[numStates][];
finalLevels = new int[numStates][];
for (int tag=0; tag<numStates; tag++){
int words = tagWordIndexer[tag].size();
hierarchicalScores[tag] = new List[words];
finalLevels[tag] = new int[words];
for (int word=0; word<words; word++){
hierarchicalScores[tag][word] = new ArrayList<double[]>();
double[] score = {Math.log(scores[tag][0][word])};
hierarchicalScores[tag][word].add(score);
//finalLevels[tag][word]=0; // already initialized to 0
}
}
}
public void explicitlyComputeScores(int finalLevel){
this.scores = new double[numStates][][];
int nSubstates = (int)Math.pow(2, finalLevel);
// int[] divisors = new int[nSubstates];//finalLevel+1];
// for (int i=0; i<=finalLevel; i++){
// int div = (int)Math.pow(2, finalLevel-i);
// divisors[div] = div;
// }
for (int tag=0; tag<numStates; tag++){
int words = hierarchicalScores[tag].length;
this.scores[tag] = new double[nSubstates][words];
for (int word=0; word<words; word++){
List<double[]> scoreHierarchy = hierarchicalScores[tag][word];
for (int level=0; level<=finalLevel; level++){
if (level>finalLevels[tag][word])
continue;
double[] scoresThisLevel = scoreHierarchy.get(level);
int divisor = nSubstates/scoresThisLevel.length; // divisors[level];
for (int substate=0; substate<nSubstates; substate++){
this.scores[tag][substate][word] += scoresThisLevel[substate/divisor];
}
}
for (int substate=0; substate<nSubstates; substate++){
this.scores[tag][substate][word] = Math.exp(scores[tag][substate][word]);
}
}
}
}
public HierarchicalLexicon splitAllStates(int[] counts, boolean moreSubstatesThanCounts, int mode){
short[] newNumSubStates = new short[numSubStates.length];
newNumSubStates[0] = 1;
for (short i = 1; i < numSubStates.length; i++) {
// don't split a state into more substates than times it was actaully seen
if (!moreSubstatesThanCounts && numSubStates[i]>=counts[i]) {
newNumSubStates[i]=numSubStates[i];
}
else{
newNumSubStates[i] = (short)(numSubStates[i] * 2);
}
}
HierarchicalLexicon newLex = newInstance();
newLex.numSubStates = newNumSubStates;
Random random = GrammarTrainer.RANDOM;
newLex.expectedCounts = new double[numStates][][];
newLex.tagWordIndexer = new IntegerIndexer[numStates];
newLex.wordIndexer = this.wordIndexer;
for (int tag=0; tag<numStates; tag++){
newLex.tagWordIndexer[tag] = tagWordIndexer[tag].copy();
}
newLex.nWords = this.nWords;
newLex.smoother = this.smoother;
List<double[]>[][] hS = new List[numStates][];
newLex.finalLevels = new int[numStates][];
// int[] nSubstates = new int[finalLevel+1];
// for (int i=0; i<=finalLevel; i++){
// nSubstates[i] = (int)Math.pow(2, i);
// }
for (int tag=0; tag<numStates; tag++){
int words = tagWordIndexer[tag].size();
hS[tag] = new List[words];
newLex.finalLevels[tag] = new int[words];
for (int word=0; word<words; word++){
hS[tag][word] = new ArrayList<double[]>();
for (double[] scores : hierarchicalScores[tag][word]){
hS[tag][word].add(scores.clone());
}
int fLevel = this.finalLevels[tag][word]+1;
int nSub = (int)Math.pow(2, fLevel);
if (nSub > newNumSubStates[tag]) continue;
double[] newScores = new double[nSub];
for (int i=0; i<newScores.length; i++){
newScores[i] = random.nextDouble()/100.0;
}
hS[tag][word].add(newScores);
newLex.finalLevels[tag][word] = fLevel;
}
}
newLex.scores = null;
newLex.hierarchicalScores = hS;
newLex.wordCounter = wordCounter;
// newLex.wordIsAmbiguous = wordIsAmbiguous;
return newLex;
}
/**
* @return
*/
public HierarchicalLexicon newInstance() {
return new HierarchicalLexicon(this.numSubStates,this.threshold);
}
public int getFinalLevel(int globalWordIndex, int tag){
int tagSpecificWordIndex = tagWordIndexer[tag].indexOf(globalWordIndex);
return finalLevels[tag][tagSpecificWordIndex];
}
public void mergeLexicon(){
int nRemovedParam = 0, nRemovedArrays = 0;
for (int tag=0; tag<numStates; tag++){
int words = hierarchicalScores[tag].length;
for (int word=0; word<words; word++){
List<double[]> scoreHierarchy = hierarchicalScores[tag][word];
int level = finalLevels[tag][word];
double[] scoresThisLevel = scoreHierarchy.get(level);
if (scoresThisLevel == null) continue;
boolean allZero = true;
for (int substate=0; substate<scoresThisLevel.length; substate++){
allZero = allZero && scoresThisLevel[substate]==0;
}
if (allZero) {
scoreHierarchy.remove(level);
finalLevels[tag][word]--;
nRemovedParam += scoresThisLevel.length;
nRemovedArrays++;
}
}
}
System.out.println("Removed "+nRemovedParam+" parameters in the lexicon by setting "+nRemovedArrays+" arrays to null.");
}
public double[] getLastLevel(int tag, int word) {
return hierarchicalScores[tag][word].get(finalLevels[tag][word]);
}
public HierarchicalLexicon copyLexicon(){
HierarchicalLexicon copy = newInstance();
copy.expectedCounts = new double[numStates][][];
copy.scores = ArrayUtil.clone(scores);//new double[numStates][][];
copy.hierarchicalScores = this.hierarchicalScores;
copy.tagWordIndexer = new IntegerIndexer[numStates];
copy.wordIndexer = this.wordIndexer;
for (int tag=0; tag<numStates; tag++){
copy.tagWordIndexer[tag] = tagWordIndexer[tag].copy();
copy.expectedCounts[tag] = new double[numSubStates[tag]][tagWordIndexer[tag].size()];
}
if (this.wordCounter!=null) copy.wordCounter = this.wordCounter.clone();
// if (this.wordIsAmbiguous!=null) copy.wordIsAmbiguous = this.wordIsAmbiguous.clone();
copy.nWords = this.nWords;
copy.smoother = this.smoother;
if (finalLevels!=null) copy.finalLevels = ArrayUtil.clone(this.finalLevels);
return copy;
}
public String toString() {
StringBuffer sb = new StringBuffer();
Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
for (int tag=0; tag<scores.length; tag++){
int[] counts = new int[6];
String tagS = (String)tagNumberer.object(tag);
if (tagWordIndexer[tag].size()==0) continue;
for (int word=0; word<scores[tag][0].length; word++){
sb.append(tagS+" "+ wordIndexer.get(tagWordIndexer[tag].get(word))+" ");
for (int sub=0; sub<numSubStates[tag]; sub++){
sb.append(" " + scores[tag][sub][word]);
}
for (double[] d : hierarchicalScores[tag][word]){
sb.append("\n"+Arrays.toString(d));
}
counts[finalLevels[tag][word]]++;
sb.append("\n\n");
}
System.out.print(tagNumberer.object(tag)+", word,tag pairs per level: ");
for (int i=1; i<6; i++){
System.out.print(counts[i]+" ");
}
System.out.print("\n");
}
return sb.toString();
}
}