package quickml.supervised.tree.regressionTree.treeBuildContexts;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import quickml.data.instances.RegressionInstance;
import quickml.supervised.tree.attributeIgnoringStrategies.AttributeIgnoringStrategy;
import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability;
import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategyBuilder;
import quickml.supervised.tree.branchFinders.BranchFinder;
import quickml.supervised.tree.branchFinders.BranchFinderAndReducerFactory;
import quickml.supervised.tree.branchFinders.branchFinderBuilders.BranchFinderBuilder;
import quickml.supervised.tree.branchingConditions.BranchingConditions;
import quickml.supervised.tree.constants.AttributeType;
import quickml.supervised.tree.constants.BranchType;
import quickml.supervised.dataProcessing.BasicTrainingDataSurveyor;
import quickml.supervised.tree.decisionTree.branchingConditions.DTBranchingConditions;
import quickml.supervised.tree.nodes.LeafBuilder;
import quickml.supervised.tree.reducers.ReducerFactory;
import quickml.supervised.tree.regressionTree.branchFinders.branchFinderBuilders.RTCatBranchFinderBuilder;
import quickml.supervised.tree.regressionTree.branchFinders.branchFinderBuilders.RTNumBranchFinderBuilder;
import quickml.supervised.tree.regressionTree.reducers.reducerFactories.RTCatBranchReducerFactory;
import quickml.supervised.tree.regressionTree.reducers.reducerFactories.RTNumBranchReducerFactory;
import quickml.supervised.tree.regressionTree.scorers.RTPenalizedMSEScorerFactory;
import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter;
import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounterProducer;
import quickml.supervised.tree.scorers.ScorerFactory;
import quickml.supervised.tree.treeBuildContexts.TreeContextBuilder;
import java.io.Serializable;
import java.util.*;
import static quickml.supervised.tree.constants.ForestOptions.*;
import static quickml.supervised.tree.regressionTree.RegressionTreeBuilder.*;
/**
* Created by alexanderhawk on 6/20/15.
*/
public class RTreeContextBuilder<I extends RegressionInstance> extends TreeContextBuilder<I, MeanValueCounter> {
@Override
public RTreeContextBuilder<I> createTreeBuildContext() {
return new RTreeContextBuilder<>();
}
@Override
public RTreeContext<I> buildContext(List<I> trainingData) {
boolean considerBooleanAttributes = hasBranchFinderBuilder(BranchType.BOOLEAN);
BasicTrainingDataSurveyor<I> decTreeTrainingDataSurveyor = new BasicTrainingDataSurveyor<>(considerBooleanAttributes);
Map<AttributeType, Set<String>> candidateAttributesByType = decTreeTrainingDataSurveyor.groupAttributesByType(trainingData);
List<BranchFinderAndReducerFactory<I, MeanValueCounter>> branchFinderAndReducers = intializeBranchFindersAndReducers(candidateAttributesByType);
return new RTreeContext<I>(
(BranchingConditions<MeanValueCounter>) config.get(BRANCHING_CONDITIONS.name()),
(ScorerFactory<MeanValueCounter>) config.get(SCORER_FACTORY.name()),
branchFinderAndReducers,
(LeafBuilder<MeanValueCounter>) config.get(LEAF_BUILDER.name()),
getValueCounterProducer());
}
@Override
public MeanValueCounterProducer<I> getValueCounterProducer() {
return new MeanValueCounterProducer<>();
}
@Override
public synchronized RTreeContextBuilder<I> copy() {
//TODO: should only copy the config, and make sure the others get updated. This is redundant.
RTreeContextBuilder<I> copy = createTreeBuildContext();
copy.config = deepCopyConfig(this.config);
return copy;
}
private ArrayList<BranchFinderBuilder<MeanValueCounter>> copyBranchFinderBuilders(Map<String, Serializable> config) {
ArrayList<BranchFinderBuilder<MeanValueCounter>> copiedBranchFinderBuilders = Lists.newArrayList();
if (config.containsKey(BRANCH_FINDER_BUILDERS.name())) {
List<BranchFinderBuilder<MeanValueCounter>> bfbs = (List<BranchFinderBuilder<MeanValueCounter>>) config.get(BRANCH_FINDER_BUILDERS.name());
if (bfbs != null && !bfbs.isEmpty()) {
for (BranchFinderBuilder<MeanValueCounter> branchFinderBuilder : bfbs) {
copiedBranchFinderBuilders.add(branchFinderBuilder.copy());
}
return copiedBranchFinderBuilders;
} else {
return getDefaultBranchFinderBuilders();
}
}
return getDefaultBranchFinderBuilders();
}
private List<BranchFinderAndReducerFactory<I, MeanValueCounter>> intializeBranchFindersAndReducers(Map<AttributeType, Set<String>> candidateAttributesByType) {
/**Branch finders should be paired with the correct reducers. With this method, we don't leave open the possibility for a user to make a mistake with the pairings.
* */
//
List<BranchFinderAndReducerFactory<I, MeanValueCounter>> branchFindersAndReducers = Lists.newArrayList();
Map<BranchType, ReducerFactory<I, MeanValueCounter>> reducerMap = getDefaultReducerFactories();
for (BranchFinderBuilder<MeanValueCounter> branchFinderBuilder : getBranchFinderBuilders()) {
AttributeType attributeType = AttributeType.convertBranchTypeToAttributeType(branchFinderBuilder.getBranchType());
BranchFinder<MeanValueCounter> branchFinder = branchFinderBuilder.buildBranchFinder(null, candidateAttributesByType.get(attributeType));
ReducerFactory<I, MeanValueCounter> reducerFactory = reducerMap.get(branchFinderBuilder.getBranchType());
reducerFactory.updateBuilderConfig(config);
branchFindersAndReducers.add(new BranchFinderAndReducerFactory<I, MeanValueCounter>(branchFinder, reducerFactory));
}
return branchFindersAndReducers;
}
public static <I extends RegressionInstance> Map<BranchType, ReducerFactory<I, MeanValueCounter>> getDefaultReducerFactories() {
Map<BranchType, ReducerFactory<I, MeanValueCounter>> reducerFactories = Maps.newHashMap();
reducerFactories.put(BranchType.RT_CATEGORICAL, new RTCatBranchReducerFactory<I>());
reducerFactories.put(BranchType.RT_NUMERIC, new RTNumBranchReducerFactory<I>());
return reducerFactories;
}
public static <I extends RegressionInstance> ArrayList<BranchFinderBuilder<MeanValueCounter>> getDefaultBranchFinderBuilders() {
ArrayList<BranchFinderBuilder<MeanValueCounter>> branchFinderBuilders = Lists.newArrayList();
branchFinderBuilders.add(new RTCatBranchFinderBuilder());
branchFinderBuilders.add(new RTCatBranchFinderBuilder());
branchFinderBuilders.add(new RTNumBranchFinderBuilder());
return branchFinderBuilders;
}
@Override
public void setDefaultsAsNeeded() {
if (!config.containsKey(BRANCH_FINDER_BUILDERS.name())) {
branchFinderBuilders(RTreeContextBuilder.getDefaultBranchFinderBuilders());
}
if (!config.containsKey(MAX_DEPTH.name())) {
maxDepth(DEFAULT_MAX_DEPTH);
}
if (!config.containsKey(MIN_SLPIT_FRACTION.name())) {
minSplitFraction(DEFAULT_MIN_SPLIT_FRACTION);
}
if (!config.containsKey(ATTRIBUTE_IGNORING_STRATEGY.name())) {
attributeIgnoringStrategy(DEFAULT_ATTRIBUTE_IGNORING_STRATEGY);
}
if (!config.containsKey(NUM_SAMPLES_PER_NUMERIC_BIN.name())) {
numSamplesPerNumericBin(DEFAULT_NUM_SAMPLES_PER_NUMERIC_BIN);
}
if (!config.containsKey(NUM_NUMERIC_BINS.name())) {
numNumericBins(DEFAULT_NUM_NUMERIC_BINS);
}
if (!config.containsKey(BRANCHING_CONDITIONS.name())) {
branchingConditions(DEFAULT_BRANCHING_CONDITIONS);
}
if (!config.containsKey(SCORER_FACTORY.name())) {
scorerFactory(getDefaultScorerFactory());
}
if (!config.containsKey(DEGREE_OF_GAIN_RATIO_PENALTY.name())) {
degreeOfGainRatioPenalty(DEFAULT_DEGREE_OF_GAIN_RATIO_PENALTY);
}
if (!config.containsKey(IMBALANCE_PENALTY_POWER.name())) {
imbalancePenaltyPower(DEFAULT_IMBALANCE_PENALTY_POWER);
}
if (!config.containsKey(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name())) {
minAttributeValueOccurences(DEFAULT_MIN_ATTRIBUTE_OCCURENCES);
}
if (!config.containsKey(LEAF_BUILDER.name())) {
leafBuilder(DEFAULT_LEAF_BUILDER);
}
if (!config.containsKey(MIN_LEAF_INSTANCES.name())) {
minLeafInstances(DEFAULT_MIN_LEAF_INSTANCES);
}
if (!config.containsKey(MIN_SCORE.name())) {
minScore(DEFAULT_MIN_SCORE);
}
}
private ScorerFactory<MeanValueCounter> getDefaultScorerFactory() {
return new RTPenalizedMSEScorerFactory(DEFAULT_DEGREE_OF_GAIN_RATIO_PENALTY, DEFAULT_IMBALANCE_PENALTY_POWER);
}
@Override
public synchronized Map<String, Serializable> deepCopyConfig(Map<String, Serializable> config) {
Map<String, Serializable> copiedConfig = Maps.newHashMap();
if (config.containsKey(BRANCH_FINDER_BUILDERS.name())) {
copiedConfig.put(BRANCH_FINDER_BUILDERS.name(), copyBranchFinderBuilders(config));
}
if (config.containsKey(MAX_DEPTH.name())) {
copiedConfig.put(MAX_DEPTH.name(), config.get(MAX_DEPTH.name()));
}
if (config.containsKey(MIN_SLPIT_FRACTION.name())) {
copiedConfig.put(MIN_SLPIT_FRACTION.name(), config.get(MIN_SLPIT_FRACTION.name()));
}
if (config.containsKey(ATTRIBUTE_IGNORING_STRATEGY.name())) {
copiedConfig.put(ATTRIBUTE_IGNORING_STRATEGY.name(), ((AttributeIgnoringStrategy) config.get(ATTRIBUTE_IGNORING_STRATEGY.name())).copy());
}
if (config.containsKey(NUM_SAMPLES_PER_NUMERIC_BIN.name())) {
copiedConfig.put(NUM_SAMPLES_PER_NUMERIC_BIN.name(), config.get(NUM_SAMPLES_PER_NUMERIC_BIN.name()));
}
if (config.containsKey(NUM_NUMERIC_BINS.name())) {
copiedConfig.put(NUM_NUMERIC_BINS.name(), config.get(NUM_NUMERIC_BINS.name()));
}
if (config.containsKey(BRANCHING_CONDITIONS.name())) {
copiedConfig.put(BRANCHING_CONDITIONS.name(), ((BranchingConditions<MeanValueCounter>) config.get(BRANCHING_CONDITIONS.name())).copy());
}
if (config.containsKey(SCORER_FACTORY.name())) {
copiedConfig.put(SCORER_FACTORY.name(), ((ScorerFactory<MeanValueCounter>) config.get(SCORER_FACTORY.name())).copy());
}
if (config.containsKey(DEGREE_OF_GAIN_RATIO_PENALTY.name())) {
copiedConfig.put(DEGREE_OF_GAIN_RATIO_PENALTY.name(), config.get(DEGREE_OF_GAIN_RATIO_PENALTY.name()));
}
if (config.containsKey(IMBALANCE_PENALTY_POWER.name())) {
copiedConfig.put(IMBALANCE_PENALTY_POWER.name(), config.get(IMBALANCE_PENALTY_POWER.name()));
}
if (config.containsKey(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name())) {
copiedConfig.put(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name(), config.get(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name()));
}
if (config.containsKey(LEAF_BUILDER.name())) {
copiedConfig.put(LEAF_BUILDER.name(), ((LeafBuilder<MeanValueCounter>) config.get(LEAF_BUILDER.name())).copy());
}
if (config.containsKey(MIN_LEAF_INSTANCES.name())) {
copiedConfig.put(MIN_LEAF_INSTANCES.name(), config.get(MIN_LEAF_INSTANCES.name()));
}
if (config.containsKey(MIN_SCORE.name())) {
copiedConfig.put(MIN_SCORE.name(), config.get(MIN_SCORE.name()));
}
if (config.containsKey(EXEMPT_ATTRIBUTES.name())) {
copiedConfig.put(EXEMPT_ATTRIBUTES.name(), Sets.newHashSet(((Set<String>) config.get(EXEMPT_ATTRIBUTES.name()))));
}
if (config.containsKey(ATTRIBUTE_VALUE_IGNORING_STRATEGY_BUILDER.name())) {
copiedConfig.put(ATTRIBUTE_VALUE_IGNORING_STRATEGY_BUILDER.name(), ((AttributeValueIgnoringStrategyBuilder<MeanValueCounter>) config.get(ATTRIBUTE_VALUE_IGNORING_STRATEGY.name())).copy());
}
return copiedConfig;
}
public void maxDepth(int maxDepth) {
config.put(MAX_DEPTH.name(), maxDepth);
}
public void leafBuilder(LeafBuilder<MeanValueCounter> leafBuilder) {
config.put(LEAF_BUILDER.name(), leafBuilder);
}
//doesn't have default
public void ignoreAttributeProbability(double ignoreAttributeProbability) {
config.put(ATTRIBUTE_IGNORING_STRATEGY.name(), new IgnoreAttributesWithConstantProbability(ignoreAttributeProbability));
}
public void minSplitFraction(double minSplitFraction) {
config.put(MIN_SLPIT_FRACTION.name(), minSplitFraction);
}
public void minLeafInstances(int minLeafInstances) {
config.put(MIN_LEAF_INSTANCES.name(), minLeafInstances);
}
//doesn't have a default.
public void exemptAttributes(HashSet<String> exemptAttributes) {
config.put(EXEMPT_ATTRIBUTES.name(), exemptAttributes);
}
public void attributeIgnoringStrategy(AttributeIgnoringStrategy attributeIgnoringStrategy) {
config.put(ATTRIBUTE_IGNORING_STRATEGY.name(), attributeIgnoringStrategy);
}
//if not specified, the appropriate attrValIgnoringStrategy will be chosen when building BranchFinders
public void attributeValueIgnoringStrategyBuilder(AttributeValueIgnoringStrategyBuilder<MeanValueCounter> attributeValueIgnoringStrategyBuilder) {
config.put(ATTRIBUTE_VALUE_IGNORING_STRATEGY_BUILDER.name(), attributeValueIgnoringStrategyBuilder);
}
public void numSamplesPerNumericBin(int numSamplesPerNumericBin) {
config.put(NUM_SAMPLES_PER_NUMERIC_BIN.name(), numSamplesPerNumericBin);
}
public void numNumericBins(int numNumericBins) {
config.put(NUM_NUMERIC_BINS.name(), numNumericBins);
}
public void branchingConditions(DTBranchingConditions branchingConditions) {
config.put(BRANCHING_CONDITIONS.name(), branchingConditions);
}
public void scorerFactory(ScorerFactory<MeanValueCounter> scorerFactory) {
config.put(SCORER_FACTORY.name(), scorerFactory);
}
public void degreeOfGainRatioPenalty(double degreeOfGainRatioPenalty) {
config.put(DEGREE_OF_GAIN_RATIO_PENALTY.name(), degreeOfGainRatioPenalty);
}
public void imbalancePenaltyPower(double imbalancePenaltyPower) {
config.put(IMBALANCE_PENALTY_POWER.name(), imbalancePenaltyPower);
}
public void branchFinderBuilders(ArrayList<? extends BranchFinderBuilder<MeanValueCounter>> branchFinderBuilders) {
config.put(BRANCH_FINDER_BUILDERS.name(), branchFinderBuilders);
}
public void minAttributeValueOccurences(int minAttributeValueOccurences) {
config.put(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name(), minAttributeValueOccurences);
}
public void minScore(double minScore) {
config.put(MIN_SCORE.name(), minScore);
}
}