package com.spbsu.exp.multiclass;
import com.spbsu.commons.seq.IntSeq;
import com.spbsu.commons.util.tree.IntTree;
import com.spbsu.exp.multiclass.weak.CustomWeakBinClass;
import com.spbsu.exp.multiclass.weak.CustomWeakMultiClass;
import com.spbsu.ml.data.tools.HierTools;
import com.spbsu.ml.data.tools.MCTools;
import com.spbsu.ml.data.tools.Pool;
import com.spbsu.ml.loss.L2;
import com.spbsu.ml.loss.blockwise.BlockwiseMLLLogit;
import com.spbsu.ml.meta.FeatureMeta;
import com.spbsu.ml.meta.impl.fake.FakeTargetMeta;
import com.spbsu.ml.methods.multiclass.hierarchical.HierarchicalClassification;
import com.spbsu.ml.methods.multiclass.hierarchical.HierarchicalRefinedClassification;
import com.spbsu.ml.models.multiclass.HierarchicalModel;
import com.spbsu.ml.models.multiclass.MCModel;
import com.spbsu.ml.testUtils.TestResourceLoader;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.array.TDoubleArrayList;
import gnu.trove.map.TIntIntMap;
import junit.framework.TestCase;
/**
* User: qdeee
* Date: 28.07.14
*/
public class HierClassTests extends TestCase {
private static Pool<?> learn;
private static Pool<?> test;
private static IntTree tree;
private static int iters;
private static double step;
private synchronized void init() throws Exception {
if (learn == null || test == null) {
learn = TestResourceLoader.loadPool("features.txt.gz");
test = TestResourceLoader.loadPool("featuresTest.txt.gz");
final TDoubleList borders = new TDoubleArrayList();
final IntSeq learnTarget = MCTools.transformRegressionToMC(learn.target(L2.class).target, 16, borders);
final IntSeq testTarget = MCTools.transformRegressionToMC(test.target(L2.class).target, borders.size(), borders);
final HierTools.TreeBuilder treeBuilder = new HierTools.TreeBuilder(450);
treeBuilder.createFromOrderedMulticlass(learnTarget);
tree = treeBuilder.releaseTree();
final TIntIntMap map = treeBuilder.releaseMapping();
learn.addTarget(new FakeTargetMeta(learn.vecData(), FeatureMeta.ValueType.INTS), MCTools.mapTarget(learnTarget, map));
test.addTarget(new FakeTargetMeta(test.vecData(), FeatureMeta.ValueType.INTS), MCTools.mapTarget(testTarget, map));
iters = 200;
step = 1.5;
}
}
@Override
protected void setUp() throws Exception {
super.setUp();
init();
}
private void printResult(final MCModel model) {
System.out.println(MCTools.evalModel(model, learn, "[LEARN]", false));
System.out.println(MCTools.evalModel(model, test, "[TEST]", false));
System.out.println(MCTools.evalModel(model, learn, getName(), true) + MCTools.evalModel(model, test, "", true));
}
public void testHierClass() throws Exception {
final CustomWeakMultiClass customWeakMultiClass = new CustomWeakMultiClass(iters, step);
final HierarchicalClassification hierarchicalClassification = new HierarchicalClassification(customWeakMultiClass, tree);
final HierarchicalModel model = (HierarchicalModel) hierarchicalClassification.fit(learn.vecData(), learn.target(BlockwiseMLLLogit.class));
printResult(model);
}
public void testHierRefinedClass() throws Exception {
final CustomWeakMultiClass customWeakMultiClass = new CustomWeakMultiClass(iters, step);
final CustomWeakBinClass customWeakBinClass = new CustomWeakBinClass(iters, step);
final HierarchicalRefinedClassification hierarchicalRefinedClassification = new HierarchicalRefinedClassification(customWeakBinClass, customWeakMultiClass, tree);
final HierarchicalModel model = (HierarchicalModel) hierarchicalRefinedClassification.fit(learn.vecData(), learn.target(BlockwiseMLLLogit.class));
printResult(model);
}
}