package edu.stanford.nlp.sequences;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counters;
import java.util.Arrays;
/**
* @author Jenny Finkel
*/
public class KBestSequenceFinder implements BestSequenceFinder {
/**
* Runs the Viterbi algorithm on the sequence model
* in order to find the best sequence.
*
* @return An array containing the int tags of the best sequence
*/
@Override
public int[] bestSequence(SequenceModel ts) {
return Counters.argmax(kBestSequences(ts, 1));
}
public ClassicCounter<int[]> kBestSequences(SequenceModel ts, int k) {
// Set up tag options
int length = ts.length();
int leftWindow = ts.leftWindow();
int rightWindow = ts.rightWindow();
assert (rightWindow == 0);
int padLength = length + leftWindow + rightWindow;
int[][] tags = new int[padLength][];
int[] tagNum = new int[padLength];
for (int pos = 0; pos < padLength; pos++) {
tags[pos] = ts.getPossibleValues(pos);
tagNum[pos] = tags[pos].length;
}
int[] tempTags = new int[padLength];
// Set up product space sizes
int[] productSizes = new int[padLength];
int curProduct = 1;
for (int i = 0; i < leftWindow; i++) {
curProduct *= tagNum[i];
}
for (int pos = leftWindow; pos < padLength; pos++) {
if (pos > leftWindow + rightWindow) {
curProduct /= tagNum[pos - leftWindow - rightWindow - 1]; // shift off
}
curProduct *= tagNum[pos]; // shift on
productSizes[pos - rightWindow] = curProduct;
}
double[][] windowScore = new double[padLength][];
// Score all of each window's options
for (int pos = leftWindow; pos < leftWindow + length; pos++) {
windowScore[pos] = new double[productSizes[pos]];
Arrays.fill(tempTags, tags[0][0]);
for (int product = 0; product < productSizes[pos]; product++) {
int p = product;
int shift = 1;
for (int curPos = pos; curPos >= pos - leftWindow; curPos--) {
tempTags[curPos] = tags[curPos][p % tagNum[curPos]];
p /= tagNum[curPos];
if (curPos > pos) {
shift *= tagNum[curPos];
}
}
if (tempTags[pos] == tags[pos][0]) {
// get all tags at once
double[] scores = ts.scoresOf(tempTags, pos);
// fill in the relevant windowScores
for (int t = 0; t < tagNum[pos]; t++) {
windowScore[pos][product + t * shift] = scores[t];
}
}
}
}
// Set up score and backtrace arrays
double[][][] score = new double[padLength][][];
int[][][][] trace = new int[padLength][][][];
int[][] numWaysToMake = new int[padLength][];
for (int pos = 0; pos < padLength; pos++) {
score[pos] = new double[productSizes[pos]][];
trace[pos] = new int[productSizes[pos]][][]; // the 2 is for backtrace, and which of the k best for that backtrace
numWaysToMake[pos] = new int[productSizes[pos]];
Arrays.fill(numWaysToMake[pos], 1);
for (int product = 0; product < productSizes[pos]; product++) {
if (pos == leftWindow) {
numWaysToMake[pos][product] = 1;
} else if (pos > leftWindow) {
// loop over possible predecessor types
int sharedProduct = product / tagNum[pos];
int factor = productSizes[pos] / tagNum[pos];
numWaysToMake[pos][product] = 0;
for (int newTagNum = 0; newTagNum < tagNum[pos - leftWindow - 1]; newTagNum++) {
int predProduct = newTagNum * factor + sharedProduct;
numWaysToMake[pos][product] += numWaysToMake[pos-1][predProduct];
}
if (numWaysToMake[pos][product] > k) { numWaysToMake[pos][product] = k; }
} else {
numWaysToMake[pos][product] = 1;
}
score[pos][product] = new double[numWaysToMake[pos][product]];
trace[pos][product] = new int[numWaysToMake[pos][product]][2];
}
}
// Do forward Viterbi algorithm
// loop over the classification spot
for (int pos = leftWindow; pos < length + leftWindow; pos++) {
// loop over window product types
for (int product = 0; product < productSizes[pos]; product++) {
// check for initial spot
if (pos == leftWindow) {
// no predecessor type
score[pos][product][0] = windowScore[pos][product];
trace[pos][product][0][0] = -1;
trace[pos][product][0][1] = -1;
} else {
// loop over possible predecessor types/k-best
for (int k1 = 0; k1 < score[pos][product].length; k1++) {
score[pos][product][k1] = Double.NEGATIVE_INFINITY;
trace[pos][product][k1][0] = -1;
trace[pos][product][k1][1] = -1;
}
int sharedProduct = product / tagNum[pos + rightWindow];
int factor = productSizes[pos] / tagNum[pos + rightWindow];
for (int newTagNum = 0; newTagNum < tagNum[pos - leftWindow - 1]; newTagNum++) {
int predProduct = newTagNum * factor + sharedProduct;
for (int k1 = 0; k1 < score[pos-1][predProduct].length; k1++) {
double predScore = score[pos - 1][predProduct][k1] + windowScore[pos][product];
for (int k2 = 0; k2 < score[pos][product].length; k2++) {
if (predScore > score[pos][product][k2]) {
System.arraycopy(score[pos][product], k2, score[pos][product], k2+1, score[pos][product].length-(k2+1));
System.arraycopy(trace[pos][product], k2, trace[pos][product], k2+1, trace[pos][product].length-(k2+1));
score[pos][product][k2] = predScore;
trace[pos][product][k2]= new int[2];
trace[pos][product][k2][0] = predProduct;
trace[pos][product][k2][1] = k1;
break;
}
}
}
}
}
}
}
// Project the actual tag sequence
int[][] kBest = new int[k][padLength];
int[] whichDerivation = new int[k];
int[] bestCurrentProducts = new int[k];
double[] bestFinalScores = new double[k];
Arrays.fill(bestFinalScores, Double.NEGATIVE_INFINITY);
// just the last guy
for (int product = 0; product < productSizes[padLength - 1]; product++) {
for (int k1 = 0; k1 < score[padLength - 1][product].length; k1++) {
for (int k2 = 0; k2 < bestFinalScores.length; k2++) {
if (score[padLength - 1][product][k1] > bestFinalScores[k2]) {
System.arraycopy(bestFinalScores, k1, bestFinalScores, k1+1, bestFinalScores.length-(k1+1));
System.arraycopy(whichDerivation, k1, whichDerivation, k1+1, whichDerivation.length-(k1+1));
System.arraycopy(bestCurrentProducts, k1, bestCurrentProducts, k1+1, bestCurrentProducts.length-(k1+1));
bestCurrentProducts[k2] = product;
whichDerivation[k2] = k1;
bestFinalScores[k2] = score[padLength - 1][product][k1];
break;
}
}
}
}
int[] lastProducts = new int[k];
System.arraycopy(bestCurrentProducts, 0, lastProducts, 0, lastProducts.length);
for (int last = padLength - 1; last >= length - 1 && last >= 0; last--) {
for (int k1 = 0; k1 < lastProducts.length; k1++) {
kBest[k1][last] = tags[last][lastProducts[k1] % tagNum[last]];
lastProducts[k1] /= tagNum[last];
}
}
for (int pos = padLength - 2; pos >= leftWindow; pos--) {
System.arraycopy(bestCurrentProducts, 0, lastProducts, 0, lastProducts.length);
Arrays.fill(bestCurrentProducts, -1);
for (int k1 = 0; k1 < lastProducts.length; k1++) {
bestCurrentProducts[k1] = trace[pos + 1][lastProducts[k1]][whichDerivation[k1]][0];
whichDerivation[k1] = trace[pos + 1][lastProducts[k1]][whichDerivation[k1]][1];
kBest[k1][pos - leftWindow] = tags[pos - leftWindow][bestCurrentProducts[k1] / (productSizes[pos] / tagNum[pos - leftWindow])];
}
}
ClassicCounter<int[]> kBestWithScores = new ClassicCounter<int[]>();
for (int i = 0; i < kBest.length; i++) {
if(bestFinalScores[i] > Double.NEGATIVE_INFINITY) {
kBestWithScores.setCount(kBest[i], bestFinalScores[i]);
//System.err.println(bestFinalScores[i]+"\t"+Arrays.toString(kBest[i]));
}
}
return kBestWithScores;
}
}