package com.spbsu.exp.multiclass;
import com.spbsu.commons.io.StreamTools;
import com.spbsu.commons.math.MathTools;
import com.spbsu.commons.math.vectors.Mx;
import com.spbsu.commons.math.vectors.VecTools;
import com.spbsu.commons.random.FastRandom;
import com.spbsu.commons.seq.IntSeq;
import com.spbsu.exp.multiclass.weak.CustomWeakBinClass;
import com.spbsu.exp.multiclass.weak.CustomWeakMultiClass;
import com.spbsu.ml.data.tools.DataTools;
import com.spbsu.ml.data.tools.MCTools;
import com.spbsu.ml.data.tools.Pool;
import com.spbsu.ml.data.tools.SubPool;
import com.spbsu.ml.loss.L2;
import com.spbsu.ml.loss.blockwise.BlockwiseMLLLogit;
import com.spbsu.ml.meta.FeatureMeta;
import com.spbsu.ml.meta.impl.fake.FakeTargetMeta;
import com.spbsu.ml.methods.VecOptimization;
import com.spbsu.ml.methods.multiclass.spoc.AbstractCodingMatrixLearning;
import com.spbsu.ml.methods.multiclass.spoc.SPOCMethodClassic;
import com.spbsu.ml.methods.multiclass.spoc.impl.CodingMatrixLearning;
import com.spbsu.ml.methods.multiclass.spoc.impl.CodingMatrixLearningGreedy;
import com.spbsu.ml.methods.multiclass.spoc.impl.CodingMatrixLearningGreedyParallels;
import com.spbsu.ml.models.multiclass.MCModel;
import com.spbsu.ml.models.multiclass.MulticlassCodingMatrixModel;
import com.spbsu.ml.testUtils.TestResourceLoader;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import junit.framework.TestCase;
import java.io.IOException;
/**
* User: qdeee
* Date: 07.05.14
*/
public class SPOCMethodTest extends TestCase {
private static final double[] hierBorders = new double[] {0.038125, 0.07625, 0.114375, 0.1525, 0.61};
private static final double[] classicBorders = new double[]{0.06999, 0.13999, 0.40999, 0.60999, 0.61};
protected Pool<?> learn;
protected Pool<?> test;
protected Mx S;
protected int k;
protected int l;
private synchronized void initDefaultData() throws IOException {
if (learn == null || test == null) {
final TDoubleList borders = new TDoubleArrayList(hierBorders);
learn = TestResourceLoader.loadPool("features.txt.gz");
test = TestResourceLoader.loadPool("featuresTest.txt.gz");
final IntSeq learnTarget = MCTools.transformRegressionToMC(learn.target(L2.class).target, borders.size(), borders);
final IntSeq testTarget = MCTools.transformRegressionToMC(test.target(L2.class).target, borders.size(), borders);
learn.addTarget(new FakeTargetMeta(learn.vecData(), FeatureMeta.ValueType.INTS), learnTarget);
test.addTarget(new FakeTargetMeta(test.vecData(), FeatureMeta.ValueType.INTS), testTarget);
// final CharSequence mxStr = StreamTools.readStream(TestResourceLoader.loadResourceAsStream("multiclass/regression_based/features.txt.similarityMx"));
final CharSequence mxStr = StreamTools.readStream(TestResourceLoader.loadResourceAsStream("multiclass/regression_based/features-simmatrix-classic.txt"));
S = MathTools.CONVERSION.convert(mxStr, Mx.class);
k = borders.size();
l = 5;
}
}
@Override
protected void setUp() throws Exception {
super.setUp();
initDefaultData();
}
private void printResult(final MCModel model) {
System.out.println(MCTools.evalModel(model, learn, "[LEARN]", false));
System.out.println(MCTools.evalModel(model, test, "[TEST]", false));
System.out.println(MCTools.evalModel(model, learn, getName(), true) + MCTools.evalModel(model, test, "", true));
}
private void fitModel(final AbstractCodingMatrixLearning matrixLearning, final int iters, final double step) {
final Mx codingMatrix = matrixLearning.trainCodingMatrix(S);
// if (!CodingMatrixLearning.checkConstraints(codeMatrix)) {
// throw new IllegalStateException("Result matrix is out of constraints");
// }
final VecOptimization method = new SPOCMethodClassic(codingMatrix, new CustomWeakBinClass(iters, step));
final MulticlassCodingMatrixModel model = (MulticlassCodingMatrixModel) method.fit(learn.vecData(), learn.target(BlockwiseMLLLogit.class));
printResult(model);
}
public void testBaseline() throws Exception {
final CustomWeakMultiClass customWeakMultiClass = new CustomWeakMultiClass(100, 0.5);
final MCModel model = (MCModel) customWeakMultiClass.fit(learn.vecData(), learn.target(BlockwiseMLLLogit.class));
printResult(model);
}
public void testMathFit() throws Exception {
fitModel(new CodingMatrixLearning(k, l, 3.0, 2.5, 7.0, 1.8),
100, 0.3);
}
public void testGreedyFit() throws Exception {
fitModel(new CodingMatrixLearningGreedy(k, l, 3.0, 2.5, 7.0),
200, 0.3);
}
public void testParallelsGreedyFit() throws Exception {
fitModel(new CodingMatrixLearningGreedyParallels(k, l, 3.0, 2.5, 7.0),
200, 0.3);
}
public void _testBaselineBigDS() throws Exception {
final Pool<?> pool = TestResourceLoader.loadPool("multiclass/ds_letter/letter.tsv.gz");
pool.addTarget(new FakeTargetMeta(pool.vecData(), FeatureMeta.ValueType.INTS),
VecTools.toIntSeq(pool.target(L2.class).target));
final int[][] idxs = DataTools.splitAtRandom(pool.size(), new FastRandom(100501), 0.8, 0.2);
final SubPool<?> learn = new SubPool(pool, idxs[0]);
final SubPool<?> test = new SubPool(pool, idxs[1]);
final CustomWeakMultiClass customWeakMultiClass = new CustomWeakMultiClass(300, 0.7);
final MCModel model = (MCModel) customWeakMultiClass.fit(learn.vecData(), learn.target(BlockwiseMLLLogit.class));
System.out.println(MCTools.evalModel(model, learn, "[LEARN]", false));
System.out.println(MCTools.evalModel(model, test, "[TEST]", false));
System.out.println(MCTools.evalModel(model, learn, getName(), true) + MCTools.evalModel(model, test, "", true));
}
public void _testMathFitBigDS() throws Exception {
final Pool<?> pool = TestResourceLoader.loadPool("multiclass/ds_letter/letter.tsv.gz");
pool.addTarget(new FakeTargetMeta(pool.vecData(), FeatureMeta.ValueType.INTS),
VecTools.toIntSeq(pool.target(L2.class).target));
final int[][] idxs = DataTools.splitAtRandom(pool.size(), new FastRandom(100500), 0.8, 0.2);
final SubPool<?> learn = new SubPool<>(pool, idxs[0]);
final SubPool<?> test = new SubPool<>(pool, idxs[1]);
final CharSequence mxStr = StreamTools.readStream(TestResourceLoader.loadResourceAsStream("multiclass/ds_letter/letter.similarityMx"));
final Mx similarityMx = MathTools.CONVERSION.convert(mxStr, Mx.class);
final CodingMatrixLearning codingMatrixLearning = new CodingMatrixLearning(26, 10, 10.0, 2.5, 5.0, 1.8);
final Mx codeMx = codingMatrixLearning.findMatrixB(similarityMx);
final SPOCMethodClassic spoc = new SPOCMethodClassic(codeMx, new CustomWeakBinClass(300, 0.7));
final MCModel model = (MCModel) spoc.fit(learn.vecData(), learn.target(BlockwiseMLLLogit.class));
System.out.println(MCTools.evalModel(model, learn, "[LEARN]", false));
System.out.println(MCTools.evalModel(model, test, "[TEST]", false));
System.out.println(MCTools.evalModel(model, learn, getName(), true) + MCTools.evalModel(model, test, "", true));
}
}