package edu.fudan.ml.classifier.linear.update; import edu.fudan.ml.loss.Loss; import edu.fudan.ml.types.Instance; import edu.fudan.ml.types.sv.HashSparseVector; /** * 抽象参数更新类,采用PA算法 * \mathbf{w_{t+1}} = \w_t + {\alpha^*(\Phi(x,y)- \Phi(x,\hat{y}))}. * \alpha =\frac{1- \mathbf{w_t}^T \left(\Phi(x,y) - \Phi(x,\hat{y})\right)}{||\Phi(x,y) - \Phi(x,\hat{y})||^2}. * @author Feng Ji * */ public abstract class AbstractPAUpdate implements Update { /** * \mathbf{w_t}^T \left(\Phi(x,y) - \Phi(x,\hat{y})\right) */ protected float diffw; /** * \Phi(x,y)- \Phi(x,\hat{y}) */ protected HashSparseVector diffv; protected Loss loss; /** * 是否使用样本的权重进行加权 */ public boolean useInstWeight; public AbstractPAUpdate(Loss loss) { diffw = 0; diffv = new HashSparseVector(); this.loss = loss; } /** * 参数更新方法 * @param inst 样本实例 * @param weights 权重 * @param predict 预测答案 * @param c 步长阈值 * @return 预测答案和标准答案之间的损失 */ public float update(Instance inst, float[] weights, Object predict, float c) { return update(inst, weights, inst.getTarget(), predict, c); } /** * 参数更新方法 * @param inst 样本实例 * @param weights 权重 * @param target 对照答案 * @param predict 预测答案 * @param c 步长阈值 * @return 预测答案和对照答案之间的损失 */ public float update(Instance inst, float[] weights, Object target, Object predict, float c) { int lost = diff(inst, weights, target, predict); if(lost==0) return 0f; float lamda = diffv.l2Norm2(); if (diffw <= lost) { float alpha = (lost - diffw) / lamda; if(useInstWeight) alpha = alpha*inst.getWeight(); if(alpha>c){ alpha = c; }else{ alpha=alpha; } int[] idx = diffv.indices(); for (int i = 0; i < idx.length; i++) { weights[idx[i]] += diffv.get(idx[i]) * alpha; } } diffv.clear(); diffw = 0; return loss.calc(target, predict); } /** * 计算预测答案和对照答案之间的距离 * @param inst 样本实例 * @param weights 权重 * @param target 对照答案 * @param predict 预测答案 * @return 预测答案和对照答案之间的距离 */ protected abstract int diff(Instance inst, float[] weights, Object target, Object predict); protected void adjust(float[] weights, int ts, int ps) { assert (ts != -1 && ps != -1); diffv.put(ts, 1.0f); diffv.put(ps, -1.0f); diffw += weights[ts] - weights[ps]; } }