package edu.fudan.ml.classifier.struct.inf;
import java.util.Arrays;
import edu.fudan.ml.classifier.Predict;
import edu.fudan.ml.types.Instance;
import edu.fudan.nlp.pipe.seq.templet.TempletGroup;
/**
* 一阶线性最优序列解码器
* (修改成可并行版本 2011.9.15)
* @author Feng Ji
*
*/
public class LinearViterbi extends AbstractViterbi {
private static final long serialVersionUID = -8237762672065700553L;
public LinearViterbi(TempletGroup templets, int ysize) {
this.ysize = ysize;
this.setTemplets(templets);
this.orders = templets.getOrders();
}
public LinearViterbi(int[] orders, int ysize) {
this.ysize = ysize;
this.orders = orders;
}
public int ysize() {
return ysize;
}
public int[] orders() {
return orders;
}
/**
* 构造函数
* @param viterbi 一阶线性解码器
*/
public LinearViterbi(AbstractViterbi viterbi) {
this(viterbi.getTemplets(), viterbi.ysize);
this.weights = viterbi.getWeights();
}
@Override
public Predict<int[]> getBest(Instance carrier) {
Node[][] node = initialLattice(carrier);
doForwardViterbi(node, carrier);
Predict<int[]> res = getPath(node);
return res;
}
/**
* 构造并初始化网格
* @param carrier 样本实例
* @return 推理网格
*/
protected Node[][] initialLattice(Instance carrier) {
int[][] data = (int[][]) carrier.getData();
int length = carrier.length();
Node[][] lattice = new Node[length][];
for (int l = 0; l < length; l++) {
lattice[l] = new Node[ysize];
for (int c = 0; c < ysize; c++) {
lattice[l][c] = new Node(ysize);
for (int i = 0; i < orders.length; i++) {
if (data[l][i] == -1 || data[l][i]>=weights.length) //TODO: xpqiu 2013.2.1
continue;
if (orders[i] == 0) {
lattice[l][c].score += weights[data[l][i] + c];
} else if (orders[i] == 1) {
int offset = c;
for (int p = 0; p < ysize; p++) {
//weights对应trans(c,p)的按行展开
lattice[l][c].trans[p] += weights[data[l][i]
+ offset];
offset += ysize;
}
}
}
}
}
return lattice;
}
/**
* 前向Viterbi算法
* @param lattice 网格
* @param carrier 样本实例
*/
protected void doForwardViterbi(Node[][] lattice, Instance carrier) {
for (int l = 1; l < lattice.length; l++) {
for (int c = 0; c < lattice[l].length; c++) {
if (lattice[l][c] == null)
continue;
float bestScore = Float.NEGATIVE_INFINITY;
int bestPath = -1;
for (int p = 0; p < lattice[l - 1].length; p++) {
if (lattice[l - 1][p] == null)
continue;
float score = lattice[l - 1][p].score
+ lattice[l][c].trans[p];
if (score > bestScore) {
bestScore = score;
bestPath = p;
}
}
bestScore += lattice[l][c].score;
lattice[l][c].addScore(bestScore, bestPath);
}
}
}
/**
* 回溯获得最优路径
* @param lattice 网格
* @return 最优路径及其得分
*/
protected Predict<int[]> getPath(Node[][] lattice) {
Predict<int[]> res = new Predict<int[]>();
if (lattice.length == 0)
return res;
float max = Float.NEGATIVE_INFINITY;
int cur = 0;
for (int c = 0; c < ysize(); c++) {
if (lattice[lattice.length-1][c] == null)
continue;
if (lattice[lattice.length - 1][c].score > max) {
max = lattice[lattice.length - 1][c].score;
cur = c;
}
}
int[] path = new int[lattice.length];
path[lattice.length - 1] = cur;
for (int l = lattice.length - 1; l > 0; l--) {
cur = lattice[l][cur].prev;
path[l - 1] = cur;
}
res.add(path,max);
return res;
}
final class Node {
float base = 0;
float score = 0;
int prev = -1;
float[] trans = null;
public Node(int n) {
base = 0;
score = 0;
prev = -1;
trans = new float[n];
}
public void addScore(float score, int path) {
this.score = score;
this.prev = path;
}
public void clear() {
base = 0;
score = 0;
prev = -1;
Arrays.fill(trans, 0);
}
}
}