package edu.fudan.ml.classifier.struct; import java.io.IOException; import java.util.Arrays; import edu.fudan.ml.classifier.Predict; import edu.fudan.ml.classifier.linear.Linear; import edu.fudan.ml.classifier.linear.OnlineTrainer; import edu.fudan.ml.classifier.linear.inf.Inferencer; import edu.fudan.ml.classifier.linear.update.Update; import edu.fudan.ml.loss.Loss; import edu.fudan.ml.types.Instance; import edu.fudan.ml.types.InstanceSet; import edu.fudan.util.MyArrays; public class OnlineStructTrainer extends OnlineTrainer { public OnlineStructTrainer(Inferencer msolver, Update update, Loss loss, int fsize, int iternum, float c) { super(msolver, update, loss, fsize, iternum, c); } public Linear train(InstanceSet trainset, InstanceSet devset) { int numSamples = trainset.size(); System.out.println("Training Number: " + numSamples); float hisErrRate = Float.MAX_VALUE; long beginTime, endTime; long beginTimeIter, endTimeIter; int iter = 0; int frac = numSamples / 10; float[] averageWeights = null; if (method == TrainMethod.Average || method == TrainMethod.FastAverage) { averageWeights = new float[weights.length]; } beginTime = System.currentTimeMillis(); if (shuffle) trainset.shuffle(); while (iter++ < iternum) { if (!simpleOutput) { System.out.print("iter:"); System.out.print(iter + "\t"); } float err = 0; float errtot = 0; int cnt = 0; int cnttot = 0; int progress = frac; beginTimeIter = System.currentTimeMillis(); float[] innerWeights = null; if (method == TrainMethod.Average) { innerWeights = Arrays.copyOf(weights, weights.length); } int innerCount = 0; for (int ii = 0; ii < numSamples; ii++) { Instance inst = trainset.getInstance(ii); float l = inst.length(); float dl = Float.MAX_VALUE; do { dl = l; Predict pred = (Predict) inferencer.getBest(inst); l = loss.calc(pred.getLabel(0), inst.getTarget()); if (l > 0) { update.update(inst, weights, pred.getLabel(0), c); innerCount++; if (DEBUG) { pred = (Predict) inferencer.getBest(inst); float nl = loss.calc(pred.getLabel(0), inst.getTarget()); } } dl -= l; } while (l != 0 && Math.abs(dl) > 0); cnt += inst.length(); cnttot++; if (l > 0) { err += l; errtot++; } if (method == TrainMethod.Average) { for (int i = 0; i < weights.length; i++) { innerWeights[i] += weights[i]; } } if (!simpleOutput && progress != 0 && ii % progress == 0) { System.out.print('.'); progress += frac; } } if (method == TrainMethod.Average) { for (int i = 0; i < innerWeights.length; i++) { averageWeights[i] += innerWeights[i] / numSamples; } } else if (method == TrainMethod.FastAverage) { for (int i = 0; i < weights.length; i++) { averageWeights[i] += weights[i]; } } float curErrRate = err / cnt; endTimeIter = System.currentTimeMillis(); if (!simpleOutput) { System.out.println("\ttime:" + (endTimeIter - beginTimeIter) / 1000.0 + "s"); System.out.print("Train:"); System.out.print("\tTag acc:"); } System.out.print(1 - curErrRate); if (!simpleOutput) { System.out.print("\tSentence acc:"); System.out.println(1 - errtot / cnttot); } if (devset != null) { evaluate(devset); } System.out.println(); hisErrRate = curErrRate; if (interim) { Linear p = new Linear(inferencer, trainset.getAlphabetFactory()); try { p.saveTo("tmp.model"); } catch (IOException e) { System.err.println("write model error!"); } } } if (method == TrainMethod.Average || method == TrainMethod.FastAverage) { for (int i = 0; i < averageWeights.length; i++) { averageWeights[i] /= iternum; } weights = null; weights = averageWeights; inferencer.setWeights(weights); } System.out.print("Weight Numbers: " + MyArrays.countNoneZero(weights)); if (finalOptimized) { int[] idx = MyArrays.getTop(weights.clone(), threshold, false); System.out.print("Opt: weight numbers: " + MyArrays.countNoneZero(weights)); MyArrays.set(weights, idx, 0.0f); System.out.println(" -> " + MyArrays.countNoneZero(weights)); } System.out.println(); endTime = System.currentTimeMillis(); System.out.println("time escape:" + (endTime - beginTime) / 1000.0 + "s"); Linear p = new Linear(inferencer, trainset.getAlphabetFactory()); return p; } }