/**
*
*/
package edu.berkeley.nlp.PCFGLA;
import java.util.Arrays;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon.IntegerIndexer;
import edu.berkeley.nlp.PCFGLA.smoothing.Smoother;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.PriorityQueue;
/**
* @author petrov
*
*/
public class HierarchicalFullyConnectedAdaptiveLexicon extends HierarchicalFullyConnectedLexicon {
private static final long serialVersionUID = 1L;
public HierarchicalAdaptiveLexicalRule[][] rules;
public HierarchicalFullyConnectedAdaptiveLexicon(short[] numSubStates, int smoothingCutoff, double[] smoothParam,
Smoother smoother, StateSetTreeList trainTrees, int knownWordCount) {
super(numSubStates, knownWordCount);
super.init(trainTrees);
init();
}
public HierarchicalFullyConnectedAdaptiveLexicon(short[] numSubStates, int knownWordCount) {
super(numSubStates, knownWordCount);
}
private void init() {
this.scores = null;
this.hierarchicalScores = null;
this.finalLevels = null;
rules = new HierarchicalAdaptiveLexicalRule[numStates][];
for (int tag=0; tag<numStates; tag++){
if (tagWordIndexer[tag]==null) {
rules[tag] = new HierarchicalAdaptiveLexicalRule[0];
continue;
}
rules[tag] = new HierarchicalAdaptiveLexicalRule[tagWordIndexer[tag].size()];
for (int word=0; word<rules[tag].length; word++){
rules[tag][word] = new HierarchicalAdaptiveLexicalRule();
}
}
}
public double[] score(int globalWordIndex, int globalSigIndex, short tag, int loc, boolean noSmoothing, boolean isSignature) {
double[] res = new double[numSubStates[tag]];
if (tagWordIndexer[tag]==null) return res;
if (globalWordIndex!=-1) {
int tagSpecificWordIndex = tagWordIndexer[tag].indexOf(globalWordIndex);
if (tagSpecificWordIndex!=-1){
for (int i=0; i<numSubStates[tag]; i++){
res[i] = rules[tag][tagSpecificWordIndex].scores[i];
}
} else if (knownWordCount > 0){
Arrays.fill(res, 1.0);
}
} else if (knownWordCount > 0){
Arrays.fill(res, 1.0);
}
if (globalWordIndex>=0 && /*globalWordIndex<wordCounter.length &&*/ (wordCounter[globalWordIndex]>knownWordCount)) {
if (smoother!=null) {
// smoother.smooth(tag,res);
// double max = ArrayMath.max(res) / 1000;
// for (int i=0; i< res.length; i++){
// if (res[i] < max) res[i] += max;
// }
}
return res;
}
if (globalSigIndex>-1) {
int tagSpecificWordIndex = tagWordIndexer[tag].indexOf(globalSigIndex);
if (tagSpecificWordIndex!=-1){
for (int i=0; i<numSubStates[tag]; i++){
res[i] *= rules[tag][tagSpecificWordIndex].scores[i];
}
// } else{
// System.out.println("unseen sig-tag pair");
}
// } else{
// System.out.println("unseen sig");
}
if (smoother!=null){
// smoother.smooth(tag,res);
// double max = ArrayMath.max(res) / 1000;
// for (int i=0; i< res.length; i++){
// if (res[i] < max) res[i] += max;
// }
}
return res;
}
public HierarchicalLexicon copyLexicon(){
// HierarchicalLexicon copy = newInstance();
// copy.expectedCounts = new double[numStates][][];
// copy.scores = ArrayUtil.clone(scores);//new double[numStates][][];
// copy.hierarchicalScores = this.hierarchicalScores;
// copy.tagWordIndexer = new IntegerIndexer[numStates];
// copy.wordIndexer = this.wordIndexer;
// for (int tag=0; tag<numStates; tag++){
// copy.tagWordIndexer[tag] = tagWordIndexer[tag].copy();
// copy.expectedCounts[tag] = new double[numSubStates[tag]][tagWordIndexer[tag].size()];
// }
// if (this.wordCounter!=null) copy.wordCounter = this.wordCounter.clone();
//// if (this.wordIsAmbiguous!=null) copy.wordIsAmbiguous = this.wordIsAmbiguous.clone();
// copy.nWords = this.nWords;
// copy.smoother = this.smoother;
// if (finalLevels!=null) copy.finalLevels = ArrayUtil.clone(this.finalLevels);
// return copy;
return this;
}
public HierarchicalLexicon splitAllStates(int[] counts, boolean moreSubstatesThanCounts, int mode){
int finalLevel = (int)(Math.log((int)ArrayUtil.max(numSubStates))/Math.log(2))+1;
for (int tag=0; tag<numStates; tag++){
numSubStates[tag] *= 2;
for (int word=0; word<rules[tag].length; word++){
rules[tag][word].splitRule(numSubStates[tag]);
rules[tag][word].explicitlyComputeScores(finalLevel, false);
}
}
return this;
}
public void mergeLexicon(){
int removedParam = 0;
for (int tag=0; tag<numStates; tag++){
for (int word=0; word<rules[tag].length; word++){
removedParam += rules[tag][word].mergeRule();
}
}
System.out.println("Removed "+ removedParam+" parameters from the lexicon.");
}
public String toString() {
StringBuffer sb = new StringBuffer();
Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
for (int tag=0; tag<rules.length; tag++){
int[] counts = new int[7];
String tagS = (String)tagNumberer.object(tag);
if (rules[tag].length==0) continue;
for (int word=0; word<rules[tag].length; word++){
sb.append(tagS+" "+ wordIndexer.get(tagWordIndexer[tag].get(word))+"\n");
sb.append(rules[tag][word].toString());
sb.append("\n\n");
counts[rules[tag][word].hierarchy.getDepth()]++;
}
System.out.print(tagNumberer.object(tag)+", lexical rules per level: ");
for (int i=1; i<6; i++){
System.out.print(counts[i]+" ");
}
System.out.print("\n");
}
return sb.toString();
}
public void explicitlyComputeScores(int finalLevel){
for (short tag=0; tag<rules.length; tag++){
for (int word=0; word<rules[tag].length; word++){
rules[tag][word].explicitlyComputeScores(finalLevel, false);
}
}
}
}