/** * */ package edu.berkeley.nlp.PCFGLA; import java.util.Arrays; import java.util.List; import edu.berkeley.nlp.PCFGLA.smoothing.Smoother; import edu.berkeley.nlp.syntax.StateSet; import edu.berkeley.nlp.syntax.StateSetWithFeatures; import edu.berkeley.nlp.syntax.Tree; import edu.berkeley.nlp.util.Counter; import edu.berkeley.nlp.util.Indexer; import edu.berkeley.nlp.util.Numberer; import edu.berkeley.nlp.util.Pair; import edu.berkeley.nlp.util.PriorityQueue; /** * @author petrov * */ public class HierarchicalFullyConnectedAdaptiveLexiconWithFeatures extends HierarchicalFullyConnectedAdaptiveLexicon { private static final long serialVersionUID = 1L; Indexer<String> featureIndexer; SimpleLexicon simpleLex; private final int minFeatureCount = 50; public HierarchicalFullyConnectedAdaptiveLexiconWithFeatures(short[] numSubStates, int smoothingCutoff, double[] smoothParam, Smoother smoother, StateSetTreeList trainTrees, int knownWordCount) { super(numSubStates, knownWordCount);//smoothingCutoff, smoothParam, smoother, trainTrees, knownWordCount); simpleLex = new SimpleLexicon(numSubStates,-1); init(trainTrees); // super.init(trainTrees); } // public HierarchicalFullyConnectedAdaptiveLexiconWithFeatures newInstance() { // return new HierarchicalFullyConnectedAdaptiveLexiconWithFeatures(this.numSubStates,this.knownWordCount); // } @Override public void init(StateSetTreeList trainTrees){ for (Tree<StateSet> tree : trainTrees){ List<StateSet> words = tree.getYield(); for (StateSet word : words){ String sig = word.getWord(); wordIndexer.add(sig); } } wordCounter = new int[wordIndexer.size()]; Counter<String> ixCounter = new Counter<String>(); featureIndexer = new Indexer<String>(); for (Tree<StateSet> tree : trainTrees){ List<StateSet> words = tree.getYield(); int ind = 0; for (StateSet word : words){ String wordString = word.getWord(); wordCounter[wordIndexer.indexOf(wordString)]++; String sig = getSignature(word.getWord(), ind++); wordIndexer.add(sig); tallyWordFeatures(word.getWord(), ixCounter); } } featureIndexer = new Indexer<String>(); for (String word : ixCounter.keySet()){ if (ixCounter.getCount(word) >= minFeatureCount){ System.out.println("keeping: \t"+word); featureIndexer.add(word); } else System.out.println("too rare:\t"+word); } simpleLex.wordCounter = wordCounter; labelTrees(trainTrees); tagWordIndexer = new IntegerIndexer[numStates]; for (int tag=0; tag<numStates; tag++){ tagWordIndexer[tag] = new IntegerIndexer(featureIndexer.size()); } boolean[] lexTag = new boolean[numStates]; for (Tree<StateSet> tree : trainTrees){ List<StateSet> words = tree.getYield(); List<StateSet> tags = tree.getPreTerminalYield(); int ind = 0; for (StateSet word : words){ int tag = tags.get(ind).getState(); StateSetWithFeatures wordF = (StateSetWithFeatures)word; for (Integer f : wordF.features){ tagWordIndexer[tag].add(f); } lexTag[tag] = true; ind++; } } expectedCounts = new double[numStates][][]; scores = new double[numStates][][]; for (int tag=0; tag<numStates; tag++){ if (!lexTag[tag]) { tagWordIndexer[tag] = null; continue; } // else tagWordIndexer[tag] = tagIndexer; // expectedCounts[tag] = new double[numSubStates[tag]][tagWordIndexer[tag].size()]; scores[tag] = new double[numSubStates[tag]][tagWordIndexer[tag].size()]; } nWords = wordIndexer.size(); this.scores = null; this.hierarchicalScores = null; this.finalLevels = null; rules = new HierarchicalAdaptiveLexicalRule[numStates][]; for (int tag=0; tag<numStates; tag++){ if (tagWordIndexer[tag]==null) { rules[tag] = new HierarchicalAdaptiveLexicalRule[0]; continue; } rules[tag] = new HierarchicalAdaptiveLexicalRule[tagWordIndexer[tag].size()]; for (int word=0; word<rules[tag].length; word++){ rules[tag][word] = new HierarchicalAdaptiveLexicalRule(); } } } /** * @param word * @param ixCounter */ private void tallyWordFeatures(String word, Counter<String> ixCounter) { int length = word.length(); if (length>4){ for (int i=1; i<4; i++){ // String prefix = "PREF-"+word.substring(0,i); // featureIndexer.add(prefix); // ixCounter.incrementCount(prefix, 1.0); String suffix = "SUFF-"+word.substring(length-i); featureIndexer.add(suffix); ixCounter.incrementCount(suffix, 1.0); } } } public StateSet tallyFeatures(StateSet stateSet, boolean update) { String word = stateSet.getWord(); String lowered = word.toLowerCase(); int loc = stateSet.from; String sig = simpleLex.getNewSignature(word, loc); StateSetWithFeatures newStateSet = new StateSetWithFeatures(stateSet); if (update) featureIndexer.add(sig); newStateSet.features.add(featureIndexer.indexOf(sig)); if (update) featureIndexer.add("UNK"); newStateSet.features.add(featureIndexer.indexOf("UNK")); int length = word.length(); if (length>4){ for (int i=1; i<4; i++){ // String prefix = "PREF-"+lowered.substring(0,i); // int prefInd = featureIndexer.indexOf(prefix); // if (prefInd>=0) // newStateSet.features.add(prefInd); String suffix = "SUFF-"+lowered.substring(length-i); int suffInd = featureIndexer.indexOf(suffix); if (suffInd>=0) newStateSet.features.add(suffInd); } } int wlen = word.length(); int numCaps = 0; boolean hasDigit = false; boolean hasDash = false; boolean hasLower = false; for (int i = 0; i < wlen; i++) { char ch = word.charAt(i); if (Character.isDigit(ch)) { hasDigit = true; } else if (ch == '-') { hasDash = true; } else if (Character.isLetter(ch)) { if (Character.isLowerCase(ch)) { hasLower = true; } else if (Character.isTitleCase(ch)) { hasLower = true; numCaps++; } else { numCaps++; } } } char ch0 = word.charAt(0); if (Character.isUpperCase(ch0) || Character.isTitleCase(ch0)) { if (loc == 0 && numCaps == 1) { if (update) featureIndexer.add("INITC"); newStateSet.features.add(featureIndexer.indexOf("INITC")); // if (isKnown(lowered)) { // sb.append("-KNOWNLC"); // } } else { if (update) featureIndexer.add("CAPS"); newStateSet.features.add(featureIndexer.indexOf("CAPS")); } } else if (!Character.isLetter(ch0) && numCaps > 0) { if (update) featureIndexer.add("CAPS"); newStateSet.features.add(featureIndexer.indexOf("CAPS")); } else if (hasLower) { // (Character.isLowerCase(ch0)) { if (update) featureIndexer.add("LC"); newStateSet.features.add(featureIndexer.indexOf("LC")); } if (hasDigit) { if (update) featureIndexer.add("NUM"); newStateSet.features.add(featureIndexer.indexOf("NUM")); } if (hasDash) { if (update) featureIndexer.add("DASH"); newStateSet.features.add(featureIndexer.indexOf("DASH")); } if (lowered.endsWith("s") && wlen >= 3) { // here length 3, so you don't miss out on ones like 80s char ch2 = lowered.charAt(wlen - 2); // not -ess suffixes or greek/latin -us, -is if (ch2 != 's' && ch2 != 'i' && ch2 != 'u') { if (update) featureIndexer.add("s"); newStateSet.features.add(featureIndexer.indexOf("s")); } } else if (word.length() >= 5 && !hasDash && !(hasDigit && numCaps > 0)) { // don't do for very short words; // Implement common discriminating suffixes /* if (Corpus.myLanguage==Corpus.GERMAN){ sb.append(lowered.substring(lowered.length()-1)); }else{*/ // if (lowered.endsWith("ed")) { // sb.append("-ed"); // } else if (lowered.endsWith("ing")) { // sb.append("-ing"); // } else if (lowered.endsWith("ion")) { // sb.append("-ion"); // } else if (lowered.endsWith("er")) { // sb.append("-er"); // } else if (lowered.endsWith("est")) { // sb.append("-est"); // } else if (lowered.endsWith("ly")) { // sb.append("-ly"); // } else if (lowered.endsWith("ity")) { // sb.append("-ity"); // } else if (lowered.endsWith("y")) { // sb.append("-y"); // } else if (lowered.endsWith("al")) { // sb.append("-al"); // } else if (lowered.endsWith("ble")) { // sb.append("-ble"); // } else if (lowered.endsWith("e")) { // sb.append("-e"); } return newStateSet; } @Override public void labelTrees(StateSetTreeList trainTrees){ for (Tree<StateSet> tree : trainTrees){ // List<StateSet> words = tree.getYield(); int ind = 0; for (Tree<StateSet> word : tree.getTerminals()){ StateSetWithFeatures wordF = new StateSetWithFeatures(word.getLabel()); // wordF.wordIndex = wordIndexer.indexOf(word.getWord()); if (wordF.wordIndex<0 || wordF.wordIndex>=wordCounter.length){ System.out.println("Have never seen this word before: "+wordF.getWord()+" "+wordF.wordIndex); System.out.println(tree); } else if (wordCounter[wordF.wordIndex]<=knownWordCount){ wordF = (StateSetWithFeatures) tallyFeatures(wordF, false); } else wordF.sigIndex = -1; featureIndexer.add(wordF.getWord()); wordF.features.add(featureIndexer.indexOf(wordF.getWord())); word.setLabel(wordF); ind++; } } } // StateSetWithFeatures lastStateSet; @Override public double[] score(StateSet stateSet, short tag, boolean noSmoothing, boolean isSignature) { double[] res = new double[numSubStates[tag]]; Arrays.fill(res,1); StateSetWithFeatures stateSetF = null; if (stateSet.wordIndex == -2) { stateSetF = new StateSetWithFeatures(stateSet); int wordIndex = wordIndexer.indexOf(stateSet.getWord()); if (wordIndex<0||(wordIndex>=0 && (wordCounter[wordIndex]<=knownWordCount))){ stateSetF = (StateSetWithFeatures)tallyFeatures(stateSet, false); } int f = featureIndexer.indexOf(stateSet.getWord()); if (f>=0) stateSetF.features.add(f); // stateSetF.wordIndex = -3; // stateSet = lastStateSet; // } else if (stateSet.wordIndex == -3){ // stateSet = lastStateSet; } else { stateSetF = (StateSetWithFeatures) stateSet; } boolean noFeat = true; for (int f : stateSetF.features){ // if (f>tagWordIndexer[tag].size()) // System.out.println("hier"); if (f<0) continue; int tagF = tagWordIndexer[tag].indexOf(f); if (tagF<0) continue; noFeat = false; double[] resF = rules[tag][tagF].scores; for (int i=0; i<res.length; i++){ res[i] *= resF[i]; } } // if (noFeat) { // System.out.println("No features for word "+stateSet.getWord()+" "+wordIndexer.indexOf(stateSet.getWord())); // } return res; } @Override public String toString() { StringBuffer sb = new StringBuffer(); Numberer tagNumberer = Numberer.getGlobalNumberer("tags"); PriorityQueue<Pair<Integer,Integer>> pQ = new PriorityQueue<Pair<Integer,Integer>>(); for (int tag=0; tag<rules.length; tag++){ int[] counts = new int[6]; String tagS = (String)tagNumberer.object(tag); if (rules[tag].length==0) continue; for (int word=0; word<featureIndexer.size(); word++){ int wordT = tagWordIndexer[tag].indexOf(word); if (wordT<0) continue; String w = featureIndexer.get(word); if (w.length()>4 && w.substring(0, 4).equals("SUFF")){ pQ.add(new Pair(tag,word), rules[tag][wordT].scores[0]); } } } while (pQ.hasNext()){ Pair<Integer,Integer> p = pQ.next(); int word = p.getSecond(); int tag = p.getFirst(); String tagS = (String)tagNumberer.object(tag); int wordT = tagWordIndexer[tag].indexOf(word); sb.append(tagS+" "+ featureIndexer.get(word)+"\n"); sb.append(rules[tag][wordT].toString()); sb.append("\n\n"); } sb.append("-----------Start unsorted----------\n"); for (int tag=0; tag<rules.length; tag++){ int[] counts = new int[6]; String tagS = (String)tagNumberer.object(tag); if (rules[tag].length==0) continue; for (int word=0; word<featureIndexer.size(); word++){ int wordT = tagWordIndexer[tag].indexOf(word); if (wordT<0) continue; sb.append(tagS+" "+ featureIndexer.get(word)+"\n"); sb.append(rules[tag][wordT].toString()); sb.append("\n\n"); counts[rules[tag][wordT].hierarchy.getDepth()]++; } System.out.print(tagNumberer.object(tag)+", lexical rules per level: "); for (int i=1; i<6; i++){ System.out.print(counts[i]+" "); } System.out.print("\n"); } return sb.toString(); } }