package quickml.supervised; import com.beust.jcommander.internal.Lists; import com.beust.jcommander.internal.Sets; import org.junit.Assert; import org.junit.Test; import quickml.InstanceLoader; import quickml.data.AttributesMap; import quickml.data.instances.ClassifierInstance; import quickml.supervised.tree.branchFinders.SplittingUtilsTest; import quickml.supervised.tree.decisionTree.nodes.DTCatBranch; import quickml.supervised.tree.decisionTree.nodes.DTNumBranch; import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; import quickml.supervised.tree.nodes.Branch; import java.io.Serializable; import java.util.List; import java.util.Set; /** * Created by alexanderhawk on 6/29/15. */ public class UtilsTest { private List<ClassifierInstance> setUpLast(){ return InstanceLoader.getAdvertisingInstances(); } @Test public void setTrueAndFalseTrainingSetsTestForCatBranch(){ Set<Serializable> trueSet = Sets.newHashSet(); //check trivial case first where i give it a training list already sorted correctly trueSet.add("1.0"); DTCatBranch branch = new DTCatBranch(null, "t",trueSet, .5, .5, new ClassificationCounter()); Utils.TrueFalsePair<ClassifierInstance> tfPair = Utils.<ClassifierInstance>setTrueAndFalseTrainingSets(getInstances(), branch); Assert.assertEquals(tfPair.falseTrainingSet.size(), 4); Assert.assertEquals(tfPair.trueTrainingSet.size(), 4); for(ClassifierInstance instance : tfPair.falseTrainingSet) { Assert.assertEquals(instance.getAttributes().get("t"), "2.0"); } for(ClassifierInstance instance : tfPair.trueTrainingSet) { Assert.assertEquals(instance.getAttributes().get("t"), "1.0"); } //check non trivial case where data is in reverse order of what it needs to be trueSet = Sets.newHashSet(); trueSet.add("2.0"); branch = new DTCatBranch(null, "t", trueSet, .5, .5, new ClassificationCounter()); tfPair = Utils.<ClassifierInstance>setTrueAndFalseTrainingSets(getInstances(),branch); Assert.assertEquals(tfPair.falseTrainingSet.size(), 4); Assert.assertEquals(tfPair.trueTrainingSet.size(), 4); for(ClassifierInstance instance : tfPair.falseTrainingSet) { Assert.assertEquals(instance.getAttributes().get("t"), "1.0"); } for(ClassifierInstance instance : tfPair.trueTrainingSet) { Assert.assertEquals(instance.getAttributes().get("t"), "2.0"); } } @Test public void setTrueAndFalseTrainingSetsTestForNumBranch(){ //check trivial case first where i give it a training list already sorted correctly DTNumBranch branch = new DTNumBranch(null, "t", 0.5, .5, new ClassificationCounter(), 4.5); Utils.TrueFalsePair<ClassifierInstance> tfPair = Utils.<ClassifierInstance>setTrueAndFalseTrainingSets(SplittingUtilsTest.getExtendedInstances(),branch); Assert.assertEquals(tfPair.falseTrainingSet.size(), 4); Assert.assertEquals(tfPair.trueTrainingSet.size(), 4); for(ClassifierInstance instance : tfPair.falseTrainingSet) { Assert.assertTrue((Double) (instance.getAttributes().get("t")) < 4.5); } for(ClassifierInstance instance : tfPair.trueTrainingSet) { Assert.assertTrue((Double) (instance.getAttributes().get("t")) > 4.5); } //check non trivial case where data is in reverse order of what it needs to be branch = new DTNumBranch(null, "t", 0.5, .5, new ClassificationCounter(), 6.5); tfPair = Utils.<ClassifierInstance>setTrueAndFalseTrainingSets(SplittingUtilsTest.getExtendedInstances(),branch); Assert.assertEquals(tfPair.falseTrainingSet.size(), 6); Assert.assertEquals(tfPair.trueTrainingSet.size(), 2); for(ClassifierInstance instance : tfPair.falseTrainingSet) { Assert.assertTrue((Double) (instance.getAttributes().get("t")) < 6.5); } for(ClassifierInstance instance : tfPair.trueTrainingSet) { Assert.assertTrue((Double) (instance.getAttributes().get("t")) > 6.5); } } public static List<ClassifierInstance> getInstances() { List<ClassifierInstance> td = Lists.newArrayList(); AttributesMap atMap = AttributesMap.newHashMap(); atMap.put("t", "1.0"); td.add(new ClassifierInstance(atMap, 0.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "1.0"); td.add(new ClassifierInstance(atMap, 0.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "1.0"); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "1.0"); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "2.0"); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "2.0"); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "2.0"); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", "2.0"); td.add(new ClassifierInstance(atMap, 0.0)); return td; } @Test public void setTrueFalseOnespotInstances(){ Set<Serializable> splitTrueSet = Sets.<Serializable>newHashSet(); splitTrueSet.add(true); Branch<ClassificationCounter> bestBranch = new DTCatBranch(null, "seenPixel", splitTrueSet, 0.0, 0.0, null); List<ClassifierInstance> instances =setUpLast(); Utils.TrueFalsePair<ClassifierInstance> trueFalsePair = Utils.setTrueAndFalseTrainingSets(instances, bestBranch); List<ClassifierInstance> trueTrainingSet = com.google.common.collect.Lists.newArrayList(); List<ClassifierInstance> falseTrainingSet = com.google.common.collect.Lists.newArrayList(); Set<ClassifierInstance> trueSet = com.google.common.collect.Sets.newHashSet(); Set<ClassifierInstance> falseSet = com.google.common.collect.Sets.newHashSet(); setTrueAndFalseTrainingSets(instances, bestBranch, trueTrainingSet, falseTrainingSet); trueSet.addAll(trueTrainingSet); falseSet.addAll(falseTrainingSet); for (ClassifierInstance instance : trueFalsePair.trueTrainingSet) { Assert.assertTrue(trueSet.contains(instance) && !falseSet.contains(instance)); } for (ClassifierInstance instance : trueFalsePair.falseTrainingSet) { Assert.assertTrue(falseSet.contains(instance) && !trueSet.contains(instance)); } } private void setTrueAndFalseTrainingSets(Iterable<ClassifierInstance> trainingData, Branch<ClassificationCounter> bestNode, List<ClassifierInstance> trueTrainingSet, List<ClassifierInstance> falseTrainingSet) { //put instances with attribute values into appropriate training sets for (ClassifierInstance instance : trainingData) { if (bestNode.decide(instance.getAttributes())) { trueTrainingSet.add(instance); } else { falseTrainingSet.add(instance); } } } }