package edu.fudan.ml.classifier.struct;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import edu.fudan.ml.classifier.linear.Linear;
import edu.fudan.ml.classifier.linear.OnlineTrainer;
import edu.fudan.ml.classifier.linear.inf.Inferencer;
import edu.fudan.ml.classifier.linear.update.Update;
import edu.fudan.ml.loss.Loss;
import edu.fudan.ml.types.Instance;
import edu.fudan.ml.types.InstanceSet;
import edu.fudan.util.MyArrays;
public class OnlineHybridTrainer extends OnlineTrainer {
public OnlineHybridTrainer(Inferencer msolver, Update update,
Loss loss, int fsize, int iternum, float c) {
super(msolver, update, loss, fsize, iternum, c);
}
public Linear train(InstanceSet trainset, InstanceSet devset) {
int numSamples = trainset.size();
int count = 0;
for (int ii = 0; ii < numSamples; ii++) {
Instance inst = trainset.getInstance(ii);
int[][] targets = (int[][]) inst.getTarget();
for(int i = 0; i < targets.length; i++) {
count += targets[i].length;
}
}
System.out.println("Training Number: "+numSamples);
System.out.println("Chars Number: " + count);
float hisErrRate = Float.MAX_VALUE;
long beginTime, endTime;
long beginTimeIter, endTimeIter;
int iter = 0;
int frac = numSamples / 10;
float[] averageWeights = null;
if (method == TrainMethod.Average || method == TrainMethod.FastAverage) {
averageWeights = new float[weights.length];
}
beginTime = System.currentTimeMillis();
if (shuffle)
trainset.shuffle(random);
while (iter++ < iternum) {
if (!simpleOutput) {
System.out.print("iter:");
System.out.print(iter + "\t");
}
float err = 0;
float errtot = 0;
int progress = frac;
beginTimeIter = System.currentTimeMillis();
float[] innerWeights = null;
if (method == TrainMethod.Average) {
innerWeights = Arrays.copyOf(weights, weights.length);
}
for (int ii = 0; ii < numSamples; ii++) {
Instance inst = trainset.getInstance(ii);
List pred = (List) inferencer.getBest(inst, 1);
float l = loss.calc((int[][]) pred.get(0), (int[][]) inst.getTarget());
if (l > 0) {
err += l;
errtot++;
update.update(inst, weights, pred.get(0), c);
}
if (method == TrainMethod.Average) {
for (int i = 0; i < weights.length; i++) {
innerWeights[i] += weights[i];
}
}
if (DEBUG && l > 0) {
pred = (List) inferencer.getBest(inst, 1);
l = loss.calc((int[]) pred.get(0), (int[]) inst.getTarget());
}
if (!simpleOutput && ii % progress == 0) {
System.out.print('.');
progress += frac;
}
}
float curErrRate = err / count;
endTimeIter = System.currentTimeMillis();
if (!simpleOutput) {
System.out.println("\ttime:" + (endTimeIter - beginTimeIter)
/ 1000.0 + "s");
System.out.print("Train:");
System.out.print("\tTag acc:");
}
System.out.print(1 - curErrRate);
if (!simpleOutput) {
System.out.print("\tSentence acc:");
System.out.print(1 - errtot / numSamples);
System.out.println();
}
System.out.print("Weight Numbers: "+MyArrays.countNoneZero(weights));
if (innerOptimized) {
int[] idx = MyArrays.getTop(weights.clone(), threshold, false);
MyArrays.set(weights, idx, 0.0f);
System.out.print(" After Optimized: " + MyArrays.countNoneZero(weights));
}
System.out.println();
if (devset != null) {
evaluate(devset);
}
if (method == TrainMethod.Average) {
for (int i = 0; i < innerWeights.length; i++) {
averageWeights[i] += innerWeights[i] / numSamples;
}
} else if (method == TrainMethod.FastAverage) {
for (int i = 0; i < weights.length; i++) {
averageWeights[i] += weights[i];
}
}
if (interim) {
Linear p = new Linear(inferencer,
trainset.getAlphabetFactory());
try {
p.saveTo("tmp.model");
} catch (IOException e) {
System.err.println("write model error!");
}
}
}
if (method == TrainMethod.Average || method == TrainMethod.FastAverage) {
for (int i = 0; i < averageWeights.length; i++) {
averageWeights[i] /= iternum;
}
weights = null;
weights = averageWeights;
inferencer.setWeights(weights);
}
System.out.print("Weight Numbers: "+MyArrays.countNoneZero(weights));
if (finalOptimized) {
int[] idx = MyArrays.getTop(weights.clone(), threshold, false);
MyArrays.set(weights, idx, 0.0f);
System.out.print(" After Optimized: " + MyArrays.countNoneZero(weights));
}
System.out.println();
endTime = System.currentTimeMillis();
System.out.println("time escape:" + (endTime - beginTime) / 1000.0
+ "s");
Linear p = new Linear(inferencer, trainset.getAlphabetFactory());
return p;
}
public void evaluate(InstanceSet devset) {
float err = 0;
float errtot = 0;
int total = 0;
for (int i = 0; i < devset.size(); i++) {
Instance inst = devset.getInstance(i);
total += ((int[]) inst.getTarget()).length;
List pred = (List) inferencer.getBest(inst, 1);
float l = loss.calc(pred.get(0), inst.getTarget());
if (l > 0) {
errtot += 1.0;
err += l;
}
}
if (!simpleOutput) {
System.out.print("Test:\t");
System.out.print(total - err);
System.out.print('/');
System.out.print(total);
System.out.print("\tTag acc:");
} else {
System.out.print('\t');
}
System.out.print(1 - err / total);
if (!simpleOutput) {
System.out.print("\tSentence acc:");
System.out.println(1 - errtot / devset.size());
}
System.out.println();
}
}