package edu.fudan.ml.classifier.struct;
import java.io.IOException;
import edu.fudan.ml.classifier.Predict;
import edu.fudan.ml.classifier.linear.Linear;
import edu.fudan.ml.classifier.linear.inf.Inferencer;
import edu.fudan.ml.classifier.linear.update.Update;
import edu.fudan.ml.loss.Loss;
import edu.fudan.ml.types.Instance;
import edu.fudan.ml.types.InstanceSet;
import edu.fudan.ml.types.alphabet.IFeatureAlphabet;
import edu.fudan.util.MyArrays;
public class PATrainer {
/**
* 特征权重
*/
float[] weights;
/**
* 推理算法
*/
Inferencer msolver;
/**
* 权重更新
*/
public Update update;
/**
* 损失函数
*/
Loss loss;
/**
* 最大迭代次数
*/
int maxIter = Integer.MAX_VALUE;
/**
* 最小错误
*/
private float eps = 1e-5f;
/**
* 字数
*/
public int count;
/**
* PA算法参数C,用来调节约束强弱
*/
float c;
/**
* 是否优化,在每次迭代时,将不显著的特征权重置为0
*/
public boolean isOptimized = false;
/**
* 优化阈值,权重绝对值按大小排序,只保留前n个最大的。
* 要求sum_{i=1}^n{(w_i^2)/||w||}^2>threshold。threshold默认为0.999
*/
float threshold = 0.999f;
public int forceUpdateLen = -1;
boolean simpleOutput = false;
boolean usePerceptron = true;
public boolean interim = false;
public PATrainer(Inferencer msolver, Update update, Loss loss,
IFeatureAlphabet features, int maxIter, float c) {
this.msolver = msolver;
this.update = update;
this.loss = loss;
this.maxIter = maxIter;
this.c = c;
weights = new float[features.size()];
this.msolver.setWeights(weights);
}
/**
* 训练
*/
public Linear train(InstanceSet trainingList, InstanceSet testList) {
int numSamples = trainingList.size();
count = 0;
for (int ii = 0; ii < trainingList.size(); ii++) {
Instance inst = trainingList.getInstance(ii);
count += ((int[]) inst.getTarget()).length;
}
System.out.println("Chars Number: " + count);
float oldErrorRate = Float.MAX_VALUE;
// 开始循环
long beginTime, endTime;
long beginTimeIter, endTimeIter;
beginTime = System.currentTimeMillis();
float pE = 0;
int iter = 0;
int frac = numSamples / 10;
while (iter++ < maxIter) {
if (!simpleOutput) {
System.out.print("iter:");
System.out.print(iter + "\t");
}
float err = 0;
float errorAll = 0;
beginTimeIter = System.currentTimeMillis();
int progress = frac;
for (int ii = 0; ii < numSamples; ii++) {
Instance inst = trainingList.getInstance(ii);
Predict pred = (Predict) msolver.getBest(inst, 1);
float l = loss.calc(pred.getLabel(0), inst.getTarget());
if (l > 0) {// 预测错误,更新权重
errorAll += 1.0;
err += l;
update.update(inst, weights, pred.getLabel(0), c);
} else {
if (pred.size() > 1)
update.update(inst, weights, pred.getLabel(1), c);
}
if (!simpleOutput && ii % progress == 0) {// 显示进度
System.out.print('.');
progress += frac;
}
}
float errRate = err / count;
endTimeIter = System.currentTimeMillis();
if (!simpleOutput) {
System.out.println("\ttime:" + (endTimeIter - beginTimeIter)
/ 1000.0 + "s");
System.out.print("Train:");
System.out.print("\tTag acc:");
}
System.out.print(1 - errRate);
if (!simpleOutput) {
System.out.print("\tSentence acc:");
System.out.print(1 - errorAll / numSamples);
System.out.println();
}
if (testList != null) {
test(testList);
}
if (Math.abs(errRate - oldErrorRate) < eps) {
System.out.println("Convergence!");
break;
}
oldErrorRate = errRate;
if (interim) {
Linear p = new Linear(msolver, trainingList.getAlphabetFactory());
try {
p.saveTo("tmp.model");
} catch (IOException e) {
System.err.println("write model error!");
}
}
if (isOptimized) {// 模型优化,去掉不显著的特征
int[] idx = MyArrays.getTop(weights.clone(), threshold, false);
System.out.print("Opt: weight numbers: "
+ MyArrays.countNoneZero(weights));
MyArrays.set(weights, idx, 0.0f);
System.out.println(" -> " + MyArrays.countNoneZero(weights));
}
// System.out.println(trainingList.getAlphabetFactory().getFeatureSize());
}
endTime = System.currentTimeMillis();
System.out.println("done!");
System.out.println("time escape:" + (endTime - beginTime) / 1000.0
+ "s");
Linear p = new Linear(msolver, trainingList.getAlphabetFactory());
return p;
}
/**
* 用当前模型在测试集上进行测试 输出正确率
*
* @param testSet
*/
public void test(InstanceSet testSet) {
float err = 0;
float errorAll = 0;
int total = 0;
for (int i = 0; i < testSet.size(); i++) {
Instance inst = testSet.getInstance(i);
total += ((int[]) inst.getTarget()).length;
Predict pred = (Predict) msolver.getBest(inst, 1);
float l = loss.calc(pred.getLabel(0), inst.getTarget());
if (l > 0) {// 预测错误
errorAll += 1.0;
err += l;
}
}
if (!simpleOutput) {
System.out.print("Test:\t");
System.out.print(total - err);
System.out.print('/');
System.out.print(total);
System.out.print("\tTag acc:");
} else {
System.out.print('\t');
}
System.out.print(1 - err / total);
if (!simpleOutput) {
System.out.print("\tSentence acc:");
System.out.println(1 - errorAll / testSet.size());
}
System.out.println();
}
}