package com.spbsu.exp.multiclass.weak; import com.spbsu.commons.math.Trans; import com.spbsu.commons.math.vectors.Mx; import com.spbsu.commons.seq.IntSeq; import com.spbsu.ml.*; import com.spbsu.ml.data.set.VecDataSet; import com.spbsu.ml.data.tools.FakePool; import com.spbsu.ml.data.tools.MCTools; import com.spbsu.ml.func.Ensemble; import com.spbsu.ml.func.FuncJoin; import com.spbsu.ml.loss.L2; import com.spbsu.ml.loss.SatL2; import com.spbsu.ml.loss.blockwise.BlockwiseMLLLogit; import com.spbsu.ml.methods.GradientBoosting; import com.spbsu.ml.methods.MultiClass; import com.spbsu.ml.methods.VecOptimization; import com.spbsu.ml.methods.trees.GreedyObliviousTree; import com.spbsu.ml.models.MultiClassModel; import com.spbsu.ml.models.multiclass.MCModel; /** * User: qdeee * Date: 16.08.14 */ public class CustomWeakMultiClass extends VecOptimization.Stub<BlockwiseMLLLogit> { private final int iters; private final double step; public CustomWeakMultiClass(int iters, double step) { this.iters = iters; this.step = step; } @Override public Trans fit(final VecDataSet learnData, final BlockwiseMLLLogit loss) { final BFGrid grid = GridTools.medianGrid(learnData, 32); final GradientBoosting<TargetFunc> boosting = new GradientBoosting<>(new MultiClass(new GreedyObliviousTree<L2>(grid, 5), SatL2.class), iters, step); final IntSeq intTarget = loss.labels(); final FakePool ds = new FakePool(learnData.data(), intTarget); System.out.println(prepareComment(intTarget)); final ProgressHandler calcer = new ProgressHandler() { int iter = 0; @Override public void invoke(Trans partial) { if ((iter + 1) % 20 == 0) { final FuncJoin internModel = MCTools.joinBoostingResult((Ensemble) partial); final MultiClassModel multiClassModel = new MultiClassModel(internModel); final Mx x = internModel.transAll(learnData.data()); System.out.println("iter=" + iter + ", [learn]MLLLogitValue=" + String.format("%.10f", loss.value(x)) + ", stats=" + MCTools.evalModel(multiClassModel, ds, "[LEARN]", true) + "\r"); } iter++; } }; // boosting.addListener(calcer); final Ensemble ensemble = boosting.fit(learnData, loss); System.out.println(); final MCModel model = new MultiClassModel(MCTools.joinBoostingResult(ensemble)); return model; } private static String prepareComment(final IntSeq labels) { final StringBuilder builder = new StringBuilder("Class entries count: { "); final int countClasses = MCTools.countClasses(labels); for (int i = 0; i < countClasses; i++) { builder.append(i) .append(" : ") .append(MCTools.classEntriesCount(labels, i)) .append(", "); } return builder.append("}").toString(); } }