package edu.fudan.ml.classifier.struct.inf;
import java.util.Arrays;
import edu.fudan.ml.classifier.Predict;
import edu.fudan.ml.classifier.linear.inf.Inferencer;
import edu.fudan.ml.types.Instance;
import edu.fudan.ml.types.alphabet.IFeatureAlphabet;
import edu.fudan.nlp.pipe.seq.templet.TempletGroup;
import edu.fudan.util.MyArrays;
/**
* 任意阶Viterbi算法
*
* @author xpqiu
*
*/
public class HigherOrderViterbi extends AbstractViterbi {
private static final long serialVersionUID = 6023318778006156804L;
/**
* 构造函数
*
* @param features
* 特征
* @param numLabels
* 标记
* @param templets
* 模板
*/
public HigherOrderViterbi(TempletGroup templets, int numLabels) {
this.ysize = numLabels;
this.templets = templets;
this.templets.calc(numLabels);
this.numTemplets = templets.size();
numStates = templets.numStates;
}
/**
* 标记给定实例
*
* @param instance
*/
public Predict<int[]> getBest(Instance instance, int nbest) {
int[][] data;
/**
* 节点矩阵
*/
Node[][] lattice;
data = (int[][]) instance.getData();
// target = (int[]) instance.getTarget();
lattice = new Node[data.length][getTemplets().numStates];
for (int ip = 0; ip < data.length; ip++)
for (int s = 0; s < getTemplets().numStates; s++)
lattice[ip][s] = new Node(nbest);
for (int ip = 0; ip < data.length; ip++) {
// 对于每一个n阶的可能组合
for (int s = 0; s < numStates; s++) {
// 计算所有特征的权重和
for (int t = 0; t < numTemplets; t++) {
if (data[ip][t] == -1)
continue;
lattice[ip][s].weight += weights[data[ip][t]
+ getTemplets().offset[t][s]];
}
}
}
for (int s = 0; s < ysize; s++) {
lattice[0][s].best[0] = lattice[0][s].weight;
}
float[] best = new float[nbest];
Integer[] prev = new Integer[nbest];
for (int ip = 1; ip < data.length; ip++) {
for (int s = 0; s < numStates; s += ysize) {
Arrays.fill(best, Float.NEGATIVE_INFINITY);
for (int k = 0; k < ysize; k++) {
int sp = (k * getTemplets().numStates + s) / ysize;
for (int ibest = 0; ibest < nbest; ibest++) {
float b = lattice[ip - 1][sp].best[ibest];
MyArrays.addBest(best, prev, b, sp * nbest + ibest);
}
}
for (int r = s; r < s + ysize; r++) {
for (int n = 0; n < nbest; n++) {
lattice[ip][r].best[n] = best[n]
+ lattice[ip][r].weight;
lattice[ip][r].prev[n] = prev[n];
}
}
}
}
Predict<int[]> res = getPath(lattice, nbest);
return res;
}
public Predict<int[]> getBest(Instance instance) {
return getBest(instance, 1);
}
private Predict<int[]> getPath(Node[][] lattice, int nbest) {
float best;
Node lastNode = new Node(nbest);
int last = lattice.length - 1;
for (int s = 0; s < getTemplets().numStates; s++) {
for (int ibest = 0; ibest < nbest; ibest++) {
best = lattice[last][s].best[ibest];
lastNode.addBest(best, s * nbest + ibest);
}
}
Predict<int[]> res = new Predict<int[]>(nbest);
for (int k = 0; k < nbest; k++) {
int[] path = new int[lattice.length];
int p = last;
int s = lastNode.prev[k];
float score = lastNode.best[k];
for (int d = s / nbest, i = 0; i < getTemplets().maxOrder && p >= 0; i++, p--) {
path[p] = d % ysize;
d = d / ysize;
}
while (p >= 0) {
path[p] = s / nbest / getTemplets().base[getTemplets().maxOrder];
s = lattice[p + getTemplets().maxOrder][s / nbest].prev[s % nbest];
--p;
}
res.add(path,score);
}
return res;
}
public final class Node {
int n;
float weight = 0.0f;
float[] best;
int[] prev;
public Node(int n) {
this.n = n;
best = new float[n];
prev = new int[n];
}
/**
* 记录之前的label和得分,保留前n个
*
* @param score
* @param p
*/
public int addBest(float score, int p) {
int i;
for (i = 0; i < n; i++) {
if (score > best[i])
break;
}
if (i >= n)
return -1;
for (int k = n - 2; k >= i; k--) {
best[k + 1] = best[k];
prev[k + 1] = prev[k];
}
best[i] = score;
prev[i] = p;
return i;
}
public String toString() {
return String.format("%f %f %d", weight, best[0], prev[0]);
}
}
}