/**
*
*/
package edu.berkeley.nlp.PCFGLA;
import java.util.Arrays;
import java.util.List;
import edu.berkeley.nlp.PCFGLA.smoothing.Smoother;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.StateSetWithFeatures;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.Pair;
import edu.berkeley.nlp.util.PriorityQueue;
/**
* @author petrov
*
*/
public class HierarchicalFullyConnectedAdaptiveLexiconWithFeatures extends
HierarchicalFullyConnectedAdaptiveLexicon {
private static final long serialVersionUID = 1L;
Indexer<String> featureIndexer;
SimpleLexicon simpleLex;
private final int minFeatureCount = 50;
public HierarchicalFullyConnectedAdaptiveLexiconWithFeatures(short[] numSubStates, int smoothingCutoff, double[] smoothParam, Smoother smoother, StateSetTreeList trainTrees, int knownWordCount) {
super(numSubStates, knownWordCount);//smoothingCutoff, smoothParam, smoother, trainTrees, knownWordCount);
simpleLex = new SimpleLexicon(numSubStates,-1);
init(trainTrees);
// super.init(trainTrees);
}
// public HierarchicalFullyConnectedAdaptiveLexiconWithFeatures newInstance() {
// return new HierarchicalFullyConnectedAdaptiveLexiconWithFeatures(this.numSubStates,this.knownWordCount);
// }
@Override
public void init(StateSetTreeList trainTrees){
for (Tree<StateSet> tree : trainTrees){
List<StateSet> words = tree.getYield();
for (StateSet word : words){
String sig = word.getWord();
wordIndexer.add(sig);
}
}
wordCounter = new int[wordIndexer.size()];
Counter<String> ixCounter = new Counter<String>();
featureIndexer = new Indexer<String>();
for (Tree<StateSet> tree : trainTrees){
List<StateSet> words = tree.getYield();
int ind = 0;
for (StateSet word : words){
String wordString = word.getWord();
wordCounter[wordIndexer.indexOf(wordString)]++;
String sig = getSignature(word.getWord(), ind++);
wordIndexer.add(sig);
tallyWordFeatures(word.getWord(), ixCounter);
}
}
featureIndexer = new Indexer<String>();
for (String word : ixCounter.keySet()){
if (ixCounter.getCount(word) >= minFeatureCount){
System.out.println("keeping: \t"+word);
featureIndexer.add(word);
}
else
System.out.println("too rare:\t"+word);
}
simpleLex.wordCounter = wordCounter;
labelTrees(trainTrees);
tagWordIndexer = new IntegerIndexer[numStates];
for (int tag=0; tag<numStates; tag++){
tagWordIndexer[tag] = new IntegerIndexer(featureIndexer.size());
}
boolean[] lexTag = new boolean[numStates];
for (Tree<StateSet> tree : trainTrees){
List<StateSet> words = tree.getYield();
List<StateSet> tags = tree.getPreTerminalYield();
int ind = 0;
for (StateSet word : words){
int tag = tags.get(ind).getState();
StateSetWithFeatures wordF = (StateSetWithFeatures)word;
for (Integer f : wordF.features){
tagWordIndexer[tag].add(f);
}
lexTag[tag] = true;
ind++;
}
}
expectedCounts = new double[numStates][][];
scores = new double[numStates][][];
for (int tag=0; tag<numStates; tag++){
if (!lexTag[tag]) {
tagWordIndexer[tag] = null;
continue;
}
// else tagWordIndexer[tag] = tagIndexer;
// expectedCounts[tag] = new double[numSubStates[tag]][tagWordIndexer[tag].size()];
scores[tag] = new double[numSubStates[tag]][tagWordIndexer[tag].size()];
}
nWords = wordIndexer.size();
this.scores = null;
this.hierarchicalScores = null;
this.finalLevels = null;
rules = new HierarchicalAdaptiveLexicalRule[numStates][];
for (int tag=0; tag<numStates; tag++){
if (tagWordIndexer[tag]==null) {
rules[tag] = new HierarchicalAdaptiveLexicalRule[0];
continue;
}
rules[tag] = new HierarchicalAdaptiveLexicalRule[tagWordIndexer[tag].size()];
for (int word=0; word<rules[tag].length; word++){
rules[tag][word] = new HierarchicalAdaptiveLexicalRule();
}
}
}
/**
* @param word
* @param ixCounter
*/
private void tallyWordFeatures(String word, Counter<String> ixCounter) {
int length = word.length();
if (length>4){
for (int i=1; i<4; i++){
// String prefix = "PREF-"+word.substring(0,i);
// featureIndexer.add(prefix);
// ixCounter.incrementCount(prefix, 1.0);
String suffix = "SUFF-"+word.substring(length-i);
featureIndexer.add(suffix);
ixCounter.incrementCount(suffix, 1.0);
}
}
}
public StateSet tallyFeatures(StateSet stateSet, boolean update) {
String word = stateSet.getWord();
String lowered = word.toLowerCase();
int loc = stateSet.from;
String sig = simpleLex.getNewSignature(word, loc);
StateSetWithFeatures newStateSet = new StateSetWithFeatures(stateSet);
if (update) featureIndexer.add(sig);
newStateSet.features.add(featureIndexer.indexOf(sig));
if (update) featureIndexer.add("UNK");
newStateSet.features.add(featureIndexer.indexOf("UNK"));
int length = word.length();
if (length>4){
for (int i=1; i<4; i++){
// String prefix = "PREF-"+lowered.substring(0,i);
// int prefInd = featureIndexer.indexOf(prefix);
// if (prefInd>=0)
// newStateSet.features.add(prefInd);
String suffix = "SUFF-"+lowered.substring(length-i);
int suffInd = featureIndexer.indexOf(suffix);
if (suffInd>=0)
newStateSet.features.add(suffInd);
}
}
int wlen = word.length();
int numCaps = 0;
boolean hasDigit = false;
boolean hasDash = false;
boolean hasLower = false;
for (int i = 0; i < wlen; i++) {
char ch = word.charAt(i);
if (Character.isDigit(ch)) {
hasDigit = true;
} else if (ch == '-') {
hasDash = true;
} else if (Character.isLetter(ch)) {
if (Character.isLowerCase(ch)) {
hasLower = true;
} else if (Character.isTitleCase(ch)) {
hasLower = true;
numCaps++;
} else {
numCaps++;
}
}
}
char ch0 = word.charAt(0);
if (Character.isUpperCase(ch0) || Character.isTitleCase(ch0)) {
if (loc == 0 && numCaps == 1) {
if (update) featureIndexer.add("INITC");
newStateSet.features.add(featureIndexer.indexOf("INITC"));
// if (isKnown(lowered)) {
// sb.append("-KNOWNLC");
// }
} else {
if (update) featureIndexer.add("CAPS");
newStateSet.features.add(featureIndexer.indexOf("CAPS"));
}
} else if (!Character.isLetter(ch0) && numCaps > 0) {
if (update) featureIndexer.add("CAPS");
newStateSet.features.add(featureIndexer.indexOf("CAPS"));
} else if (hasLower) { // (Character.isLowerCase(ch0)) {
if (update) featureIndexer.add("LC");
newStateSet.features.add(featureIndexer.indexOf("LC"));
}
if (hasDigit) {
if (update) featureIndexer.add("NUM");
newStateSet.features.add(featureIndexer.indexOf("NUM"));
}
if (hasDash) {
if (update) featureIndexer.add("DASH");
newStateSet.features.add(featureIndexer.indexOf("DASH"));
}
if (lowered.endsWith("s") && wlen >= 3) {
// here length 3, so you don't miss out on ones like 80s
char ch2 = lowered.charAt(wlen - 2);
// not -ess suffixes or greek/latin -us, -is
if (ch2 != 's' && ch2 != 'i' && ch2 != 'u') {
if (update) featureIndexer.add("s");
newStateSet.features.add(featureIndexer.indexOf("s"));
}
} else if (word.length() >= 5 && !hasDash && !(hasDigit && numCaps > 0)) {
// don't do for very short words;
// Implement common discriminating suffixes
/* if (Corpus.myLanguage==Corpus.GERMAN){
sb.append(lowered.substring(lowered.length()-1));
}else{*/
// if (lowered.endsWith("ed")) {
// sb.append("-ed");
// } else if (lowered.endsWith("ing")) {
// sb.append("-ing");
// } else if (lowered.endsWith("ion")) {
// sb.append("-ion");
// } else if (lowered.endsWith("er")) {
// sb.append("-er");
// } else if (lowered.endsWith("est")) {
// sb.append("-est");
// } else if (lowered.endsWith("ly")) {
// sb.append("-ly");
// } else if (lowered.endsWith("ity")) {
// sb.append("-ity");
// } else if (lowered.endsWith("y")) {
// sb.append("-y");
// } else if (lowered.endsWith("al")) {
// sb.append("-al");
// } else if (lowered.endsWith("ble")) {
// sb.append("-ble");
// } else if (lowered.endsWith("e")) {
// sb.append("-e");
}
return newStateSet;
}
@Override
public void labelTrees(StateSetTreeList trainTrees){
for (Tree<StateSet> tree : trainTrees){
// List<StateSet> words = tree.getYield();
int ind = 0;
for (Tree<StateSet> word : tree.getTerminals()){
StateSetWithFeatures wordF = new StateSetWithFeatures(word.getLabel());
// wordF.wordIndex = wordIndexer.indexOf(word.getWord());
if (wordF.wordIndex<0 || wordF.wordIndex>=wordCounter.length){
System.out.println("Have never seen this word before: "+wordF.getWord()+" "+wordF.wordIndex);
System.out.println(tree);
}
else if (wordCounter[wordF.wordIndex]<=knownWordCount){
wordF = (StateSetWithFeatures) tallyFeatures(wordF, false);
}
else
wordF.sigIndex = -1;
featureIndexer.add(wordF.getWord());
wordF.features.add(featureIndexer.indexOf(wordF.getWord()));
word.setLabel(wordF);
ind++;
}
}
}
// StateSetWithFeatures lastStateSet;
@Override
public double[] score(StateSet stateSet, short tag, boolean noSmoothing, boolean isSignature) {
double[] res = new double[numSubStates[tag]];
Arrays.fill(res,1);
StateSetWithFeatures stateSetF = null;
if (stateSet.wordIndex == -2) {
stateSetF = new StateSetWithFeatures(stateSet);
int wordIndex = wordIndexer.indexOf(stateSet.getWord());
if (wordIndex<0||(wordIndex>=0 && (wordCounter[wordIndex]<=knownWordCount))){
stateSetF = (StateSetWithFeatures)tallyFeatures(stateSet, false);
}
int f = featureIndexer.indexOf(stateSet.getWord());
if (f>=0)
stateSetF.features.add(f);
// stateSetF.wordIndex = -3;
// stateSet = lastStateSet;
// } else if (stateSet.wordIndex == -3){
// stateSet = lastStateSet;
} else {
stateSetF = (StateSetWithFeatures) stateSet;
}
boolean noFeat = true;
for (int f : stateSetF.features){
// if (f>tagWordIndexer[tag].size())
// System.out.println("hier");
if (f<0)
continue;
int tagF = tagWordIndexer[tag].indexOf(f);
if (tagF<0)
continue;
noFeat = false;
double[] resF = rules[tag][tagF].scores;
for (int i=0; i<res.length; i++){
res[i] *= resF[i];
}
}
// if (noFeat) {
// System.out.println("No features for word "+stateSet.getWord()+" "+wordIndexer.indexOf(stateSet.getWord()));
// }
return res;
}
@Override
public String toString() {
StringBuffer sb = new StringBuffer();
Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
PriorityQueue<Pair<Integer,Integer>> pQ = new PriorityQueue<Pair<Integer,Integer>>();
for (int tag=0; tag<rules.length; tag++){
int[] counts = new int[6];
String tagS = (String)tagNumberer.object(tag);
if (rules[tag].length==0) continue;
for (int word=0; word<featureIndexer.size(); word++){
int wordT = tagWordIndexer[tag].indexOf(word);
if (wordT<0) continue;
String w = featureIndexer.get(word);
if (w.length()>4 && w.substring(0, 4).equals("SUFF")){
pQ.add(new Pair(tag,word), rules[tag][wordT].scores[0]);
}
}
}
while (pQ.hasNext()){
Pair<Integer,Integer> p = pQ.next();
int word = p.getSecond();
int tag = p.getFirst();
String tagS = (String)tagNumberer.object(tag);
int wordT = tagWordIndexer[tag].indexOf(word);
sb.append(tagS+" "+ featureIndexer.get(word)+"\n");
sb.append(rules[tag][wordT].toString());
sb.append("\n\n");
}
sb.append("-----------Start unsorted----------\n");
for (int tag=0; tag<rules.length; tag++){
int[] counts = new int[6];
String tagS = (String)tagNumberer.object(tag);
if (rules[tag].length==0) continue;
for (int word=0; word<featureIndexer.size(); word++){
int wordT = tagWordIndexer[tag].indexOf(word);
if (wordT<0) continue;
sb.append(tagS+" "+ featureIndexer.get(word)+"\n");
sb.append(rules[tag][wordT].toString());
sb.append("\n\n");
counts[rules[tag][wordT].hierarchy.getDepth()]++;
}
System.out.print(tagNumberer.object(tag)+", lexical rules per level: ");
for (int i=1; i<6; i++){
System.out.print(counts[i]+" ");
}
System.out.print("\n");
}
return sb.toString();
}
}