package edu.fudan.ml.classifier.struct.inf; import java.util.Arrays; import edu.fudan.ml.classifier.Predict; import edu.fudan.ml.classifier.struct.inf.AbstractViterbi; import edu.fudan.ml.types.Instance; import edu.fudan.nlp.pipe.seq.templet.TempletGroup; /** * 双链最优解码器 * * @author Feng Ji * */ public class HybridViterbi extends AbstractViterbi { private static final long serialVersionUID = 5485421022552472597L; private int[] ysize; private int length; private int[][] orders; public HybridViterbi(TempletGroup[] templets, int ssize, int tsize) { this.ysize = new int[] { ssize, tsize }; this.orders = new int[templets.length][]; for (int i = 0; i < templets.length; i++) { this.orders[i] = templets[i].getOrders(); } } public int[] ysize() { return ysize; } public int[][] orders() { return orders; } @Override public Predict<int[][]> getBest(Instance inst) { Node[][][] lattice = initialLattice(inst); doForwardViterbi(lattice); return getForwardPath(lattice); } /** * 已知分段结果的情况下,最优双链解码方法 * * @param inst * 样本实例 * @return 双链标注结果 */ public Predict<int[][]> getBestWithSegs(Instance inst) { Node[][][] lattice = initialLatticeWithSegs(inst); doForwardViterbi(lattice); return getForwardPath(lattice); } /** * 已知分段结果的情况下,构造并初始化网格,不经过的节点设置为NULL * * @param inst * 样本实例 * @return 双链网格 */ private Node[][][] initialLatticeWithSegs(Instance inst) { int[][][] data = (int[][][]) inst.getData(); int[] tags = (int[]) inst.getTempData(); length = inst.length(); Node[][][] lattice = new Node[2][length][]; for (int i = 0; i < length; i++) { lattice[0][i] = new Node[ysize[0]]; lattice[0][i][tags[i]] = new Node(ysize[0], ysize[1]); initialClique(lattice[0][i], data[0][i], orders[0], ysize[0], ysize[1]); lattice[1][i] = new Node[ysize[1]]; for (int j = 0; j < ysize[1]; j++) { lattice[1][i][j] = new Node(ysize[1], ysize[0]); } initialClique(lattice[1][i], data[1][i], orders[1], ysize[1], ysize[0]); } return lattice; } /** * 构造并初始化网格 * * @param inst * 样本实例 * @return 双链网格 */ private Node[][][] initialLattice(Instance inst) { int[][][] data = (int[][][]) inst.getData(); length = inst.length(); Node[][][] lattice = new Node[2][length][]; for (int i = 0; i < length; i++) { lattice[0][i] = new Node[ysize[0]]; for (int j = 0; j < ysize[0]; j++) { lattice[0][i][j] = new Node(ysize[0], ysize[1]); } initialClique(lattice[0][i], data[0][i], orders[0], ysize[0], ysize[1]); lattice[1][i] = new Node[ysize[1]]; for (int j = 0; j < ysize[1]; j++) { lattice[1][i][j] = new Node(ysize[1], ysize[0]); } initialClique(lattice[1][i], data[1][i], orders[1], ysize[1], ysize[0]); } return lattice; } private void initialClique(Node[] node, int[] data, int[] order, int nsize, int msize) { int scalar = msize * nsize; for (int k = 0; k < node.length; k++) { if (node[k] == null) continue; for (int t = 0; t < order.length; t++) { if (data[t] == -1) continue; int base = data[t]; if (order[t] == 0) { for (int j = 0; j < msize; j++) { node[k].score[j] += weights[base + k]; } } if (order[t] == -1) { int offset = k; for (int j = 0; j < msize; j++) { node[k].score[j] += weights[base + offset]; offset += nsize; } } if (order[t] == 1) { int offset = k; for (int i = 0; i < nsize; i++) { for (int j = 0; j < msize; j++) { node[k].trans[i][j] += weights[base + offset]; } offset += nsize; } } if (order[t] == 2) { for (int i = 0; i < nsize; i++) { int offset = i * scalar + k; for (int j = 0; j < msize; j++) { node[k].trans[i][j] += weights[base + offset]; offset += nsize; } } } } } } private void doForwardViterbi(Node[][][] lattice) { for (int j = 0; j < ysize[1]; j++) { for (int i = 0; i < ysize[0]; i++) { if (lattice[0][0][i] == null) continue; float score = lattice[1][0][j].score[i] + lattice[0][0][i].score[0]; lattice[1][0][j].addScore(score, i, -1); } } for (int p = 1; p < length; p++) { for (int k = 0; k < ysize[0]; k++) { if (lattice[0][p][k] == null) continue; for (int j = 0; j < ysize[1]; j++) { float bestScore = Float.NEGATIVE_INFINITY; int bestPath = -1; for (int i = 0; i < ysize[0]; i++) { if (lattice[0][p - 1][i] == null) continue; float score = lattice[1][p - 1][j].score[i]; score += lattice[0][p][k].trans[i][j]; if (score > bestScore) { bestScore = score; bestPath = i; } } bestScore += lattice[0][p][k].score[j]; lattice[0][p][k].addScore(bestScore, j, bestPath); } } for (int k = 0; k < ysize[1]; k++) { for (int j = 0; j < ysize[0]; j++) { if (lattice[0][p][j] == null) continue; float bestScore = Float.NEGATIVE_INFINITY; int bestPath = -1; for (int i = 0; i < ysize[1]; i++) { float score = lattice[0][p][j].score[i]; score += lattice[1][p][k].trans[i][j]; if (score > bestScore) { bestScore = score; bestPath = i; } } bestScore += lattice[1][p][k].score[j]; lattice[1][p][k].addScore(bestScore, j, bestPath); } } } } private Predict<int[][]> getForwardPath(Node[][][] lattice) { Predict<int[][]> res = new Predict<int[][]>(); float maxScore = Float.NEGATIVE_INFINITY; int u = -1; int d = -1; for (int j = 0; j < ysize[1]; j++) { for (int i = 0; i < ysize[0]; i++) { if (lattice[1][length - 1][j].prev[i] != -1 || length == 1) { float score = lattice[1][length - 1][j].score[i]; if (score > maxScore) { maxScore = score; u = i; d = j; } } } } int[][] path = new int[2][length]; path[0][length - 1] = u; path[1][length - 1] = d; for (int i = length - 2; i >= 0; i--) { d = lattice[1][i + 1][d].prev[u]; path[1][i] = d; u = lattice[0][i + 1][u].prev[d]; path[0][i] = u; } res.add(path,maxScore); return res; } final class Node { float[][] trans = null; float[] score = null; int[] prev = null; public Node(int m, int n) { trans = new float[m][]; for (int i = 0; i < m; i++) { trans[i] = new float[n]; } score = new float[n]; prev = new int[n]; Arrays.fill(prev, -1); } public void addScore(float score, int j, int i) { this.score[j] = score; this.prev[j] = i; } } }