package edu.fudan.ml.classifier.struct.update; import edu.fudan.ml.classifier.linear.update.AbstractPAUpdate; import edu.fudan.ml.classifier.struct.inf.LinearViterbi; import edu.fudan.ml.loss.struct.HammingLoss; import edu.fudan.ml.types.Instance; import edu.fudan.nlp.pipe.seq.templet.TempletGroup; /** * 一阶线性序列的参数更新类,采用PA算法 * @author Feng Ji * */ public class LinearViterbiPAUpdate extends AbstractPAUpdate { private int ysize; private int[] golds; private int[] preds; private int[][] data; private int[] orders; public LinearViterbiPAUpdate(LinearViterbi inf, HammingLoss loss) { super(loss); this.ysize = inf.ysize(); this.orders = inf.orders(); } public LinearViterbiPAUpdate(LinearViterbi inf, HammingLoss loss, TempletGroup dynamic) { super(loss); this.ysize = inf.ysize(); this.orders = concat(inf.orders(), dynamic.getOrders()); } private int[] concat(int[] A, int[] B) { int[] C= new int[A.length+B.length]; System.arraycopy(A, 0, C, 0, A.length); System.arraycopy(B, 0, C, A.length, B.length); return C; } /** * @return 预测序列和对照序列之间不同的Clique数量 */ @Override protected int diff(Instance inst, float[] weights, Object targets, Object predicts) { data = (int[][]) inst.getData(); if (targets == null) golds = (int[]) inst.getTarget(); else golds = (int[]) targets; preds = (int[]) predicts; int diff = 0; if (golds[0] != preds[0]) { diff++; diffClique(weights, 0); } for (int p = 1; p < data.length; p++) { if (golds[p - 1] != preds[p - 1] || golds[p] != preds[p]) { diff++; diffClique(weights, p); } } return diff; } /** * 调整权重 * @param weights 权重 * @param p 位置 */ private void diffClique(float[] weights, int p) { for (int t = 0; t < orders.length; t++) { if (data[p][t] == -1) continue; if (orders[t] == 0) { if (golds[p] != preds[p]) { int ts = data[p][t] + golds[p]; int ps = data[p][t] + preds[p]; adjust(weights, ts, ps); } } if (p > 0 && orders[t] == 1) { int ts = data[p][t] + (golds[p - 1] * ysize + golds[p]); int ps = data[p][t] + (preds[p - 1] * ysize + preds[p]); adjust(weights, ts, ps); } } } }