package edu.fudan.ml.classifier.hier; import java.io.IOException; import java.util.ArrayList; import java.util.HashSet; import edu.fudan.ml.classifier.hier.inf.MultiLinearMax; import edu.fudan.ml.classifier.linear.inf.Inferencer; import edu.fudan.ml.eval.Evaluation; import edu.fudan.ml.feature.BaseGenerator; import edu.fudan.ml.loss.Loss; import edu.fudan.ml.types.Instance; import edu.fudan.ml.types.InstanceSet; import edu.fudan.ml.types.alphabet.LabelAlphabet; import edu.fudan.ml.types.sv.HashSparseVector; import edu.fudan.util.MyArrays; import edu.fudan.util.MyHashSparseArrays; /** * 大规模层次化多类分类训练 * 针对类别数很大,因此权重向量用稀疏数组表示 * @author xpqiu * @since 1.0 */ public class PATrainer { /** * 特征权重数组,每个类对于一个权重 */ private HashSparseVector[] weights; /** * 输出的分类器 */ private Linear classifier; /** * 推理器 */ private MultiLinearMax msolver; /** * 特征生成器 */ private BaseGenerator featureGen; /** * 损失函数 */ private Loss loss; // 最大迭代次数 private int maxIter = Integer.MAX_VALUE; private Tree tree; private float c; /** * 保存中间结果 */ public boolean interim=false; public boolean optim=false; private boolean incremental =false; /** * 收敛控制,保留最近的错误率个数 */ private static final int historyNum = 5; /** * 收敛控制,最小误差 */ private static final float eps = 1e-10f; public PATrainer(Linear pc, Loss loss, int maxIter, float c, Tree tr){ msolver = (MultiLinearMax) pc.inf; msolver.isUseTarget(true); featureGen = pc.gen; this.loss = loss; this.maxIter = maxIter; tree = tr; this.c = c; incremental = true; weights = pc.weights; } public PATrainer(Inferencer msolver, BaseGenerator featureGen, Loss loss, int maxIter, float c, Tree tr) { this.msolver = (MultiLinearMax) msolver; this.featureGen = featureGen; this.loss = loss; this.maxIter = maxIter; tree = tr; this.c = c; } public Linear getClassifier() { return classifier; } /** * 训练 * * @param eval */ public Linear train(InstanceSet trainingList, Evaluation eval) { System.out.println("Sample Size: " + trainingList.size()); LabelAlphabet labels = trainingList.getAlphabetFactory().DefaultLabelAlphabet(); System.out.println("Class Size: " + labels.size()); if(!incremental){ // 初始化权重向量到类中心 weights = Mean.mean(trainingList, tree); msolver.setWeight(weights); } float[] hisErrRate = new float[historyNum]; int numSamples = trainingList.size(); int frac = numSamples / 10; // 开始循环 System.out.println("Begin Training..."); long beginTime = System.currentTimeMillis(); int loops = 0; //循环计数 while (loops++ < maxIter) { System.out.print("Loop: " + loops); float totalerror = 0; trainingList.shuffle(); long beginTimeInner = System.currentTimeMillis(); for (int ii = 0; ii < numSamples; ii++) { Instance inst = trainingList.getInstance(ii); int maxC = (Integer) inst.getTarget(); // HashSet<Integer> t = new HashSet<Integer>(); // t.add(maxC); Predict pred = (Predict) msolver.getBest(inst, 1); //从临时数据中取出正确标签打分信息,并删除 Predict oracle = (Predict) inst.getTempData(); inst.deleteTempData(); int maxE = pred.getLabel(0); int error; if (tree == null) { error = (pred.getLabel(0) == maxC) ? 0 : 1; } else { error = tree.dist(maxE, maxC); } float loss = error- (oracle.getScore(0) - pred.getScore(0)); if (loss > 0) {// 预测错误,更新权重 totalerror += 1; // 计算含层次信息的内积 // 计算步长 float phi = featureGen.getVector(inst).l2Norm2(); float alpha = (float) Math.min(c, loss / (phi * error)); if (tree != null) { int[] anc = tree.getPath(maxC); for (int j = 0; j < anc.length; j++) { weights[anc[j]].plus(featureGen.getVector(inst), alpha); } anc = tree.getPath(maxE); for (int j = 0; j < anc.length; j++) { weights[anc[j]].plus(featureGen.getVector(inst), -alpha); } } else { weights[maxC].plus(featureGen.getVector(inst), alpha); weights[maxE].plus(featureGen.getVector(inst), -alpha); } } if (frac==0||ii % frac == 0) {// 显示进度 System.out.print('.'); } } float acc = 1 - totalerror / numSamples; System.out.print("\t Accuracy:" + acc); System.out.println("\t Time(s):" + (System.currentTimeMillis() - beginTimeInner) / 1000); if(optim&&loops<=2){ int oldnum = 0; int newnum = 0; for(int i = 0;i<weights.length;i++){ oldnum += weights[i].size(); MyHashSparseArrays.trim(weights[i],0.99f); newnum += weights[i].size(); } System.out.println("优化:\t原特征数:"+oldnum + "\t新特征数:"+newnum); } if (interim) { Linear p = new Linear(weights, msolver, featureGen, trainingList.getPipes(), trainingList.getAlphabetFactory()); try { p.saveTo("./tmp/model.gz"); } catch (IOException e) { System.err.println("write model error!"); } msolver.isUseTarget(true); } if (eval != null) { System.out.print("Test:\t"); Linear classifier = new Linear(weights, msolver); eval.eval(classifier,2); msolver.isUseTarget(true); } hisErrRate[loops%historyNum] = acc; if(MyArrays.viarance(hisErrRate) < eps){ System.out.println("convergence!"); break; } } System.out.println("Training End"); System.out.println("Training Time(s):" + (System.currentTimeMillis() - beginTime) / 1000); classifier = new Linear(weights, msolver, featureGen, trainingList.getPipes(), trainingList.getAlphabetFactory()); return classifier; } }