package edu.stanford.nlp.sequences;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.fsm.*;
import java.util.Arrays;
/**
* @author Michel Galley
* @author Sarah Spikes (sdspikes@cs.stanford.edu) - cleanup and filling in types
*/
public class ViterbiSearchGraphBuilder {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(ViterbiSearchGraphBuilder.class);
public static DFSA<String, Integer> getGraph(SequenceModel ts, Index<String> classIndex) {
DFSA<String, Integer> viterbiSearchGraph = new DFSA<>(null);
// Set up tag options
int length = ts.length();
int leftWindow = ts.leftWindow();
int rightWindow = ts.rightWindow();
assert (rightWindow == 0);
int padLength = length + leftWindow + rightWindow;
// NOTE: tags[i][j] : i is index into pos, and j into product
int[][] tags = new int[padLength][];
int[] tagNum = new int[padLength];
for (int pos = 0; pos < padLength; pos++) {
tags[pos] = ts.getPossibleValues(pos);
tagNum[pos] = tags[pos].length;
}
// Set up Viterbi search graph:
DFSAState<String, Integer>[][] graphStates = null;
DFSAState<String, Integer> startState = null, endState = null;
if(viterbiSearchGraph != null) {
int stateId = -1;
startState = new DFSAState<>(++stateId, viterbiSearchGraph, 0.0);
viterbiSearchGraph.setInitialState(startState);
graphStates = new DFSAState[length][];
for(int pos = 0; pos<length; ++pos) {
//System.err.printf("%d states at pos %d\n",tags[pos].length,pos);
graphStates[pos] = new DFSAState[tags[pos].length];
for(int product = 0; product < tags[pos].length; ++product) {
graphStates[pos][product] = new DFSAState<>(++stateId, viterbiSearchGraph);
}
}
// Accepting state:
endState = new DFSAState<>(++stateId, viterbiSearchGraph, 0.0);
endState.setAccepting(true);
}
int[] tempTags = new int[padLength];
// Set up product space sizes
int[] productSizes = new int[padLength];
int curProduct = 1;
for (int i = 0; i < leftWindow; i++) {
curProduct *= tagNum[i];
}
for (int pos = leftWindow; pos < padLength; pos++) {
if (pos > leftWindow + rightWindow) {
curProduct /= tagNum[pos - leftWindow - rightWindow - 1]; // shift off
}
curProduct *= tagNum[pos]; // shift on
productSizes[pos - rightWindow] = curProduct;
}
double[][] windowScore = new double[padLength][];
// Score all of each window's options
for (int pos = leftWindow; pos < leftWindow + length; pos++) {
windowScore[pos] = new double[productSizes[pos]];
Arrays.fill(tempTags, tags[0][0]);
for (int product = 0; product < productSizes[pos]; product++) {
int p = product;
int shift = 1;
for (int curPos = pos; curPos >= pos - leftWindow; curPos--) {
tempTags[curPos] = tags[curPos][p % tagNum[curPos]];
p /= tagNum[curPos];
if (curPos > pos) {
shift *= tagNum[curPos];
}
}
if (tempTags[pos] == tags[pos][0]) {
// get all tags at once
double[] scores = ts.scoresOf(tempTags, pos);
// fill in the relevant windowScores
for (int t = 0; t < tagNum[pos]; t++) {
windowScore[pos][product + t * shift] = scores[t];
}
}
}
}
// loop over the classification spot
for (int pos = leftWindow; pos < length + leftWindow; pos++) {
// loop over window product types
for (int product = 0; product < productSizes[pos]; product++) {
if (pos == leftWindow) {
// all nodes in the first spot link to startState:
int curTag = tags[pos][product % tagNum[pos]];
//System.err.printf("pos=%d, product=%d, tag=%d score=%.3f\n",pos,product,curTag,windowScore[pos][product]);
DFSATransition<String, Integer> tr =
new DFSATransition<>("", startState, graphStates[pos][product], classIndex.get(curTag), "", -windowScore[pos][product]);
startState.addTransition(tr);
} else {
int sharedProduct = product / tagNum[pos + rightWindow];
int factor = productSizes[pos] / tagNum[pos + rightWindow];
for (int newTagNum = 0; newTagNum < tagNum[pos - leftWindow - 1]; newTagNum++) {
int predProduct = newTagNum * factor + sharedProduct;
int predTag = tags[pos-1][predProduct % tagNum[pos-1]];
int curTag = tags[pos][product % tagNum[pos]];
//log.info("pos: "+pos);
//log.info("product: "+product);
//System.err.printf("pos=%d-%d, product=%d-%d, tag=%d-%d score=%.3f\n",pos-1,pos,predProduct,product,predTag,curTag,
// windowScore[pos][product]);
DFSAState<String, Integer> sourceState = graphStates[pos-leftWindow][predTag];
DFSAState<String, Integer> destState = (pos-leftWindow+1==graphStates.length) ? endState : graphStates[pos-leftWindow+1][curTag];
DFSATransition<String, Integer> tr =
new DFSATransition<>("", sourceState, destState, classIndex.get(curTag), "", -windowScore[pos][product]);
graphStates[pos-leftWindow][predTag].addTransition(tr);
}
}
}
}
return viterbiSearchGraph;
}
}