package quickml.supervised.tree.branchFinders; import com.beust.jcommander.internal.Lists; import com.beust.jcommander.internal.Sets; import com.google.common.base.Optional; import org.junit.Assert; import org.junit.Test; import quickml.data.AttributesMap; import quickml.data.instances.ClassifierInstance; import quickml.supervised.tree.decisionTree.attributeValueIgnoringStrategies.BinaryClassAttributeValueIgnoringStrategy; import quickml.supervised.tree.decisionTree.branchingConditions.DTBranchingConditions; import quickml.supervised.tree.decisionTree.reducers.DTBinaryCatBranchReducer; import quickml.supervised.tree.decisionTree.reducers.DTNumBranchReducer; import quickml.supervised.tree.decisionTree.scorers.GRPenalizedGiniImpurityScorerFactory; import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; import quickml.supervised.tree.reducers.AttributeStats; import java.util.List; import java.util.Set; /** * Created by alexanderhawk on 6/25/15. */ public class SplittingUtilsTest { @Test public void findBestCategoricalSplit() throws Exception { List<ClassifierInstance> td = getInstances(); DTBinaryCatBranchReducer<ClassifierInstance> reducer = new DTBinaryCatBranchReducer<>(td, 0.0); Optional<AttributeStats<ClassificationCounter>> attributeStatsOptional = reducer.getAttributeStats("t"); AttributeStats<ClassificationCounter> attStats = attributeStatsOptional.get(); ClassificationCounter aggregateData = ClassificationCounter.countAll(td); BinaryClassAttributeValueIgnoringStrategy attributeValueIgnoringStrategy = new BinaryClassAttributeValueIgnoringStrategy(aggregateData, 0); Optional<SplittingUtils.SplitScore> splitScoreOptional = SplittingUtils.splitSortedAttributeStats(attStats, new GRPenalizedGiniImpurityScorerFactory(), new DTBranchingConditions().minSplitFraction(.25).minLeafInstances(0).minScore(0), attributeValueIgnoringStrategy, true); catBranchAssertions(splitScoreOptional); //change scorerFactory splitScoreOptional = SplittingUtils.splitSortedAttributeStats(attStats, new GRPenalizedGiniImpurityScorerFactory(), new DTBranchingConditions().minSplitFraction(.25).minLeafInstances(0).minScore(0), attributeValueIgnoringStrategy, true); catBranchAssertions(splitScoreOptional); } private void catBranchAssertions(Optional<SplittingUtils.SplitScore> splitScoreOptional) { SplittingUtils.SplitScore splitScore; Assert.assertTrue(splitScoreOptional.isPresent()); splitScore = splitScoreOptional.get(); Assert.assertEquals("last index: " + splitScore.indexOfLastValueCounterInTrueSet, splitScore.indexOfLastValueCounterInTrueSet, 1); Assert.assertEquals("probOfTrueSet: " + splitScore.probabilityOfBeingInTrueSet, splitScore.probabilityOfBeingInTrueSet, 0.5, 1E-5); } @Test public void testMinSplitFractionEffect () { List<ClassifierInstance> td = getExtendedInstances(); DTBinaryCatBranchReducer<ClassifierInstance> reducer = new DTBinaryCatBranchReducer<>(td, 0.0); Optional<AttributeStats<ClassificationCounter>> attStatsOptional = reducer.getAttributeStats("t"); AttributeStats<ClassificationCounter> attStats = attStatsOptional.get(); ClassificationCounter aggregateData = ClassificationCounter.countAll(td); BinaryClassAttributeValueIgnoringStrategy attributeValueIgnoringStrategy = new BinaryClassAttributeValueIgnoringStrategy(aggregateData, 0); Optional<SplittingUtils.SplitScore> splitScoreOptional = SplittingUtils.splitSortedAttributeStats(attStats, new GRPenalizedGiniImpurityScorerFactory(), new DTBranchingConditions().minSplitFraction(.3).minLeafInstances(0).minScore(0), attributeValueIgnoringStrategy, true); Assert.assertTrue(splitScoreOptional.isPresent()); SplittingUtils.SplitScore splitScore = splitScoreOptional.get(); int subOptimalNumberOfEntriesGivenMinSplitFraction = 2; Assert.assertEquals("last index: " + splitScore.indexOfLastValueCounterInTrueSet, splitScore.indexOfLastValueCounterInTrueSet, subOptimalNumberOfEntriesGivenMinSplitFraction); Set<Double> expectedTrueSet = Sets.newHashSet(); Assert.assertTrue(splitScore.trueSet.contains(1.0) && splitScore.trueSet.contains(2.0)); Assert.assertEquals("probOfTrueSet: " + splitScore.probabilityOfBeingInTrueSet, 3.0 / 8, splitScore.probabilityOfBeingInTrueSet, 1E-5); } @Test public void findBestNumericSplit() throws Exception { List<ClassifierInstance> td = getExtendedInstances(); int numSamplesPerBin = 2; int numNumericBins = 4; DTNumBranchReducer<ClassifierInstance> reducer = new DTNumBranchReducer<>(td, numSamplesPerBin, numNumericBins); Optional<AttributeStats<ClassificationCounter>> attStatsOptional = reducer.getAttributeStats("t"); AttributeStats<ClassificationCounter> attStats = attStatsOptional.get();//should not be absent ClassificationCounter aggregateData = ClassificationCounter.countAll(td); BinaryClassAttributeValueIgnoringStrategy attributeValueIgnoringStrategy = new BinaryClassAttributeValueIgnoringStrategy(aggregateData, 0); Optional<SplittingUtils.SplitScore> splitScoreOptional = SplittingUtils.splitSortedAttributeStats(attStats, new GRPenalizedGiniImpurityScorerFactory(), new DTBranchingConditions().minSplitFraction(.25).minLeafInstances(0).minScore(0), attributeValueIgnoringStrategy, false); Assert.assertTrue(splitScoreOptional.isPresent()); SplittingUtils.SplitScore splitScore = splitScoreOptional.get(); Assert.assertEquals("last index: " + splitScore.indexOfLastValueCounterInTrueSet, splitScore.indexOfLastValueCounterInTrueSet, 0); Assert.assertEquals("probOfTrueSet: " + splitScore.probabilityOfBeingInTrueSet, splitScore.probabilityOfBeingInTrueSet, 0.25, 1E-5); } public static List<ClassifierInstance> getInstances() { List<ClassifierInstance> td = Lists.newArrayList(); AttributesMap atMap = AttributesMap.newHashMap(); atMap.put("t", 1.0); td.add(new ClassifierInstance(atMap, 0.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", 2.0); td.add(new ClassifierInstance(atMap, 0.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", 3.0); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", 4.0); td.add(new ClassifierInstance(atMap, 1.0)); return td; } public static List<ClassifierInstance> getExtendedInstances() { List<ClassifierInstance> td = getInstances(); AttributesMap atMap = AttributesMap.newHashMap(); atMap = AttributesMap.newHashMap(); atMap.put("t", 5.0); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", 6.0); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", 7.0); td.add(new ClassifierInstance(atMap, 1.0)); atMap = AttributesMap.newHashMap(); atMap.put("t", 8.0); td.add(new ClassifierInstance(atMap, 1.0)); return td; } }