/** * */ package edu.berkeley.nlp.PCFGLA; import java.util.Arrays; import java.util.List; import edu.berkeley.nlp.PCFGLA.smoothing.Smoother; import edu.berkeley.nlp.syntax.StateSet; import edu.berkeley.nlp.syntax.Tree; /** * @author petrov * */ public class HierarchicalFullyConnectedLexicon extends HierarchicalLexicon { private static final long serialVersionUID = 1L; protected int knownWordCount; /** * @param numSubStates * @param threshold */ public HierarchicalFullyConnectedLexicon(short[] numSubStates, int knownWordCount) { super(numSubStates, 0); this.knownWordCount = knownWordCount; } public HierarchicalFullyConnectedLexicon(short[] numSubStates, int smoothingCutoff, double[] smoothParam, Smoother smoother, StateSetTreeList trainTrees, int knownWordCount) { this(numSubStates, knownWordCount); init(trainTrees); } /** * @param previousLexicon */ public HierarchicalFullyConnectedLexicon(SimpleLexicon previousLexicon, int knownWordCount) { super(previousLexicon); this.knownWordCount = knownWordCount; } public HierarchicalFullyConnectedLexicon newInstance() { return new HierarchicalFullyConnectedLexicon(this.numSubStates,this.knownWordCount); } public void init(StateSetTreeList trainTrees){ for (Tree<StateSet> tree : trainTrees){ List<StateSet> words = tree.getYield(); for (StateSet word : words){ String sig = word.getWord(); wordIndexer.add(sig); } } wordCounter = new int[wordIndexer.size()]; for (Tree<StateSet> tree : trainTrees){ List<StateSet> words = tree.getYield(); int ind = 0; for (StateSet word : words){ String wordString = word.getWord(); wordCounter[wordIndexer.indexOf(wordString)]++; String sig = getSignature(word.getWord(), ind++); wordIndexer.add(sig); } } tagWordIndexer = new IntegerIndexer[numStates]; for (int tag=0; tag<numStates; tag++){ tagWordIndexer[tag] = new IntegerIndexer(wordIndexer.size()); } labelTrees(trainTrees); boolean[] lexTag = new boolean[numStates]; for (Tree<StateSet> tree : trainTrees){ List<StateSet> words = tree.getYield(); List<StateSet> tags = tree.getPreTerminalYield(); int ind = 0; for (StateSet word : words){ int tag = tags.get(ind).getState(); tagWordIndexer[tag].add(new Integer(word.wordIndex)); tagWordIndexer[tag].add(new Integer(word.sigIndex)); lexTag[tag] = true; ind++; } } expectedCounts = new double[numStates][][]; scores = new double[numStates][][]; for (int tag=0; tag<numStates; tag++){ if (!lexTag[tag]) { tagWordIndexer[tag] = null; continue; } // else tagWordIndexer[tag] = tagIndexer; // expectedCounts[tag] = new double[numSubStates[tag]][tagWordIndexer[tag].size()]; scores[tag] = new double[numSubStates[tag]][tagWordIndexer[tag].size()]; } nWords = wordIndexer.size(); } public double[] score(int globalWordIndex, int globalSigIndex, short tag, int loc, boolean noSmoothing, boolean isSignature) { double[] res = new double[numSubStates[tag]]; if (globalWordIndex!=-1) { int tagSpecificWordIndex = tagWordIndexer[tag].indexOf(globalWordIndex); if (tagSpecificWordIndex!=-1){ for (int i=0; i<numSubStates[tag]; i++){ res[i] = scores[tag][i][tagSpecificWordIndex]; } } else { Arrays.fill(res, 1.0); } } else { Arrays.fill(res, 1.0); } if (globalWordIndex>=0 && (wordCounter[globalWordIndex]>knownWordCount)) { // if (globalSigIndex!=-1) System.out.println("Problem: frequent word has signature!"); return res; } if (globalSigIndex!=-1) { int tagSpecificWordIndex = tagWordIndexer[tag].indexOf(globalSigIndex); if (tagSpecificWordIndex!=-1){ for (int i=0; i<numSubStates[tag]; i++){ res[i] *= scores[tag][i][tagSpecificWordIndex]; } // } else{ // System.out.println("unseen sig-tag pair"); } // } else{ // System.out.println("unseen sig"); } // if (smoother!=null) smoother.smooth(tag,res); return res; } public double[] score(StateSet stateSet, short tag, boolean noSmoothing, boolean isSignature) { if (stateSet.wordIndex == -2) { String word = stateSet.getWord(); if (isSignature){ stateSet.wordIndex = -1; stateSet.sigIndex = wordIndexer.indexOf(word); } else { stateSet.wordIndex = wordIndexer.indexOf(word); // if (stateSet.wordIndex > wordCounter.length){ // System.out.println("no count for this word: "+(String)wordIndexer.get(tagWordIndexer[tag].get(stateSet.wordIndex))); // stateSet.sigIndex = -1; // } else { if ((stateSet.wordIndex>=0 && (wordCounter[stateSet.wordIndex]>knownWordCount)) || noSmoothing) stateSet.sigIndex = -1; else if (knownWordCount > 0) stateSet.sigIndex = wordIndexer.indexOf(getSignature(word,stateSet.from)); else stateSet.wordIndex = wordIndexer.indexOf(getSignature(word,stateSet.from)); } // } } return score(stateSet.wordIndex, stateSet.sigIndex, tag, stateSet.from, noSmoothing, isSignature); } public void labelTrees(StateSetTreeList trainTrees){ for (Tree<StateSet> tree : trainTrees){ List<StateSet> words = tree.getYield(); List<StateSet> tags = tree.getPreTerminalYield(); int ind = 0; for (StateSet word : words){ word.wordIndex = wordIndexer.indexOf(word.getWord()); if (word.wordIndex<0 || word.wordIndex>=wordCounter.length){ System.out.println("Have never seen this word before: "+word.getWord()+" "+word.wordIndex); System.out.println(tree); } else if (wordCounter[word.wordIndex]<=knownWordCount){ short tag = tags.get(ind).getState(); String sig = getSignature(word.getWord(), ind); wordIndexer.add(sig); word.sigIndex = wordIndexer.indexOf(sig); tagWordIndexer[tag].add(wordIndexer.indexOf(sig)); } else word.sigIndex = -1; ind++; } } } }