package edu.fudan.ml.classifier.struct.update;
import java.io.Serializable;
import edu.fudan.ml.classifier.linear.update.Update;
import edu.fudan.ml.types.Instance;
import edu.fudan.nlp.pipe.seq.templet.TempletGroup;
import gnu.trove.iterator.TIntIntIterator;
import gnu.trove.map.hash.TIntIntHashMap;
/**
* @deprecated
* @author feng
*
*/
public class HigherOrderViterbiPAUpdate implements Serializable, Update {
private static final long serialVersionUID = 4198902147638417425L;
TempletGroup templets;
/**
* 模板个数
*/
int numTemplets;
/**
* label个数
*/
int numLabels;
/**
* 是否用损失函数
*/
private boolean useLoss = false;
public HigherOrderViterbiPAUpdate(TempletGroup templets, int numLabels, boolean useLoss2) {
this.templets = templets;
this.numTemplets = templets.size();
this.numLabels = numLabels;
this.useLoss = useLoss2;
}
public float update(Instance inst, float[] weights, Object predictLabel,
float c) {
return update(inst, weights, predictLabel, null, c);
}
/**
* data 每个元素为特征空间索引位置 1 ... T T列(模板个数) 1 N行(序列长度) 2 . data[r][t] N
* 第t个模板作用在第r个位置 得到feature的起始位置
*
* target[r],predict[r] label的编号
*
* @param c
* @param weights
*/
public float update(Instance inst, float[] weights, Object predictLabel,
Object goldenLabel, float c) {
int[][] data = (int[][]) inst.getData();
int[] target;
if (goldenLabel == null)
target = (int[]) inst.getTarget();
else
target = (int[]) goldenLabel;
int[] predict = (int[]) predictLabel;
// 当前clique中不同label的个数
int ne = 0;
/**
* 偏移索引
*
*/
int tS = 0, pS = 0;
float diffW = 0;
int loss = 0;
int L = data.length;
// 稀疏矩阵表示(f(x,y)-f(x,\bar{y}))
TIntIntHashMap diffF = new TIntIntHashMap(); // 最多有2*L*numTemplets个不同
for (int o = -templets.maxOrder - 1, l = 0; l < L; o++, l++) {
tS = tS * numLabels % templets.numStates + target[l]; // 目标值:计算当前状态组合的y空间偏移
pS = pS * numLabels % templets.numStates + predict[l];// 预测值:计算当前状态组合的y空间偏移
if (predict[l] != target[l])
ne++;
if (o >= 0 && (predict[o] != target[o]))
ne--; // 减去移出clique的节点的label差异
if (ne > 0) { // 当前clique有不相同label
loss++; // L(y,ybar)
for (int t = 0; t < numTemplets; t++) {
if (data[l][t] == -1)
continue;
int tI = data[l][t] + templets.offset[t][tS]; // 特征索引:找到对应weights的维数
int pI = data[l][t] + templets.offset[t][pS]; // 特征索引:找到对应weights的维数
if (tI != pI) {
diffF.adjustOrPutValue(tI, 1, 1);
diffF.adjustOrPutValue(pI, -1, -1);
diffW += weights[tI] - weights[pI]; // w^T(f(x,y)-f(x,ybar))
}
}
}
}
float diff = 0;
TIntIntIterator it = diffF.iterator();
for (int i = diffF.size(); i-- > 0;) {
it.advance();
diff += it.value() * it.value();
}
it = null;
float alpha;
float delta;
if (useLoss) {
delta = loss;
} else
delta = 1;
if (diffW < delta) {
tS = 0;
pS = 0;
ne = 0;
alpha = (delta - diffW) / diff;
// System.out.println(alpha);
alpha = Math.min(c, alpha);
it = diffF.iterator();
for (int i = diffF.size(); i-- > 0;) {
it.advance();
weights[it.key()] += it.value() * alpha;
}
return loss;
} else {
return 0;
}
}
}