/**
*
*/
package edu.berkeley.nlp.PCFGLA;
import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import edu.berkeley.nlp.discPCFG.WordInSentence;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.ArrayUtil;
import edu.berkeley.nlp.util.Counter;
import edu.berkeley.nlp.util.Indexer;
import edu.berkeley.nlp.util.Numberer;
import edu.berkeley.nlp.util.Pair;
import edu.berkeley.nlp.util.PriorityQueue;
/**
* @author petrov
*
*/
public class SpanPredictor implements Serializable{
private static final long serialVersionUID = 1L;
public final boolean useFirstAndLast;
public final boolean usePreviousAndNext; // can only be on if useFirstAndLast is on
public final boolean useBeginAndEndPairs;
public final boolean useSyntheticClass;
public final boolean usePunctuation;
Indexer<String> punctuationSignatures;
boolean[] isPunctuation;
public final boolean useOnlyWords = true;
// public final boolean useCapitalization;
public final int minFeatureFrequency;
public final int minSpanLength = 3;
public double[][] firstWordScore;
public double[][] lastWordScore;
public double[][] previousWordScore;
public double[][] nextWordScore;
public double[][] beginPairScore;
public double[][] endPairScore;
private HashMap<Pair<Integer,Integer>,Integer> beginMap;
private HashMap<Pair<Integer,Integer>,Integer> endMap;
public double[][] punctuationScores;
public int nWords;
public int nFeatures;
private int[] stateClass;
private int nClasses;
private Indexer<String> wordIndexer;
// public int startIndexPrevious, startIndexBegin;
public SpanPredictor(int nWords, StateSetTreeList trainTrees, Numberer tagNumberer, Indexer<String> wordIndexer){
this.useFirstAndLast = ConditionalTrainer.Options.useFirstAndLast;
this.usePreviousAndNext = ConditionalTrainer.Options.usePreviousAndNext;
this.useBeginAndEndPairs = ConditionalTrainer.Options.useBeginAndEndPairs;
this.useSyntheticClass = ConditionalTrainer.Options.useSyntheticClass;
this.usePunctuation = ConditionalTrainer.Options.usePunctuation;
this.minFeatureFrequency = ConditionalTrainer.Options.minFeatureFrequency;
this.wordIndexer = wordIndexer;
this.nWords = nWords;
this.nFeatures = 0;
if (useSyntheticClass){
System.out.println("Distinguishing between real and synthetic classes.");
stateClass = new int[tagNumberer.total()];
for (int i=0; i<tagNumberer.total(); i++){
String state = (String)tagNumberer.object(i);
if (state.charAt(0)=='@')
stateClass[i] = 1; // synthetic
}
nClasses = 2;
} else {
stateClass = new int[tagNumberer.total()];
nClasses = 1;
}
if (useFirstAndLast){
firstWordScore = new double[nWords][nClasses];
lastWordScore = new double[nWords][nClasses];
ArrayUtil.fill(firstWordScore,1);
ArrayUtil.fill(lastWordScore,1);
this.nFeatures += 2*nWords*nClasses;
}
if (usePreviousAndNext){
previousWordScore = new double[nWords][nClasses];
nextWordScore = new double[nWords][nClasses];
ArrayUtil.fill(previousWordScore,1);
ArrayUtil.fill(nextWordScore,1);
this.nFeatures += 2*nWords*nClasses;
}
if (useBeginAndEndPairs){
initPairs(trainTrees);
}
if (usePunctuation){
initPunctuations(trainTrees);
}
}
private void initPairs(StateSetTreeList trainTrees) {
beginMap = new HashMap<Pair<Integer,Integer>, Integer>();
endMap = new HashMap<Pair<Integer,Integer>, Integer>();
Counter<Pair<Integer,Integer>> beginPairCounter = new Counter<Pair<Integer,Integer>>();
Counter<Pair<Integer,Integer>> endPairCounter = new Counter<Pair<Integer,Integer>>();
int beginPairs = 0, endPairs = 0;
for (Tree<StateSet> tree : trainTrees){
List<StateSet> words = tree.getYield();
StateSet stateSet = words.get(0);
int prevIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) prevIndex = stateSet.wordIndex;
int currIndex = -1;
for (int i=1; i<=words.size()-minSpanLength; i++){
stateSet = words.get(i);
currIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) currIndex = stateSet.wordIndex;
Pair<Integer,Integer> pair = new Pair<Integer,Integer>(prevIndex,currIndex);
beginPairCounter.incrementCount(pair, 1.0);
if (!beginMap.containsKey(pair))
beginMap.put(pair,beginPairs++);
prevIndex = currIndex;
}
if (words.size() < minSpanLength) continue;
stateSet = words.get(minSpanLength-1);
prevIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) currIndex = stateSet.wordIndex;
for (int i=minSpanLength; i<words.size(); i++){
stateSet = words.get(i);
currIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) currIndex = stateSet.wordIndex;
Pair<Integer,Integer> pair = new Pair<Integer,Integer>(prevIndex,currIndex);
endPairCounter.incrementCount(pair, 1.0);
if (!endMap.containsKey(pair))
endMap.put(pair,endPairs++);
prevIndex = currIndex;
}
}
HashMap<Pair<Integer,Integer>, Integer> newBeginMap = new HashMap<Pair<Integer,Integer>, Integer>();
HashMap<Pair<Integer,Integer>, Integer> newEndMap = new HashMap<Pair<Integer,Integer>, Integer>();
int newBeginPairs = 0;
for (Pair<Integer,Integer> pair : beginMap.keySet()){
if (beginPairCounter.getCount(pair) >= minFeatureFrequency){
newBeginMap.put(pair, newBeginPairs++);
}
}
beginMap = newBeginMap;
beginPairs = newBeginPairs;
int newEndPairs = 0;
for (Pair<Integer,Integer> pair : endMap.keySet()){
if (endPairCounter.getCount(pair) >= minFeatureFrequency){
newEndMap.put(pair, newEndPairs++);
}
}
endMap = newEndMap;
endPairs = newEndPairs;
beginPairScore = new double[beginPairs][nClasses];
endPairScore = new double[endPairs][nClasses];
nFeatures += (beginPairs + endPairs)*nClasses;
System.out.println("There were "+beginPairs+" begin-pair types and "+endPairs+" end-pair types.");
}
public double[] scoreSpan(int previousIndex, int firstIndex, int lastIndex, int followingIndex){
double[] result = new double[nClasses];
Arrays.fill(result, 1);
if (firstIndex<0||lastIndex<0) {
// System.out.println("unseen index when scoring span: "+firstIndex+" "+lastIndex);
return result;
}
for (int c=0; c<nClasses; c++){
// if (c==1) continue;
if (useFirstAndLast) result[c] *= firstWordScore[firstIndex][c] * lastWordScore[lastIndex][c];
if (usePreviousAndNext){
if (previousIndex>=0) result[c] *= previousWordScore[previousIndex][c];
if (followingIndex>=0) result[c] *= nextWordScore[followingIndex][c];
}
if (useBeginAndEndPairs){
if (previousIndex>=0) {
int index = getBeginIndex(previousIndex, firstIndex);
if (index>=0) result[c] *= beginPairScore[index][c];
}
if (followingIndex>=0) {
int index = getEndIndex(lastIndex, followingIndex);
if (index>=0) result[c] *= endPairScore[index][c];
}
}
if (SloppyMath.isDangerous(result[c])){
System.out.println("Dangerous span prediction set to 1, since it was "+result);
result[c] = 1;
}
}
return result;
}
public double[][][] predictSpans(List<StateSet> sentence) {
int previousIndex=-1, firstIndex, lastIndex, followingIndex=-1;
int length = sentence.size();
double[][][] spanScores = new double[length][length+1][nClasses];
// all spans of size <=minSpanLength are ok
for (int start = 0; start < length; start++) {
for (int end = start + 1; end < start+minSpanLength && end<=length; end++) {
for (int clas=0; clas < nClasses; clas++){
spanScores[start][end][clas] = 1;
}
}
}
for (int start = 0; start <= length-minSpanLength; start++) {
StateSet stateSet = sentence.get(start);
firstIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) firstIndex = stateSet.wordIndex;
for (int end = start + minSpanLength; end <= length; end++) {
stateSet = sentence.get(end-1);
lastIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) lastIndex = stateSet.wordIndex;
if (end<length){
stateSet = sentence.get(end);
followingIndex = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) followingIndex = stateSet.wordIndex;
} else {
followingIndex = -1;
}
spanScores[start][end] = scoreSpan(previousIndex, firstIndex, lastIndex, followingIndex);
}
previousIndex = firstIndex;
}
if (usePunctuation){
int[][] punctSignatures = getPunctuationSignatures(sentence);
for (int start = 0; start <= length-minSpanLength; start++) {
for (int end = start + minSpanLength; end <= length; end++) {
int sig = punctSignatures[start][end];
if (sig==-1) continue;
for (int c=0; c<nClasses; c++){
spanScores[start][end][c] *= punctuationScores[sig][c];
}
}
}
}
return spanScores;
}
public double[] countGoldSpanFeatures(StateSetTreeList trainTrees){
int[][] firstWordCount = null, lastWordCount = null;
int[][] previousWordCount = null, nextWordCount = null;
int[][] beginPairsCount = null, endPairsCount = null;
int[][] punctuationCount = null, punctuationSig = null;
if (useFirstAndLast){
firstWordCount = new int[nWords][nClasses];
lastWordCount = new int[nWords][nClasses];
}
if (usePreviousAndNext){
previousWordCount = new int[nWords][nClasses];
nextWordCount = new int[nWords][nClasses];
}
if (useBeginAndEndPairs){
beginPairsCount = new int[beginPairScore.length][nClasses];
endPairsCount = new int[endPairScore.length][nClasses];
}
if (usePunctuation){
punctuationCount = new int[punctuationSignatures.size()][nClasses];
}
for (Tree<StateSet> tree : trainTrees){
List<StateSet> words = tree.getYield();
if (usePunctuation) punctuationSig = getPunctuationSignatures(words);
countGoldSpanFeaturesHelper(tree, words, firstWordCount, lastWordCount,
previousWordCount, nextWordCount, beginPairsCount, endPairsCount,
punctuationCount, punctuationSig);
}
double[] res = new double[nFeatures];
int index = 0;
if (useFirstAndLast){
int firstSum = 0, lastSum = 0;
for (int c=0; c<nWords; c++){
firstSum += ArrayUtil.sum(firstWordCount[c]);
lastSum += ArrayUtil.sum(lastWordCount[c]);
}
System.out.println("Number of first words: "+firstSum);
System.out.println("Number of last words: "+lastSum);
for (int i=0; i<nWords; i++){
for (int c=0; c<nClasses; c++){
res[index++] = firstWordCount[i][c];
}
}
for (int i=0; i<nWords; i++){
for (int c=0; c<nClasses; c++){
res[index++] = lastWordCount[i][c];
}
}
}
if (usePreviousAndNext){
int prevSum = 0, nextSum = 0;
for (int c=0; c<nWords; c++){
prevSum += ArrayUtil.sum(previousWordCount[c]);
nextSum += ArrayUtil.sum(nextWordCount[c]);
}
System.out.println("Number of previous words: "+prevSum);
System.out.println("Number of next words: "+nextSum);
for (int i=0; i<nWords; i++){
for (int c=0; c<nClasses; c++){
res[index++] = previousWordCount[i][c];
}
}
for (int i=0; i<nWords; i++){
for (int c=0; c<nClasses; c++){
res[index++] = nextWordCount[i][c];
}
}
}
if (useBeginAndEndPairs){
int beginSum = 0, endSum = 0;
for (int i=0; i<beginPairsCount.length; i++){
beginSum += ArrayUtil.sum(beginPairsCount[i]);
}
for (int i=0; i<endPairsCount.length; i++){
endSum += ArrayUtil.sum(endPairsCount[i]);
}
System.out.println("Number of begin pairs: "+beginSum);
System.out.println("Number of end pairs: "+endSum);
for (int i=0; i<beginPairsCount.length; i++){
for (int c=0; c<nClasses; c++){
res[index++] = beginPairsCount[i][c];
}
}
for (int i=0; i<endPairsCount.length; i++){
for (int c=0; c<nClasses; c++){
res[index++] = endPairsCount[i][c];
}
}
}
if (usePunctuation){
for (int i=0; i<punctuationCount.length; i++){
for (int c=0; c<nClasses; c++){
res[index++] = punctuationCount[i][c];
}
}
}
return res;
}
private void countGoldSpanFeaturesHelper(Tree<StateSet> tree, List<StateSet> words,
int[][] firstWordCount, int[][] lastWordCount, int[][] previousWordCount, int[][] nextWordCount,
int[][] beginPairsCount, int[][] endPairsCount, int[][] punctuationCount, int[][] punctuationSignatures) {
StateSet node = tree.getLabel();
if (node.to - node.from < minSpanLength) return;
short state = node.getState();
int thisClass = stateClass[state];
StateSet stateSet = words.get(node.from);
int firstWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) firstWord = stateSet.wordIndex;
stateSet = words.get(node.to-1);
int lastWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) lastWord = stateSet.wordIndex;
int previousWord = 0, nextWord = 0;
if (node.from > 0) {
stateSet = words.get(node.from-1);
previousWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) previousWord = stateSet.wordIndex;
}
if (node.to < words.size()) {
stateSet = words.get(node.to);
nextWord = (stateSet.sigIndex<0) ? stateSet.wordIndex : stateSet.sigIndex;
if (useOnlyWords) nextWord = stateSet.wordIndex;
}
if (useFirstAndLast){
firstWordCount[firstWord][thisClass]++;
lastWordCount[lastWord][thisClass]++;
}
if (usePreviousAndNext){
if (node.from > 0) previousWordCount[previousWord][thisClass]++;
if (node.to < words.size()) nextWordCount[nextWord][thisClass]++;
}
if (useBeginAndEndPairs){
if (node.from > 0) {
int beginIndex = getBeginIndex(previousWord, firstWord);
if (beginIndex>=0)
beginPairsCount[beginIndex][thisClass]++;
}
if (node.to < words.size()) {
int endIndex = getEndIndex(lastWord, nextWord);
if (endIndex>=0)
endPairsCount[endIndex][thisClass]++;
}
}
if (usePunctuation){
int punctSig = punctuationSignatures[node.from][node.to];
if (punctSig>=0)
punctuationCount[punctSig][thisClass]++;
}
for (Tree<StateSet> child : tree.getChildren()){
countGoldSpanFeaturesHelper(child, words, firstWordCount, lastWordCount, previousWordCount, nextWordCount,
beginPairsCount, endPairsCount, punctuationCount, punctuationSignatures);
}
}
private void initPunctuations(StateSetTreeList trainTrees){
punctuationSignatures = new Indexer<String>();
isPunctuation = new boolean[nWords];
Counter<String> punctSigCounter = new Counter<String>();
for (int word=0; word<nWords; word++){
isPunctuation[word] = isPunctuation(wordIndexer.get(word));
}
for (Tree<StateSet> tree : trainTrees){
getPunctuationSignatures(tree.getYield(), true, punctSigCounter);
}
Indexer<String> newPunctuationSignatures = new Indexer<String>();
for (String sig : punctSigCounter.keySet()){
if (punctSigCounter.getCount(sig) >= minFeatureFrequency)
newPunctuationSignatures.add(sig);
}
punctuationSignatures = newPunctuationSignatures;
punctuationScores = new double[punctuationSignatures.size()][nClasses];
ArrayUtil.fill(punctuationScores,1);
nFeatures += nClasses*punctuationScores.length;
}
private boolean isPunctuation(String word){
if (word.length()>2) return false;
if (Character.isLetterOrDigit(word.charAt(0))) return false;
if (word.length()==1) return true;
return !Character.isLetterOrDigit(word.charAt(1));
}
private int appendItem(StringBuilder sb, String maskedWord, int nWordsBefore){
if (maskedWord != X) {
sb.append(maskedWord);
nWordsBefore = 0;
} else if (nWordsBefore==0){
sb.append("x");
nWordsBefore++;
} else if (nWordsBefore==1){
sb.append("+");
nWordsBefore++;
}
return nWordsBefore;
}
public int[][] getPunctuationSignatures(List<StateSet> sentence){
return getPunctuationSignatures(sentence, false, null);
}
private final String X = "x".intern();
// replace words with x and leave only punctuation, collapse xx,xxx,xxxx,... to x+
public int[][] getPunctuationSignatures(List<StateSet> sentence, boolean update, Counter<String> punctSigCounter){
int length = sentence.size();
String[] masked = new String[length];
for (int i=0; i<length; i++) {
StateSet thisStateSet = sentence.get(i);
masked[i] = (thisStateSet.wordIndex>0&&isPunctuation[thisStateSet.wordIndex]) ? thisStateSet.getWord() : X;
}
int[][] result = new int[length][length+1];
ArrayUtil.fill(result, -1);
for (int start = 0; start <= length-minSpanLength; start++) {
StringBuilder sb = new StringBuilder();
String prev = "";
if (start<=1) sb.append("<S>");
int nWordsBefore = 0;
if (start>0){
appendItem(sb, masked[start-1], nWordsBefore);
}
sb.append("[");
nWordsBefore = appendItem(sb, masked[start], 0);
for (int end = start + minSpanLength; end <= length; end++) {
nWordsBefore = appendItem(sb, masked[end-1], nWordsBefore);
prev = sb.toString();
sb.append("]");
if (end<length){
appendItem(sb, masked[end], 0);
}
if (end<length-1){
sb.append("<E>");
}
String sig = sb.toString();
if (update) {
punctuationSignatures.add(sig);
punctSigCounter.incrementCount(sig, 1.0);
}
result[start][end] = punctuationSignatures.indexOf(sig);
sb = new StringBuilder(prev);
}
}
return result;
}
public String toString(){
return toString(null);
}
public String toString(Indexer<String> wordIndexer){
StringBuffer sb = new StringBuffer();
if (useFirstAndLast||usePreviousAndNext){
sb.append("word");
if (useFirstAndLast) sb.append("\tfirst\t\tlast\t");
if (usePreviousAndNext) sb.append("\tprevious\tfollowing");
sb.append("\n");
for (int word=0; word<nWords; word++){
String w = (wordIndexer!=null) ? wordIndexer.get(word) : word+"";
sb.append(w);
if (useFirstAndLast) sb.append("\t"+Arrays.toString(firstWordScore[word])+"\t"+Arrays.toString(lastWordScore[word]));
if (usePreviousAndNext) sb.append("\t"+Arrays.toString(previousWordScore[word])+"\t"+Arrays.toString(nextWordScore[word]));
sb.append("\n");
}
if (useFirstAndLast){
PriorityQueue<String> pQf = new PriorityQueue<String>();
PriorityQueue<String> pQl = new PriorityQueue<String>();
PriorityQueue<String> pQp = null;
PriorityQueue<String> pQn = null;
if (usePreviousAndNext){
pQp = new PriorityQueue<String>();
pQn = new PriorityQueue<String>();
}
for (int word=0; word<nWords; word++){
String w = (wordIndexer!=null) ? wordIndexer.get(word) : word+"";
pQf.add(w, firstWordScore[word][0]);
pQl.add(w, lastWordScore[word][0]);
if (usePreviousAndNext){
pQp.add(w, previousWordScore[word][0]);
pQn.add(w, nextWordScore[word][0]);
}
}
sb.append("First word weights\tLast word weights");
if (usePreviousAndNext){
sb.append("\tPrevious word weights\tNext word weights");
}
sb.append("\n");
while (pQf.hasNext()){
double weight = pQf.getPriority();
sb.append(pQf.next()+" "+weight+"\t");
weight = pQl.getPriority();
sb.append(pQl.next()+" "+weight+"\t");
if (usePreviousAndNext){
weight = pQp.getPriority();
sb.append(pQp.next()+" "+weight+"\t");
weight = pQn.getPriority();
sb.append(pQn.next()+" "+weight);
}
sb.append("\n");
}
}
}
if (useBeginAndEndPairs){
sb.append("Begin pairs\t\t\t\tEnd pairs\n");
PriorityQueue<String> pQb = new PriorityQueue<String>();
PriorityQueue<String> pQe = new PriorityQueue<String>();
for (Pair p : beginMap.keySet()){
String w1 = wordIndexer.get((Integer)p.getFirst());
String w2 = wordIndexer.get((Integer)p.getSecond());
pQb.add("("+w1+" | "+w2+"),", beginPairScore[beginMap.get(p)][0]);
}
for (Pair p : endMap.keySet()){
String w1 = wordIndexer.get((Integer)p.getFirst());
String w2 = wordIndexer.get((Integer)p.getSecond());
pQe.add("("+w1+" | "+w2+"),", endPairScore[endMap.get(p)][0]);
}
while (pQb.hasNext()||pQe.hasNext()){
double weight = 0;
if (pQb.hasNext()){
weight = pQb.getPriority();
sb.append(pQb.next()+" "+weight+"\t");
} else sb.append("\t\t\t\t");
if (pQe.hasNext()){
weight = pQe.getPriority();
sb.append(pQe.next()+" "+weight+"\n");
} else sb.append("\n");
}
}
if (usePunctuation){
sb.append("Punctuation features:\n");
PriorityQueue<String> pQp = new PriorityQueue<String>();
for (int f=0; f<punctuationSignatures.size(); f++){
String w = punctuationSignatures.get(f);
pQp.add(w, punctuationScores[f][0]);
}
while (pQp.hasNext()){
double weight = pQp.getPriority();
String word = pQp.next();
sb.append(word+"\t");
if (word.length()<8) sb.append("\t");
sb.append(+weight+"\n");
}
}
return sb.toString();
}
public int getBeginIndex(int previousIndex, int currIndex) {
Pair<Integer,Integer> pair = new Pair<Integer,Integer>(previousIndex, currIndex);
if (!beginMap.containsKey(pair)) return -1;
return beginMap.get(pair);
}
public int getEndIndex(int previousIndex, int currIndex) {
Pair<Integer,Integer> pair = new Pair<Integer,Integer>(previousIndex, currIndex);
if (!endMap.containsKey(pair)) return -1;
return endMap.get(pair);
}
/**
* @return the stateClass
*/
public int[] getStateClass() {
return stateClass;
}
/**
* @return the nClasses
*/
public final int getNClasses() {
return nClasses;
}
// public class FeatureBundle{
// public int firstWord;
// public int lastWord;
// public int previousWord;
// public int nextWord;
//
// public int beginPair;
// public int endPair;
//
//
// }
}