package quickml.supervised.tree.decisionTree.reducers; import com.beust.jcommander.internal.Lists; import com.google.common.base.Optional; import org.junit.Assert; import org.junit.Test; import quickml.data.AttributesMap; import quickml.data.instances.ClassifierInstance; import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; import quickml.supervised.tree.reducers.AttributeStats; import java.util.List; /** * Created by alexanderhawk on 6/29/15. */ public class DTCatOldBranchReducerTest { @Test public void getAttributeStatsTest() { List<ClassifierInstance> instances = getInstances(); DTCatBranchReducer<ClassifierInstance> reducer = new DTCatBranchReducer<>(instances); Optional<AttributeStats<ClassificationCounter>> attributeStatsOptional = reducer.getAttributeStats("t"); AttributeStats<ClassificationCounter> attributeStats = attributeStatsOptional.get(); Assert.assertEquals(attributeStats.getStatsOnEachValue().size(), 2); ClassificationCounter first = attributeStats.getStatsOnEachValue().get(0); assertionsAboutInstances(first); ClassificationCounter second = attributeStats.getStatsOnEachValue().get(0); assertionsAboutInstances(second); } private void assertionsAboutInstances(ClassificationCounter first) { if(first.attrVal.equals("1.0")) { Assert.assertEquals(first.getCount(1.0), 2.0, 1E-5); Assert.assertEquals(first.getCount(0.0), 2.0, 1E-5); } else { Assert.assertEquals(first.getCount(1.0), 3.0, 1E-5); Assert.assertEquals(first.getCount(0.0), 1.0, 1E-5); } } public static List<ClassifierInstance> getInstances() { List<ClassifierInstance> td = Lists.newArrayList(); AttributesMap atMap = AttributesMap.newHashMap(); atMap.put("t", "1.0"); td.add(new ClassifierInstance(atMap, 0.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "1.0"); td.add(new ClassifierInstance(atMap, 0.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "1.0"); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "1.0"); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "2.0"); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "2.0"); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "2.0"); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "2.0"); td.add(new ClassifierInstance(atMap, 0.0)); return td; } }