package quickml.supervised.tree.regressionTree; import com.google.common.collect.Lists; import org.javatuples.Pair; import quickml.data.instances.RegressionInstance; import quickml.supervised.PredictiveModelBuilder; import quickml.supervised.tree.attributeIgnoringStrategies.AttributeIgnoringStrategy; import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability; import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategyBuilder; import quickml.supervised.tree.branchFinders.branchFinderBuilders.BranchFinderBuilder; import quickml.supervised.tree.decisionTree.DecisionTree; import quickml.supervised.tree.decisionTree.branchingConditions.DTBranchingConditions; import quickml.supervised.tree.nodes.LeafBuilder; import quickml.supervised.tree.nodes.Node; import quickml.supervised.tree.regressionTree.nodes.RTLeafBuilder; import quickml.supervised.tree.regressionTree.treeBuildContexts.RTreeContextBuilder; import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; import quickml.supervised.tree.scorers.ScorerFactory; import java.io.Serializable; import java.util.*; /** * Created by alexanderhawk on 6/20/15. */ public class RegressionTreeBuilder<I extends RegressionInstance> implements PredictiveModelBuilder< RegressionTree, I> { //why implement TreeBuilder, why not PredictiveModelBuilder public static final int DEFAULT_MAX_DEPTH = 5; public static final int DEFAULT_NUM_SAMPLES_PER_NUMERIC_BIN = 20; public static final IgnoreAttributesWithConstantProbability DEFAULT_ATTRIBUTE_IGNORING_STRATEGY = new IgnoreAttributesWithConstantProbability(0.7); public static final int DEFAULT_NUM_NUMERIC_BINS = 5; public static final DTBranchingConditions DEFAULT_BRANCHING_CONDITIONS = new DTBranchingConditions(); public static final double DEFAULT_DEGREE_OF_GAIN_RATIO_PENALTY = 1.0; public static final double DEFAULT_IMBALANCE_PENALTY_POWER = 0.0; public static final double DEFAULT_MIN_SPLIT_FRACTION = 0.005; public static final int DEFAULT_MIN_LEAF_INSTANCES = 0; public static final int DEFAULT_MIN_ATTRIBUTE_OCCURENCES = 0; public static final LeafBuilder<MeanValueCounter> DEFAULT_LEAF_BUILDER = new RTLeafBuilder(); public static final double DEFAULT_MIN_SCORE = 0.00000000000001; private final RTreeContextBuilder<I> tcb; private RegressionTreeBuilder(RTreeContextBuilder<I> tcb) { this.tcb = tcb.copy(); } public RegressionTreeBuilder(){ this.tcb = new RTreeContextBuilder<>(); } @Override public RegressionTree buildPredictiveModel(Iterable<I> trainingData) { tcb.initializeConfig(); RegressionTreeBuilderHelper<I> treeBuilderHelper = new RegressionTreeBuilderHelper<>(tcb); ArrayList<I> trainingDataList = Lists.newArrayList(trainingData); Node<MeanValueCounter> root = treeBuilderHelper.computeNodes(trainingDataList); return new RegressionTree(root); } @Override public void updateBuilderConfig(Map<String, Serializable> config) { tcb.setConfig(config); } public synchronized RegressionTreeBuilder<I> copy() { return new RegressionTreeBuilder<>(tcb); } public RegressionTreeBuilder<I> maxDepth(int maxDepth) { tcb.maxDepth(maxDepth); return this; } public RegressionTreeBuilder<I> leafBuilder(LeafBuilder<MeanValueCounter> leafBuilder) { tcb.leafBuilder(leafBuilder); return this; } public RegressionTreeBuilder<I> ignoreAttributeProbability(double ignoreAttributeProbability) { tcb.ignoreAttributeProbability(ignoreAttributeProbability); return this; } public RegressionTreeBuilder<I> minSplitFraction(double minSplitFraction) { tcb.minSplitFraction(minSplitFraction); return this; } public RegressionTreeBuilder<I> minLeafInstances(int minLeafInstances) { tcb.minLeafInstances(minLeafInstances); return this; } public RegressionTreeBuilder<I> exemptAttributes(HashSet<String> exemptAttributes) { tcb.exemptAttributes(exemptAttributes); return this; } public RegressionTreeBuilder<I> attributeIgnoringStrategy(AttributeIgnoringStrategy attributeIgnoringStrategy) { tcb.attributeIgnoringStrategy(attributeIgnoringStrategy); return this; } public RegressionTreeBuilder<I> attributeValueIgnoringStrategyBuilder(AttributeValueIgnoringStrategyBuilder<MeanValueCounter> attributeValueIgnoringStrategyBuilder) { tcb.attributeValueIgnoringStrategyBuilder(attributeValueIgnoringStrategyBuilder); return this; } public RegressionTreeBuilder<I> numSamplesPerNumericBin(int numSamplesPerNumericBin) { tcb.numSamplesPerNumericBin(numSamplesPerNumericBin); return this; } public RegressionTreeBuilder<I> numNumericBins(int numNumericBins) { tcb.numNumericBins(numNumericBins); return this; } public RegressionTreeBuilder<I> branchingConditions(DTBranchingConditions branchingConditions) { tcb.branchingConditions(branchingConditions); return this; } public RegressionTreeBuilder<I> scorerFactory(ScorerFactory<MeanValueCounter> scorerFactory) { tcb.scorerFactory(scorerFactory); return this; } public RegressionTreeBuilder<I> degreeOfGainRatioPenalty(double degreeOfGainRatioPenalty) { tcb.degreeOfGainRatioPenalty(degreeOfGainRatioPenalty); return this; } public RegressionTreeBuilder<I> imbalancePenaltyPower(double imbalancePenaltyPower) { tcb.imbalancePenaltyPower(imbalancePenaltyPower); return this; } public RegressionTreeBuilder<I> branchFinderBuilders(ArrayList<? extends BranchFinderBuilder<MeanValueCounter>> branchFinderBuilders) { tcb.branchFinderBuilders(branchFinderBuilders); return this; } public RegressionTreeBuilder<I> minAttributeValueOccurences(int minAttributeValueOccurences) { tcb.minAttributeValueOccurences(minAttributeValueOccurences); return this; } public RegressionTreeBuilder<I> minScore(double minScore) { tcb.minScore(minScore); return this; } }