package com.spbsu.exp.multiclass; import com.spbsu.commons.func.Action; import com.spbsu.commons.io.StreamTools; import com.spbsu.commons.math.MathTools; import com.spbsu.commons.math.vectors.Mx; import com.spbsu.commons.math.vectors.MxTools; import com.spbsu.commons.math.vectors.Vec; import com.spbsu.commons.math.vectors.VecTools; import com.spbsu.commons.math.vectors.impl.mx.VecBasedMx; import com.spbsu.commons.math.vectors.impl.vectors.ArrayVec; import com.spbsu.commons.random.FastRandom; import com.spbsu.commons.seq.IntSeq; import com.spbsu.commons.util.Pair; import com.spbsu.commons.util.logging.Interval; import com.spbsu.exp.multiclass.spoc.full.mx.optimization.ECOCMulticlass; import com.spbsu.exp.multiclass.spoc.full.mx.optimization.SeparatedMLLLogit; import com.spbsu.exp.multiclass.weak.CustomWeakBinClass; import com.spbsu.exp.multiclass.weak.CustomWeakMultiClass; import com.spbsu.ml.BFGrid; import com.spbsu.ml.GridTools; import com.spbsu.ml.TargetFunc; import com.spbsu.ml.cli.output.printers.MulticlassProgressPrinter; import com.spbsu.ml.data.set.VecDataSet; import com.spbsu.ml.data.tools.DataTools; import com.spbsu.ml.data.tools.MCTools; import com.spbsu.ml.data.tools.Pool; import com.spbsu.ml.data.tools.SubPool; import com.spbsu.ml.factorization.Factorization; import com.spbsu.ml.factorization.impl.ALS; import com.spbsu.ml.factorization.impl.SVDAdapterEjml; import com.spbsu.ml.factorization.impl.StochasticALS; import com.spbsu.ml.func.Ensemble; import com.spbsu.ml.func.FuncJoin; import com.spbsu.ml.loss.L2; import com.spbsu.ml.loss.SatL2; import com.spbsu.ml.loss.blockwise.BlockwiseMLLLogit; import com.spbsu.ml.meta.FeatureMeta; import com.spbsu.ml.meta.impl.fake.FakeTargetMeta; import com.spbsu.ml.methods.GradientBoosting; import com.spbsu.ml.methods.multiclass.gradfac.GradFacMulticlass; import com.spbsu.ml.methods.multiclass.gradfac.GradFacSvdNMulticlass; import com.spbsu.ml.methods.multiclass.spoc.ECOCCombo; import com.spbsu.ml.methods.trees.GreedyObliviousTree; import com.spbsu.ml.models.MultiClassModel; import com.spbsu.ml.models.multiclass.MCModel; import com.spbsu.ml.models.multiclass.MulticlassCodingMatrixModel; import com.spbsu.ml.testUtils.TestResourceLoader; import junit.framework.TestCase; import java.io.FileInputStream; import java.io.IOException; public class ECOCComboTest extends TestCase { private static Pool<?> learn; private static Pool<?> test; private static Mx S; private synchronized static void init() throws IOException { if (learn == null || test == null) { final Pool<?> pool = TestResourceLoader.loadPool("multiclass/ds_letter/letter.tsv.gz"); pool.addTarget(new FakeTargetMeta(pool.vecData(), FeatureMeta.ValueType.INTS), VecTools.toIntSeq(pool.target(L2.class).target) ); final int[][] idxs = DataTools.splitAtRandom(pool.size(), new FastRandom(100500), 0.8, 0.5); learn = new SubPool<>(pool, idxs[0]); test = new SubPool<>(pool, idxs[1]); final CharSequence mxStr = StreamTools.readStream(new FileInputStream("/Users/qdeee/datasets/catalog-final/catalog50-1stlevel-gt5000.tsv.simmx")); S = MathTools.CONVERSION.convert(mxStr, Mx.class); } } @Override protected void setUp() throws Exception { // init(); } public void testFit() throws Exception { final BlockwiseMLLLogit mllLogit = learn.target(BlockwiseMLLLogit.class); final VecDataSet vecDataSet = learn.vecData(); final int k = MCTools.countClasses(mllLogit.labels()); final ECOCCombo ecocComboMethod = new ECOCCombo(k, k, 5.0, 2.5, 3.0, S, new CustomWeakBinClass(100, 0.3)); final Action<MulticlassCodingMatrixModel> listener = new Action<MulticlassCodingMatrixModel>() { @Override public void invoke(final MulticlassCodingMatrixModel model) { System.out.println("L == " + model.getInternalModel().ydim()); System.out.println(getPairwiseInteractions(model)); System.out.println(MCTools.evalModel(model, learn, "[LEARN] ", true)); System.out.println(MCTools.evalModel(model, test, "[TEST] ", true)); } }; ecocComboMethod.addListener(listener); final MulticlassCodingMatrixModel model = (MulticlassCodingMatrixModel) ecocComboMethod.fit(vecDataSet, mllLogit); System.out.println("\n\n\n"); System.out.println(MCTools.evalModel(model, learn, "[LEARN] ", false)); System.out.println(MCTools.evalModel(model, test, "[TEST] ", false)); } public void testDefaultGreedyModelFit() throws Exception { final BlockwiseMLLLogit mllLogit = learn.target(BlockwiseMLLLogit.class); final VecDataSet vecDataSet = learn.vecData(); final int k = MCTools.countClasses(mllLogit.labels()); final ECOCCombo ecocComboMethod = new ECOCCombo(k, k, 5.0, 2.5, 3.0, S, new CustomWeakBinClass(100, 0.3)); final Action<MulticlassCodingMatrixModel> listener = new Action<MulticlassCodingMatrixModel>() { @Override public void invoke(final MulticlassCodingMatrixModel model) { System.out.println("L == " + model.getInternalModel().ydim()); System.out.println(MCTools.evalModel(model, learn, "[LEARN] ", true)); System.out.println(MCTools.evalModel(model, test, "[TEST] ", true)); System.out.println(); } }; ecocComboMethod.addListener(listener); final MulticlassCodingMatrixModel model = (MulticlassCodingMatrixModel) ecocComboMethod.fit(vecDataSet, mllLogit); System.out.println("\n\n\n"); System.out.println(MCTools.evalModel(model, learn, "[LEARN] ", false)); System.out.println(MCTools.evalModel(model, test, "[TEST] ", false)); } public void testFitUpdatePrior() throws Exception { final double lambda = 0.9; final BlockwiseMLLLogit mllLogit = learn.target(BlockwiseMLLLogit.class); final VecDataSet vecDataSet = learn.vecData(); final int k = MCTools.countClasses(mllLogit.labels()); final ECOCCombo ecocComboMethod = new ECOCCombo(k, k, 5.0, 2.5, 3.0, S, new CustomWeakBinClass(100, 0.3)); final Action<MulticlassCodingMatrixModel> listener = new Action<MulticlassCodingMatrixModel>() { @Override public void invoke(final MulticlassCodingMatrixModel model) { final Mx mx = getPairwiseInteractions(model); VecTools.scale(S, lambda); VecTools.scale(mx, 1 - lambda); VecTools.append(S, mx); System.out.println("L == " + model.getInternalModel().ydim()); System.out.println(MCTools.evalModel(model, learn, "[LEARN] ", true)); System.out.println(MCTools.evalModel(model, test, "[TEST] ", true)); System.out.println(); } }; ecocComboMethod.addListener(listener); final MulticlassCodingMatrixModel model = (MulticlassCodingMatrixModel) ecocComboMethod.fit(vecDataSet, mllLogit); System.out.println("\n\n\n"); System.out.println(MCTools.evalModel(model, learn, "[LEARN] ", false)); System.out.println(MCTools.evalModel(model, test, "[TEST] ", false)); } private static Mx getPairwiseInteractions(final MCModel model) { final Mx result = new VecBasedMx(S.columns(), S.rows()); final Mx features = learn.vecData().data(); final int[] counts = new int[features.rows()]; for (int i = 0; i < learn.size(); i++) { final Vec probs = model.probs(features.row(i)); final int bestClass = VecTools.argmax(probs); VecTools.append(result.row(bestClass), probs); counts[bestClass]++; } for (int c = 0; c < result.rows(); c++) { VecTools.scale(result.row(c), 1.0 / counts[c]); } for (int c1 = 0; c1 < result.rows(); c1++) { for (int c2 = c1 + 1; c2 < result.columns(); c2++) { final double val = 0.5 * (result.get(c1, c2) + result.get(c2, c1)); result.set(c1, c2, val); result.set(c2, c1, val); } } return result; } public void testBoostedECOC() throws Exception { final VecDataSet vecDataSet = learn.vecData(); final IntSeq labels = learn.target(BlockwiseMLLLogit.class).labels(); final BFGrid bfGrid = GridTools.medianGrid(vecDataSet, 32); final SeparatedMLLLogit smlllogit = new SeparatedMLLLogit(5, labels, null); final int k = MCTools.countClasses(smlllogit.labels()); final int l = smlllogit.getBinClassifiersCount(); final ECOCMulticlass ecocMulticlass = new ECOCMulticlass(new GreedyObliviousTree<L2>(bfGrid, 5), SatL2.class, k, l, 1.0); final GradientBoosting<SeparatedMLLLogit> boosting = new GradientBoosting<>(ecocMulticlass, 3, 0.1); final Ensemble fit = boosting.fit(vecDataSet, smlllogit); System.out.println(fit); } public void testMCMMProbs() throws Exception { final BlockwiseMLLLogit mllLogit = learn.target(BlockwiseMLLLogit.class); final VecDataSet vecDataSet = learn.vecData(); final int k = MCTools.countClasses(mllLogit.labels()); final ECOCCombo ecocComboMethod = new ECOCCombo(k, 5, 5.0, 2.5, 3.0, S, new CustomWeakBinClass(10, 0.3)); final MulticlassCodingMatrixModel model = (MulticlassCodingMatrixModel) ecocComboMethod.fit(vecDataSet, mllLogit); for (int i = 0; i < vecDataSet.data().rows(); i++) { final Vec features = vecDataSet.data().row(i); final Vec probs = model.probs(features); final int bestClass = model.bestClass(features); assertEquals(bestClass, VecTools.argmax(probs)); } } public void testGradFacALS() throws Exception { final VecDataSet vecDataSet = learn.vecData(); final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class); final BFGrid bfGrid = GridTools.medianGrid(vecDataSet, 32); final int factorIters = 15; final GradientBoosting<TargetFunc> boosting = new GradientBoosting<>(new GradFacMulticlass(new GreedyObliviousTree<L2>(bfGrid, 5), new ALS(factorIters), SatL2.class), 500, 0.7); boosting.addListener(new MulticlassProgressPrinter(learn, test)); final Ensemble ensemble = boosting.fit(vecDataSet, globalLoss); final FuncJoin joined = MCTools.joinBoostingResult(ensemble); final MultiClassModel multiclassModel = new MultiClassModel(joined); final String learnResult = MCTools.evalModel(multiclassModel, learn, "[LEARN] ", false); final String testResult = MCTools.evalModel(multiclassModel, test, "[TEST] ", false); System.out.println(learnResult); System.out.println(testResult); } public void testGradFacSvdN() throws Exception { final VecDataSet vecDataSet = learn.vecData(); final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class); final BFGrid bfGrid = GridTools.medianGrid(vecDataSet, 32); final GradientBoosting<TargetFunc> boosting = new GradientBoosting<>(new GradFacSvdNMulticlass(new GreedyObliviousTree<L2>(bfGrid, 5), SatL2.class, 2), 500, 0.7); final MulticlassProgressPrinter multiclassProgressPrinter = new MulticlassProgressPrinter(learn, test); boosting.addListener(multiclassProgressPrinter); final Ensemble ensemble = boosting.fit(vecDataSet, globalLoss); final FuncJoin joined = MCTools.joinBoostingResult(ensemble); final MultiClassModel multiclassModel = new MultiClassModel(joined); final String learnResult = MCTools.evalModel(multiclassModel, learn, "[LEARN] ", false); final String testResult = MCTools.evalModel(multiclassModel, test, "[TEST] ", false); System.out.println(learnResult); System.out.println(testResult); } public void testBaseline() throws Exception { final VecDataSet vecDataSet = learn.vecData(); final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class); final CustomWeakMultiClass customWeakMultiClass = new CustomWeakMultiClass(300, 0.3); final MCModel model = (MCModel) customWeakMultiClass.fit(vecDataSet, globalLoss); System.out.println(MCTools.evalModel(model, learn, "[LEARN]", false)); System.out.println(MCTools.evalModel(model, test, "[TEST]", false)); } public void testGradMxApproxALS() throws Exception { final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class); final Mx gradient = (Mx) globalLoss.gradient(new ArrayVec(globalLoss.dim())); final int factorIters = 25; final ALS als = new ALS(factorIters); final Action<Pair<Vec, Vec>> action = new Action<Pair<Vec, Vec>>() { @Override public void invoke(final Pair<Vec, Vec> pair) { final Vec h = pair.getFirst(); final Vec b = pair.getSecond(); System.out.println("||h|| = " + VecTools.norm(h) + ", ||b|| = " + VecTools.norm(b) + ", RMSE = " + rmse(gradient, VecTools.outer(h, b))); } }; als.addListener(action); als.factorize(gradient); } public void testGradMxApproxSVD() throws Exception { applyFactorMethod(new ALS(15)); applyFactorMethod(new SVDAdapterEjml()); } public void testGradMxApproxSVDN() throws Exception { final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class); final Mx gradient = (Mx) globalLoss.gradient(new ArrayVec(globalLoss.dim())); double time = System.currentTimeMillis(); for (int factorDim = gradient.columns(); factorDim >= 1; factorDim--) { final Pair<Vec, Vec> pair = new SVDAdapterEjml(factorDim).factorize(gradient); final Mx h = (Mx) pair.getFirst(); final Mx b = (Mx) pair.getSecond(); System.out.println("factor dim: " + factorDim); System.out.println("time: " + ((System.currentTimeMillis() - time) / 1000)); final Mx afterFactor = MxTools.multiply(h, MxTools.transpose(b)); // System.out.println("||h|| = " + VecTools.norm(h) + ", ||b|| = " + VecTools.norm(b) + ", l2 = " + VecTools.distance(gradient, afterFactor) + ", l1 = " + VecTools.distanceL1(gradient, afterFactor)); System.out.println(); } } public void testStochasticALS() { final FastRandom rng = new FastRandom(0); final Mx X = new VecBasedMx(100000, 100); VecTools.fillGaussian(X, rng); final StochasticALS sals = new StochasticALS(rng, 1000); final ALS als = new ALS(10); Interval.start(); final Pair<Vec, Vec> pair = sals.factorize(X); Interval.stopAndPrint("SALS"); Interval.start(); final Pair<Vec, Vec> reference = als.factorize(X); Interval.stopAndPrint("ALS"); final VecBasedMx u = new VecBasedMx(pair.first.dim(), pair.first); final VecBasedMx v = new VecBasedMx(1, pair.second); Mx mx = VecTools.outer(u, v); final VecBasedMx u1 = new VecBasedMx(reference.first.dim(), reference.first); final VecBasedMx v1 = new VecBasedMx(1, reference.second); Mx mx1 = VecTools.outer(u1, v1); System.out.println(VecTools.distance(mx, X)); System.out.println(VecTools.distance(mx1, X)); assertEquals(VecTools.distance(pair.first, reference.first), 0.001); assertEquals(VecTools.distance(pair.second, reference.second), 0.001); } private static void applyFactorMethod(final Factorization method) { final BlockwiseMLLLogit globalLoss = learn.target(BlockwiseMLLLogit.class); final Mx gradient = (Mx) globalLoss.gradient(new ArrayVec(globalLoss.dim())); final Pair<Vec, Vec> pair = method.factorize(gradient); final Vec h = pair.getFirst(); final Vec b = pair.getSecond(); final double normB = VecTools.norm(b); VecTools.scale(b, 1 / normB); VecTools.scale(h, normB); final Mx afterFactor = VecTools.outer(h, b); // System.out.println("||h|| = " + VecTools.norm(h) + ", ||b|| = " + VecTools.norm(b) + ", l2 = " + VecTools.distance(gradient, afterFactor) + ", l1 = " + VecTools.distanceL1(gradient, afterFactor)); } private static double rmse(final Vec target, final Vec approx) { return Math.sqrt(VecTools.sum2(VecTools.subtract(target, approx)) / target.length()); } }