package edu.stanford.nlp.sequences; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counters; import java.util.Arrays; /** A SequenceFinder which can efficiently return a k-best list of sequence labellings. * * @author Jenny Finkel * @author Sven Zethelius */ public class KBestSequenceFinder implements BestSequenceFinder { /** * Runs the Viterbi algorithm on the sequence model * in order to find the best sequence. * This sequence finder only works on SequenceModel's with rightWindow == 0. * * @return An array containing the int tags of the best sequence */ @Override public int[] bestSequence(SequenceModel ts) { return Counters.argmax(kBestSequences(ts, 1)); } /** * Runs the Viterbi algorithm on the sequence model, and then proceeds to efficiently * backwards decode the best k label sequence assignments. * This sequence finder only works on SequenceModel's with rightWindow == 0. * * @param ts The SequenceModel to find the best k label sequence assignments of * @param k The number of top-scoring assignments to find. * @return A Counter with k entries that map from a sequence assignment (int array) to a double score */ @SuppressWarnings("MethodMayBeStatic") public Counter<int[]> kBestSequences(SequenceModel ts, int k) { // Set up tag options int length = ts.length(); int leftWindow = ts.leftWindow(); int rightWindow = ts.rightWindow(); if (rightWindow != 0) { throw new IllegalArgumentException("KBestSequenceFinder only works with rightWindow == 0 not " + rightWindow); } int padLength = length + leftWindow + rightWindow; int[][] tags = new int[padLength][]; int[] tagNum = new int[padLength]; for (int pos = 0; pos < padLength; pos++) { tags[pos] = ts.getPossibleValues(pos); tagNum[pos] = tags[pos].length; } int[] tempTags = new int[padLength]; // Set up product space sizes int[] productSizes = new int[padLength]; int curProduct = 1; for (int i = 0; i < leftWindow; i++) { curProduct *= tagNum[i]; } for (int pos = leftWindow; pos < padLength; pos++) { if (pos > leftWindow + rightWindow) { curProduct /= tagNum[pos - leftWindow - rightWindow - 1]; // shift off } curProduct *= tagNum[pos]; // shift on productSizes[pos - rightWindow] = curProduct; } double[][] windowScore = new double[padLength][]; // Score all of each window's options for (int pos = leftWindow; pos < leftWindow + length; pos++) { windowScore[pos] = new double[productSizes[pos]]; Arrays.fill(tempTags, tags[0][0]); for (int product = 0; product < productSizes[pos]; product++) { int p = product; int shift = 1; for (int curPos = pos; curPos >= pos - leftWindow; curPos--) { tempTags[curPos] = tags[curPos][p % tagNum[curPos]]; p /= tagNum[curPos]; if (curPos > pos) { shift *= tagNum[curPos]; } } if (tempTags[pos] == tags[pos][0]) { // get all tags at once double[] scores = ts.scoresOf(tempTags, pos); // fill in the relevant windowScores for (int t = 0; t < tagNum[pos]; t++) { windowScore[pos][product + t * shift] = scores[t]; } } } } // Set up score and backtrace arrays double[][][] score = new double[padLength][][]; int[][][][] trace = new int[padLength][][][]; int[][] numWaysToMake = new int[padLength][]; for (int pos = 0; pos < padLength; pos++) { score[pos] = new double[productSizes[pos]][]; trace[pos] = new int[productSizes[pos]][][]; // the 2 is for backtrace, and which of the k best for that backtrace numWaysToMake[pos] = new int[productSizes[pos]]; Arrays.fill(numWaysToMake[pos], 1); for (int product = 0; product < productSizes[pos]; product++) { if (pos > leftWindow) { // loop over possible predecessor types int sharedProduct = product / tagNum[pos]; int factor = productSizes[pos] / tagNum[pos]; numWaysToMake[pos][product] = 0; for (int newTagNum = 0; newTagNum < tagNum[pos - leftWindow - 1] && numWaysToMake[pos][product] < k; newTagNum++) { int predProduct = newTagNum * factor + sharedProduct; numWaysToMake[pos][product] += numWaysToMake[pos-1][predProduct]; } if (numWaysToMake[pos][product] > k) { numWaysToMake[pos][product] = k; } } score[pos][product] = new double[numWaysToMake[pos][product]]; Arrays.fill(score[pos][product], Double.NEGATIVE_INFINITY); trace[pos][product] = new int[numWaysToMake[pos][product]][]; Arrays.fill(trace[pos][product], new int[]{-1,-1}); } } // Do forward Viterbi algorithm // this is the hottest loop, so cache loop control variables hoping for a little speed.... // loop over the classification spot for (int pos = leftWindow, posMax = length + leftWindow; pos < posMax; pos++) { // loop over window product types for (int product = 0, productMax = productSizes[pos]; product < productMax; product++) { // check for initial spot double[] scorePos = score[pos][product]; int[][] tracePos = trace[pos][product]; if (pos == leftWindow) { // no predecessor type scorePos[0] = windowScore[pos][product]; } else { // loop over possible predecessor types/k-best int sharedProduct = product / tagNum[pos + rightWindow]; int factor = productSizes[pos] / tagNum[pos + rightWindow]; for (int newTagNum = 0, maxTagNum = tagNum[pos - leftWindow - 1]; newTagNum < maxTagNum; newTagNum++) { int predProduct = newTagNum * factor + sharedProduct; double[] scorePosPrev = score[pos-1][predProduct]; for (int k1 = 0; k1 < scorePosPrev.length; k1++) { double predScore = scorePosPrev[k1] + windowScore[pos][product]; if (predScore > scorePos[0]) { // new value higher then lowest value we should keep int k2 = Arrays.binarySearch(scorePos, predScore); k2 = k2 < 0 ? -k2 - 2 : k2 - 1; // open a spot at k2 by shifting off the lowest value System.arraycopy(scorePos, 1, scorePos, 0, k2); System.arraycopy(tracePos, 1, tracePos, 0, k2); scorePos[k2] = predScore; tracePos[k2]= new int[] {predProduct, k1}; } } } } } } // Project the actual tag sequence int[] whichDerivation = new int[k]; int[] bestCurrentProducts = new int[k]; double[] bestFinalScores = new double[k]; Arrays.fill(bestFinalScores, Double.NEGATIVE_INFINITY); // just the last guy for (int product = 0; product < productSizes[padLength - 1]; product++) { double[] scorePos = score[padLength - 1][product]; for (int k1 = scorePos.length - 1; k1 >= 0 && scorePos[k1] > bestFinalScores[0]; k1--) { int k2 = Arrays.binarySearch(bestFinalScores, scorePos[k1]); k2 = k2 < 0 ? -k2 - 2 : k2 - 1; // open a spot at k2 by shifting off the lowest value System.arraycopy(bestFinalScores, 1, bestFinalScores, 0, k2); System.arraycopy(whichDerivation, 1, whichDerivation, 0, k2); System.arraycopy(bestCurrentProducts, 1, bestCurrentProducts, 0, k2); bestCurrentProducts[k2] = product; whichDerivation[k2] = k1; bestFinalScores[k2] = scorePos[k1]; } } ClassicCounter<int[]> kBestWithScores = new ClassicCounter<>(); for (int k1 = k - 1; k1 >= 0 && bestFinalScores[k1] > Double.NEGATIVE_INFINITY; k1--) { int lastProduct = bestCurrentProducts[k1]; for (int last = padLength - 1; last >= length - 1 && last >= 0; last--) { tempTags[last] = tags[last][lastProduct % tagNum[last]]; lastProduct /= tagNum[last]; } for (int pos = leftWindow + length - 2; pos >= leftWindow; pos--) { int bestNextProduct = bestCurrentProducts[k1]; bestCurrentProducts[k1] = trace[pos + 1][bestNextProduct][whichDerivation[k1]][0]; whichDerivation[k1] = trace[pos + 1][bestNextProduct][whichDerivation[k1]][1]; tempTags[pos - leftWindow] = tags[pos - leftWindow][bestCurrentProducts[k1] / (productSizes[pos] / tagNum[pos - leftWindow])]; } kBestWithScores.setCount(Arrays.copyOf(tempTags, tempTags.length), bestFinalScores[k1]); } return kBestWithScores; } }