package com.spbsu.exp.multiclass; import com.spbsu.commons.random.FastRandom; import com.spbsu.ml.GridTools; import com.spbsu.ml.cli.output.printers.MulticlassProgressPrinter; import com.spbsu.ml.data.set.VecDataSet; import com.spbsu.ml.data.tools.DataTools; import com.spbsu.ml.data.tools.MCTools; import com.spbsu.ml.data.tools.Pool; import com.spbsu.ml.factorization.impl.ElasticNetFactorization; import com.spbsu.ml.factorization.impl.SVDAdapterEjml; import com.spbsu.ml.func.Ensemble; import com.spbsu.ml.func.FuncJoin; import com.spbsu.ml.loss.L2; import com.spbsu.ml.loss.LogL2; import com.spbsu.ml.loss.blockwise.BlockwiseMLLLogit; import com.spbsu.ml.methods.GradientBoosting; import com.spbsu.ml.methods.MultiClass; import com.spbsu.ml.methods.VecOptimization; import com.spbsu.ml.methods.multiclass.gradfac.GradFacMulticlass; import com.spbsu.ml.methods.multiclass.gradfac.MultiClassColumnBootstrapOptimization; import com.spbsu.ml.methods.trees.GreedyObliviousTree; import com.spbsu.ml.models.MultiClassModel; import junit.framework.TestCase; import java.io.IOException; /** * User: qdeee * Date: 24.05.15 */ public class DiplomaGradFacTest extends TestCase{ private static Pool<?> learn; private static Pool<?> test; @Override protected void setUp() throws Exception { super.setUp(); init(); } private synchronized static void init() throws IOException { if (learn == null || test == null) { learn = DataTools.loadFromFeaturesTxt("/Users/qdeee/datasets/letter.tsv.learn"); test = DataTools.loadFromFeaturesTxt("/Users/qdeee/datasets/letter.tsv.test"); } } public void testBaseline() throws Exception { final MultiClass learner = new MultiClass( new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5), LogL2.class ); fitModel(learner, 400, 0.3); } public void testGradFacBaseline() throws Exception { final GradFacMulticlass learner = new GradFacMulticlass( new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5), new SVDAdapterEjml(1), LogL2.class ); fitModel(learner, 400, 7.); } public void testGradFacColumnsBootstrap() throws Exception { final MultiClassColumnBootstrapOptimization learner = new MultiClassColumnBootstrapOptimization( new GradFacMulticlass( new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5), new SVDAdapterEjml(1), LogL2.class ), new FastRandom(100500), 1. ); fitModel(learner, 7500, 1.5); } public void testGradFacElasticNet() throws Exception { final GradFacMulticlass learner = new GradFacMulticlass( new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5), new ElasticNetFactorization(20, 1e-2, 0.95, 0.15 * 1e-6), LogL2.class ); fitModel(learner, 7500, 7.); } public void testGradFacElasticNetColumnsBootstrap() throws Exception { final MultiClassColumnBootstrapOptimization learner = new MultiClassColumnBootstrapOptimization( new GradFacMulticlass( new GreedyObliviousTree<L2>(GridTools.medianGrid(learn.vecData(), 32), 5), new ElasticNetFactorization(20, 1e-2, 0.95, 0.15 * 1e-6), LogL2.class ), new FastRandom(100500), 1. ); fitModel(learner, 5000, 7.); } private void fitModel(final VecOptimization<L2> weak, final int iters, final double step) { final VecDataSet vecDataSet = learn.vecData(); final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class); final MulticlassProgressPrinter multiclassProgressPrinter = new MulticlassProgressPrinter(learn, test); final GradientBoosting<BlockwiseMLLLogit> boosting = new GradientBoosting<>(weak, L2.class, iters, step); boosting.addListener(multiclassProgressPrinter); final Ensemble ensemble = boosting.fit(vecDataSet, globalLoss); final FuncJoin joined = MCTools.joinBoostingResult(ensemble); final MultiClassModel multiclassModel = new MultiClassModel(joined); System.out.println(MCTools.evalModel(multiclassModel, learn, "[LEARN] ", false)); System.out.println(MCTools.evalModel(multiclassModel, test, "[TEST] ", false)); } }