package quickml.supervised.tree.decisionTree;
import com.beust.jcommander.internal.Lists;
import org.junit.Assert;
import org.junit.Test;
import quickml.data.instances.ClassifierInstance;
import quickml.InstanceLoader;
import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.OldTree;
import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.OldTreeBuilder;
import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldScorers.GiniImpurityOldScorer;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder;
import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.nodes.Branch;
import quickml.supervised.tree.nodes.Leaf;
import quickml.supervised.tree.nodes.Node;
import quickml.supervised.tree.summaryStatistics.ValueCounter;
import java.util.List;
/**
* Created by alexanderhawk on 4/26/15.
*/
public class DecisionOldOldTreeBuilderTest {
@Test
public void singleTreeTest() {
int maxDepth = 8;
double minSplitFraction = 0.1;
int minLeafInstances = 20;
int minAttributeOccurences = 11;
List<ClassifierInstance> instances = Lists.newArrayList(InstanceLoader.getAdvertisingInstances());//.subList(0, 10000);
OldTreeBuilder modelBuilder = new OldTreeBuilder().scorer(new GiniImpurityOldScorer()).
maxDepth(maxDepth).
minSplitFraction(minSplitFraction).
degreeOfGainRatioPenalty(1.0).
minCategoricalAttributeValueOccurances(minAttributeOccurences)
.attributeIgnoringStrategy(new quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.oldAttributeIgnoringStrategies.IgnoreAttributesWithConstantProbability(0.7));
OldTree oldTreeOld = modelBuilder.buildPredictiveModel(instances);
DecisionTreeBuilder<ClassifierInstance> decisionTreeBuilder = new DecisionTreeBuilder<>().numSamplesPerNumericBin(18).numNumericBins(6)
.attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.7)).maxDepth(maxDepth).minSplitFraction(minSplitFraction)
.degreeOfGainRatioPenalty(1.0).minAttributeValueOccurences(minAttributeOccurences).minLeafInstances(minLeafInstances);
DecisionTree decisionTree = decisionTreeBuilder.buildPredictiveModel(instances);
Conditions<ClassificationCounter> conditions = new Conditions<>(maxDepth, minAttributeOccurences, minSplitFraction, minLeafInstances);
recurseTree(decisionTree.root, conditions);
// RandomDecisionForestBuilder<ClassifierInstance> randomDecisionForestBuilder = new RandomDecisionForestBuilder<>(decisionTreeBuilder).numTrees(5);
// RandomDecisionForest randomDecisionForest = randomDecisionForestBuilder.buildPredictiveModel(instances);
// for (DecisionTree forestTree : randomDecisionForest.regressionTrees) {
// recurseTree(forestTree.root, conditions);
// }
for (ClassifierInstance instance: instances) {
oldTreeOld.getProbability(instance.getAttributes(),1.0);// Assert.assertTrue("prob: " + randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(), 1.0),randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(), 1.0) < 1.0);
oldTreeOld.getProbability(instance.getAttributes(),1.0);// Assert.assertTrue("prob: " + randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(), 1.0),randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(), 1.0) < 1.0);
decisionTree.getProbability(instance.getAttributes(), 1.0);//Assert.assertTrue("prob: "+ randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(),0.0), randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(), 0.0) < 1.0);
decisionTree.getProbability(instance.getAttributes(), 0.0);//Assert.assertTrue("prob: "+ randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(),0.0), randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(), 0.0) < 1.0);
Assert.assertEquals(decisionTree.getProbability(instance.getAttributes(), 1.0) +decisionTree.getProbability(instance.getAttributes(), 0.0), 1.0, 1E-5);
}
}
@Test
public void mockTreeTest() {
int maxDepth = 8;
double minSplitFraction = 0.1;
int minLeafInstances = 20;
int minAttributeOccurences = 11;
DecisionTreeBuilder<ClassifierInstance> decisionTreeBuilder = new DecisionTreeBuilder<>().numSamplesPerNumericBin(25).numNumericBins(6)
.attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.7)).maxDepth(maxDepth).minSplitFraction(minSplitFraction)
.degreeOfGainRatioPenalty(1.0).minAttributeValueOccurences(minAttributeOccurences).minLeafInstances(minLeafInstances);
List<ClassifierInstance> instances = Lists.newArrayList(InstanceLoader.getAdvertisingInstances());//.subList(0, 10000);
DecisionTree decisionTree = decisionTreeBuilder.buildPredictiveModel(instances);
Conditions<ClassificationCounter> conditions = new Conditions<>(maxDepth, minAttributeOccurences, minSplitFraction, minLeafInstances);
recurseTree(decisionTree.root, conditions);
RandomDecisionForestBuilder<ClassifierInstance> randomDecisionForestBuilder = new RandomDecisionForestBuilder<>(decisionTreeBuilder).numTrees(5);
RandomDecisionForest randomDecisionForest = randomDecisionForestBuilder.buildPredictiveModel(instances);
for (DecisionTree forestTree : randomDecisionForest.decisionTrees) {
recurseTree(forestTree.root, conditions);
}
}
@Test
public void randomForestTest(){
int maxDepth = 8;
double minSplitFraction = 0.1;
int minLeafInstances = 20;
int minAttributeOccurences = 11;
DecisionTreeBuilder<ClassifierInstance> decisionTreeBuilder = new DecisionTreeBuilder<>().numSamplesPerNumericBin(25).numNumericBins(6)
.attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.7)).maxDepth(maxDepth).minSplitFraction(minSplitFraction)
.degreeOfGainRatioPenalty(1.0).minAttributeValueOccurences(minAttributeOccurences).minLeafInstances(minLeafInstances);
List<ClassifierInstance> instances = Lists.newArrayList(InstanceLoader.getAdvertisingInstances());//.subList(0, 10000);
RandomDecisionForestBuilder<ClassifierInstance> randomDecisionForestBuilder = new RandomDecisionForestBuilder<>(decisionTreeBuilder).numTrees(5);
RandomDecisionForest randomDecisionForest = randomDecisionForestBuilder.buildPredictiveModel(instances);
Conditions<ClassificationCounter> conditions = new Conditions<>(maxDepth, minAttributeOccurences, minSplitFraction, minLeafInstances);
for (DecisionTree forestTree : randomDecisionForest.decisionTrees) {
recurseTree(forestTree.root, conditions);
}
for (ClassifierInstance instance: instances) {
randomDecisionForest.getProbability(instance.getAttributes(),1.0);// Assert.assertTrue("prob: " + randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(), 1.0),randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(), 1.0) < 1.0);
randomDecisionForest.getProbability(instance.getAttributes(),0.0);//Assert.assertTrue("prob: "+ randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(),0.0), randomDecisionForest.getProbabilityOfPositiveClassification(instance.getAttributes(), 0.0) < 1.0);
}
}
private static void recurseTree(Node<ClassificationCounter> node, Conditions<ClassificationCounter> conditions) {
conditions.satisfiesConditions(node);
if (node instanceof Branch) {
recurseTree(((Branch<ClassificationCounter>)node).getTrueChild(), conditions);
recurseTree(((Branch<ClassificationCounter>)node).getFalseChild(), conditions);
}
}
public static class Conditions<VC extends ValueCounter<VC>> {
private int maxDepth;
private int minAttributeOccurrences;
private double minSplitFraction;
private int minInstancesPerLeaf;
public Conditions(int maxDepth, int minAttributeOccurrences, double minSplitFraction, int minInstancesPerLeaf) {
this.maxDepth = maxDepth;
this.minAttributeOccurrences = minAttributeOccurrences;
this.minSplitFraction = minSplitFraction;
this.minInstancesPerLeaf = minInstancesPerLeaf;
}
public void satisfiesConditions(Node<VC> node) {
Assert.assertTrue(node!=null);
if (node instanceof Branch) {
satisfiesBranchConditions((Branch<VC>)node);
} else {
satisfiesLeafConditions((Leaf<VC>)node);
}
}
private void satisfiesBranchConditions(Branch<VC> branch) {
Assert.assertTrue("attribute: " + branch.attribute+ ". branch.getProbabilityOfTrueChild(): "+ branch.getProbabilityOfTrueChild(), branch.getProbabilityOfTrueChild() >= minSplitFraction && 1.0 - branch.getProbabilityOfTrueChild() >= minSplitFraction);
Assert.assertTrue("branch.getValueCounter().getTotal(): " + branch.getValueCounter().getTotal(), branch.getValueCounter().getTotal()>minAttributeOccurrences);
}
private void satisfiesLeafConditions(Leaf<VC> leaf) {
Assert.assertTrue("instances at leaf: leaf.getValueCounter().getTotal()", leaf.getValueCounter().getTotal() > minInstancesPerLeaf && leaf.getValueCounter().getTotal() >= minInstancesPerLeaf);
Assert.assertTrue("leafDepth: "+ leaf.getDepth(), leaf.getDepth() <= maxDepth);
}
}
}