package quickml.supervised.tree.decisionTree; import com.google.common.collect.Lists; import org.javatuples.Pair; import quickml.data.instances.ClassifierInstance; import quickml.supervised.PredictiveModelBuilder; import quickml.supervised.tree.attributeIgnoringStrategies.AttributeIgnoringStrategy; import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability; import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategyBuilder; import quickml.supervised.tree.branchFinders.branchFinderBuilders.BranchFinderBuilder; import quickml.supervised.tree.decisionTree.branchingConditions.DTBranchingConditions; import quickml.supervised.tree.decisionTree.nodes.DTLeafBuilder; import quickml.supervised.tree.decisionTree.treeBuildContexts.DTreeContextBuilder; import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; import quickml.supervised.tree.nodes.LeafBuilder; import quickml.supervised.tree.nodes.Node; import quickml.supervised.tree.scorers.ScorerFactory; import java.io.Serializable; import java.util.*; /** * Created by alexanderhawk on 6/20/15. */ public class DecisionTreeBuilder<I extends ClassifierInstance> implements PredictiveModelBuilder< DecisionTree, I> { //why implement TreeBuilder, why not PredictiveModelBuilder public static final int DEFAULT_MAX_DEPTH = 5; public static final int DEFAULT_NUM_SAMPLES_PER_NUMERIC_BIN = 20; public static final IgnoreAttributesWithConstantProbability DEFAULT_ATTRIBUTE_IGNORING_STRATEGY = new IgnoreAttributesWithConstantProbability(0.7); public static final int DEFAULT_NUM_NUMERIC_BINS = 5; public static final DTBranchingConditions DEFAULT_BRANCHING_CONDITIONS = new DTBranchingConditions(); public static final double DEFAULT_DEGREE_OF_GAIN_RATIO_PENALTY = 1.0; public static final double DEFAULT_IMBALANCE_PENALTY_POWER = 0.0; public static final double DEFAULT_MIN_SPLIT_FRACTION = 0.005; public static final int DEFAULT_MIN_LEAF_INSTANCES = 0; public static final int DEFAULT_MIN_ATTRIBUTE_OCCURENCES = 0; public static final LeafBuilder<ClassificationCounter> DEFAULT_LEAF_BUILDER = new DTLeafBuilder(); public static final double DEFAULT_MIN_SCORE = 0.00000000000001; private final DTreeContextBuilder<I> tcb; private DecisionTreeBuilder(DTreeContextBuilder<I> tcb) { this.tcb = tcb.copy(); } public DecisionTreeBuilder(){ this.tcb = new DTreeContextBuilder<>(); } @Override public DecisionTree buildPredictiveModel(Iterable<I> trainingData) { tcb.initializeConfig(); DecisionTreeBuilderHelper<I> treeBuilderHelper = new DecisionTreeBuilderHelper<>(tcb); ArrayList<I> trainingDataList = Lists.newArrayList(trainingData); Pair<Node<ClassificationCounter>, Set<Serializable>> rootAndClassifications = treeBuilderHelper.computeNodesAndClasses(trainingDataList); Node<ClassificationCounter> root = rootAndClassifications.getValue0(); Set<Serializable> classifications = rootAndClassifications.getValue1(); return new DecisionTree(root, classifications); } @Override public void updateBuilderConfig(Map<String, Serializable> config) { tcb.setConfig(config); } public synchronized DecisionTreeBuilder<I> copy() { return new DecisionTreeBuilder<>(tcb); } public DecisionTreeBuilder<I> maxDepth(int maxDepth) { tcb.maxDepth(maxDepth); return this; } public DecisionTreeBuilder<I> leafBuilder(LeafBuilder<ClassificationCounter> leafBuilder) { tcb.leafBuilder(leafBuilder); return this; } public DecisionTreeBuilder<I> ignoreAttributeProbability(double ignoreAttributeProbability) { tcb.ignoreAttributeProbability(ignoreAttributeProbability); return this; } public DecisionTreeBuilder<I> minSplitFraction(double minSplitFraction) { tcb.minSplitFraction(minSplitFraction); return this; } public DecisionTreeBuilder<I> minLeafInstances(int minLeafInstances) { tcb.minLeafInstances(minLeafInstances); return this; } public DecisionTreeBuilder<I> exemptAttributes(HashSet<String> exemptAttributes) { tcb.exemptAttributes(exemptAttributes); return this; } public DecisionTreeBuilder<I> attributeIgnoringStrategy(AttributeIgnoringStrategy attributeIgnoringStrategy) { tcb.attributeIgnoringStrategy(attributeIgnoringStrategy); return this; } public DecisionTreeBuilder<I> attributeValueIgnoringStrategyBuilder(AttributeValueIgnoringStrategyBuilder<ClassificationCounter> attributeValueIgnoringStrategyBuilder) { tcb.attributeValueIgnoringStrategyBuilder(attributeValueIgnoringStrategyBuilder); return this; } public DecisionTreeBuilder<I> numSamplesPerNumericBin(int numSamplesPerNumericBin) { tcb.numSamplesPerNumericBin(numSamplesPerNumericBin); return this; } public DecisionTreeBuilder<I> numNumericBins(int numNumericBins) { tcb.numNumericBins(numNumericBins); return this; } public DecisionTreeBuilder<I> branchingConditions(DTBranchingConditions branchingConditions) { tcb.branchingConditions(branchingConditions); return this; } public DecisionTreeBuilder<I> scorerFactory(ScorerFactory<ClassificationCounter> scorerFactory) { tcb.scorerFactory(scorerFactory); return this; } public DecisionTreeBuilder<I> degreeOfGainRatioPenalty(double degreeOfGainRatioPenalty) { tcb.degreeOfGainRatioPenalty(degreeOfGainRatioPenalty); return this; } public DecisionTreeBuilder<I> imbalancePenaltyPower(double imbalancePenaltyPower) { tcb.imbalancePenaltyPower(imbalancePenaltyPower); return this; } public DecisionTreeBuilder<I> branchFinderBuilders(ArrayList<? extends BranchFinderBuilder<ClassificationCounter>> branchFinderBuilders) { tcb.branchFinderBuilders(branchFinderBuilders); return this; } public DecisionTreeBuilder<I> minAttributeValueOccurences(int minAttributeValueOccurences) { tcb.minAttributeValueOccurences(minAttributeValueOccurences); return this; } public DecisionTreeBuilder<I> minScore(double minScore) { tcb.minScore(minScore); return this; } }