package quickml.supervised.tree.decisionTree.treeBuildContexts;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import quickml.data.instances.ClassifierInstance;
import quickml.supervised.tree.attributeIgnoringStrategies.AttributeIgnoringStrategy;
import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability;
import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategyBuilder;
import quickml.supervised.tree.branchFinders.BranchFinder;
import quickml.supervised.tree.branchFinders.BranchFinderAndReducerFactory;
import quickml.supervised.tree.branchFinders.branchFinderBuilders.BranchFinderBuilder;
import quickml.supervised.tree.branchingConditions.BranchingConditions;
import quickml.supervised.tree.constants.AttributeType;
import quickml.supervised.tree.constants.BranchType;
import quickml.supervised.dataProcessing.BasicTrainingDataSurveyor;
import quickml.supervised.tree.decisionTree.branchFinders.branchFinderBuilders.DTBinaryCatBranchFinderBuilder;
import quickml.supervised.tree.decisionTree.branchFinders.branchFinderBuilders.DTCatBranchFinderBuilder;
import quickml.supervised.tree.decisionTree.branchFinders.branchFinderBuilders.DTNumBranchFinderBuilder;
import quickml.supervised.tree.decisionTree.branchingConditions.DTBranchingConditions;
import quickml.supervised.tree.decisionTree.reducers.reducerFactories.DTBinaryCatBranchReducerFactory;
import quickml.supervised.tree.decisionTree.reducers.reducerFactories.DTCatBranchReducerFactory;
import quickml.supervised.tree.decisionTree.reducers.reducerFactories.DTNumBranchReducerFactory;
import quickml.supervised.tree.decisionTree.scorers.PenalizedGiniImpurityScorerFactory;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounterProducer;
import quickml.supervised.tree.nodes.LeafBuilder;
import quickml.supervised.tree.reducers.ReducerFactory;
import quickml.supervised.tree.scorers.ScorerFactory;
import quickml.supervised.tree.treeBuildContexts.TreeContextBuilder;
import java.io.Serializable;
import java.util.*;
import static quickml.supervised.tree.constants.ForestOptions.*;
import static quickml.supervised.tree.decisionTree.DecisionTreeBuilder.*;
/**
* Created by alexanderhawk on 6/20/15.
*/
public class DTreeContextBuilder<I extends ClassifierInstance> extends TreeContextBuilder<I, ClassificationCounter> {
@Override
public DTreeContextBuilder<I> createTreeBuildContext() {
return new DTreeContextBuilder<>();
}
@Override
public DTreeContext<I> buildContext(List<I> trainingData) {
boolean considerBooleanAttributes = hasBranchFinderBuilder(BranchType.BOOLEAN);
BasicTrainingDataSurveyor<I> decTreeTrainingDataSurveyor = new BasicTrainingDataSurveyor<>(considerBooleanAttributes);
Map<AttributeType, Set<String>> candidateAttributesByType = decTreeTrainingDataSurveyor.groupAttributesByType(trainingData);
ClassificationCounter classificationCounts = getValueCounterProducer().getValueCounter(trainingData);
List<BranchFinderAndReducerFactory<I, ClassificationCounter>> branchFinderAndReducers = intializeBranchFindersAndReducers(classificationCounts, candidateAttributesByType);
return new DTreeContext<I>(classificationCounts.allClassifications(),
(BranchingConditions<ClassificationCounter>) config.get(BRANCHING_CONDITIONS.name()),
(ScorerFactory<ClassificationCounter>) config.get(SCORER_FACTORY.name()),
branchFinderAndReducers,
(LeafBuilder<ClassificationCounter>) config.get(LEAF_BUILDER.name()),
getValueCounterProducer());
}
@Override
public ClassificationCounterProducer<I> getValueCounterProducer() {
return new ClassificationCounterProducer<>();
}
@Override
public synchronized DTreeContextBuilder<I> copy() {
//TODO: should only copy the config, and make sure the others get updated. This is redundant.
DTreeContextBuilder<I> copy = createTreeBuildContext();
copy.config = deepCopyConfig(this.config);
return copy;
}
private ArrayList<BranchFinderBuilder<ClassificationCounter>> copyBranchFinderBuilders(Map<String, Serializable> config) {
ArrayList<BranchFinderBuilder<ClassificationCounter>> copiedBranchFinderBuilders = Lists.newArrayList();
if (config.containsKey(BRANCH_FINDER_BUILDERS.name())) {
List<BranchFinderBuilder<ClassificationCounter>> bfbs = (List<BranchFinderBuilder<ClassificationCounter>>) config.get(BRANCH_FINDER_BUILDERS.name());
if (bfbs != null && !bfbs.isEmpty()) {
for (BranchFinderBuilder<ClassificationCounter> branchFinderBuilder : bfbs) {
copiedBranchFinderBuilders.add(branchFinderBuilder.copy());
}
return copiedBranchFinderBuilders;
} else {
return getDefaultBranchFinderBuilders();
}
}
return getDefaultBranchFinderBuilders();
}
private List<BranchFinderAndReducerFactory<I, ClassificationCounter>> intializeBranchFindersAndReducers(ClassificationCounter classificationCounts, Map<AttributeType, Set<String>> candidateAttributesByType) {
/**Branch finders should be paired with the correct reducers. With this method, we don't leave open the possibility for a user to make a mistake with the pairings.
* */
//
List<BranchFinderAndReducerFactory<I, ClassificationCounter>> branchFindersAndReducers = Lists.newArrayList();
int numClasses = classificationCounts.allClassifications().size();
Serializable minorityClassification = ClassificationCounter.getLeastPopularClass(classificationCounts);
Map<BranchType, ReducerFactory<I, ClassificationCounter>> reducerMap = getDefaultReducerFactories(minorityClassification);
for (BranchFinderBuilder<ClassificationCounter> branchFinderBuilder : getBranchFinderBuilders()) {
if (useBranchFinder(branchFinderBuilder, numClasses)) {
AttributeType attributeType = AttributeType.convertBranchTypeToAttributeType(branchFinderBuilder.getBranchType());
BranchFinder<ClassificationCounter> branchFinder = branchFinderBuilder.buildBranchFinder(classificationCounts, candidateAttributesByType.get(attributeType));
ReducerFactory<I, ClassificationCounter> reducerFactory = reducerMap.get(branchFinderBuilder.getBranchType());
reducerFactory.updateBuilderConfig(config);
branchFindersAndReducers.add(new BranchFinderAndReducerFactory<I, ClassificationCounter>(branchFinder, reducerFactory));
}
}
return branchFindersAndReducers;
}
private boolean useBranchFinder(BranchFinderBuilder<ClassificationCounter> branchFinderBuilder, int numClasses) {
if (branchFinderBuilder.getBranchType().equals(BranchType.BINARY_CATEGORICAL) && numClasses != 2) {
return false;
}
if (branchFinderBuilder.getBranchType().equals(BranchType.CATEGORICAL) && numClasses == 2) {
return false;
}
return true;
}
public static <I extends ClassifierInstance> Map<BranchType, ReducerFactory<I, ClassificationCounter>> getDefaultReducerFactories(Serializable minorityClassification) {
Map<BranchType, ReducerFactory<I, ClassificationCounter>> reducerFactories = Maps.newHashMap();
reducerFactories.put(BranchType.BINARY_CATEGORICAL, new DTBinaryCatBranchReducerFactory<I>(minorityClassification));
reducerFactories.put(BranchType.CATEGORICAL, new DTCatBranchReducerFactory<I>());
reducerFactories.put(BranchType.NUMERIC, new DTNumBranchReducerFactory<I>());
reducerFactories.put(BranchType.BOOLEAN, new DTCatBranchReducerFactory<I>());
return reducerFactories;
}
public static <I extends ClassifierInstance> ArrayList<BranchFinderBuilder<ClassificationCounter>> getDefaultBranchFinderBuilders() {
ArrayList<BranchFinderBuilder<ClassificationCounter>> branchFinderBuilders = Lists.newArrayList();
branchFinderBuilders.add(new DTBinaryCatBranchFinderBuilder());
branchFinderBuilders.add(new DTCatBranchFinderBuilder());
branchFinderBuilders.add(new DTNumBranchFinderBuilder());
return branchFinderBuilders;
}
@Override
public void setDefaultsAsNeeded() {
if (!config.containsKey(BRANCH_FINDER_BUILDERS.name())) {
branchFinderBuilders(DTreeContextBuilder.getDefaultBranchFinderBuilders());
}
if (!config.containsKey(MAX_DEPTH.name())) {
maxDepth(DEFAULT_MAX_DEPTH);
}
if (!config.containsKey(MIN_SLPIT_FRACTION.name())) {
minSplitFraction(DEFAULT_MIN_SPLIT_FRACTION);
}
if (!config.containsKey(ATTRIBUTE_IGNORING_STRATEGY.name())) {
attributeIgnoringStrategy(DEFAULT_ATTRIBUTE_IGNORING_STRATEGY);
}
if (!config.containsKey(NUM_SAMPLES_PER_NUMERIC_BIN.name())) {
numSamplesPerNumericBin(DEFAULT_NUM_SAMPLES_PER_NUMERIC_BIN);
}
if (!config.containsKey(NUM_NUMERIC_BINS.name())) {
numNumericBins(DEFAULT_NUM_NUMERIC_BINS);
}
if (!config.containsKey(BRANCHING_CONDITIONS.name())) {
branchingConditions(DEFAULT_BRANCHING_CONDITIONS);
}
if (!config.containsKey(SCORER_FACTORY.name())) {
scorerFactory(getDefaultScorerFactory());
}
if (!config.containsKey(DEGREE_OF_GAIN_RATIO_PENALTY.name())) {
degreeOfGainRatioPenalty(DEFAULT_DEGREE_OF_GAIN_RATIO_PENALTY);
}
if (!config.containsKey(IMBALANCE_PENALTY_POWER.name())) {
imbalancePenaltyPower(DEFAULT_IMBALANCE_PENALTY_POWER);
}
if (!config.containsKey(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name())) {
minAttributeValueOccurences(DEFAULT_MIN_ATTRIBUTE_OCCURENCES);
}
if (!config.containsKey(LEAF_BUILDER.name())) {
leafBuilder(DEFAULT_LEAF_BUILDER);
}
if (!config.containsKey(MIN_LEAF_INSTANCES.name())) {
minLeafInstances(DEFAULT_MIN_LEAF_INSTANCES);
}
if (!config.containsKey(MIN_SCORE.name())) {
minScore(DEFAULT_MIN_SCORE);
}
}
private ScorerFactory<ClassificationCounter> getDefaultScorerFactory() {
return new PenalizedGiniImpurityScorerFactory(DEFAULT_DEGREE_OF_GAIN_RATIO_PENALTY, DEFAULT_IMBALANCE_PENALTY_POWER);
}
@Override
public synchronized Map<String, Serializable> deepCopyConfig(Map<String, Serializable> config) {
Map<String, Serializable> copiedConfig = Maps.newHashMap();
if (config.containsKey(BRANCH_FINDER_BUILDERS.name())) {
copiedConfig.put(BRANCH_FINDER_BUILDERS.name(), copyBranchFinderBuilders(config));
}
if (config.containsKey(MAX_DEPTH.name())) {
copiedConfig.put(MAX_DEPTH.name(), config.get(MAX_DEPTH.name()));
}
if (config.containsKey(MIN_SLPIT_FRACTION.name())) {
copiedConfig.put(MIN_SLPIT_FRACTION.name(), config.get(MIN_SLPIT_FRACTION.name()));
}
if (config.containsKey(ATTRIBUTE_IGNORING_STRATEGY.name())) {
copiedConfig.put(ATTRIBUTE_IGNORING_STRATEGY.name(), ((AttributeIgnoringStrategy) config.get(ATTRIBUTE_IGNORING_STRATEGY.name())).copy());
}
if (config.containsKey(NUM_SAMPLES_PER_NUMERIC_BIN.name())) {
copiedConfig.put(NUM_SAMPLES_PER_NUMERIC_BIN.name(), config.get(NUM_SAMPLES_PER_NUMERIC_BIN.name()));
}
if (config.containsKey(NUM_NUMERIC_BINS.name())) {
copiedConfig.put(NUM_NUMERIC_BINS.name(), config.get(NUM_NUMERIC_BINS.name()));
}
if (config.containsKey(BRANCHING_CONDITIONS.name())) {
copiedConfig.put(BRANCHING_CONDITIONS.name(), ((BranchingConditions<ClassificationCounter>) config.get(BRANCHING_CONDITIONS.name())).copy());
}
if (config.containsKey(SCORER_FACTORY.name())) {
copiedConfig.put(SCORER_FACTORY.name(), ((ScorerFactory<ClassificationCounter>) config.get(SCORER_FACTORY.name())).copy());
}
if (config.containsKey(DEGREE_OF_GAIN_RATIO_PENALTY.name())) {
copiedConfig.put(DEGREE_OF_GAIN_RATIO_PENALTY.name(), config.get(DEGREE_OF_GAIN_RATIO_PENALTY.name()));
}
if (config.containsKey(IMBALANCE_PENALTY_POWER.name())) {
copiedConfig.put(IMBALANCE_PENALTY_POWER.name(), config.get(IMBALANCE_PENALTY_POWER.name()));
}
if (config.containsKey(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name())) {
copiedConfig.put(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name(), config.get(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name()));
}
if (config.containsKey(LEAF_BUILDER.name())) {
copiedConfig.put(LEAF_BUILDER.name(), ((LeafBuilder<ClassificationCounter>) config.get(LEAF_BUILDER.name())).copy());
}
if (config.containsKey(MIN_LEAF_INSTANCES.name())) {
copiedConfig.put(MIN_LEAF_INSTANCES.name(), config.get(MIN_LEAF_INSTANCES.name()));
}
if (config.containsKey(MIN_SCORE.name())) {
copiedConfig.put(MIN_SCORE.name(), config.get(MIN_SCORE.name()));
}
if (config.containsKey(EXEMPT_ATTRIBUTES.name())) {
copiedConfig.put(EXEMPT_ATTRIBUTES.name(), Sets.newHashSet(((Set<String>) config.get(EXEMPT_ATTRIBUTES.name()))));
}
if (config.containsKey(ATTRIBUTE_VALUE_IGNORING_STRATEGY_BUILDER.name())) {
copiedConfig.put(ATTRIBUTE_VALUE_IGNORING_STRATEGY_BUILDER.name(), ((AttributeValueIgnoringStrategyBuilder<ClassificationCounter>) config.get(ATTRIBUTE_VALUE_IGNORING_STRATEGY.name())).copy());
}
return copiedConfig;
}
public void maxDepth(int maxDepth) {
config.put(MAX_DEPTH.name(), maxDepth);
}
public void leafBuilder(LeafBuilder<ClassificationCounter> leafBuilder) {
config.put(LEAF_BUILDER.name(), leafBuilder);
}
//doesn't have default
public void ignoreAttributeProbability(double ignoreAttributeProbability) {
config.put(ATTRIBUTE_IGNORING_STRATEGY.name(), new IgnoreAttributesWithConstantProbability(ignoreAttributeProbability));
}
public void minSplitFraction(double minSplitFraction) {
config.put(MIN_SLPIT_FRACTION.name(), minSplitFraction);
}
public void minLeafInstances(int minLeafInstances) {
config.put(MIN_LEAF_INSTANCES.name(), minLeafInstances);
}
//doesn't have a default.
public void exemptAttributes(HashSet<String> exemptAttributes) {
config.put(EXEMPT_ATTRIBUTES.name(), exemptAttributes);
}
public void attributeIgnoringStrategy(AttributeIgnoringStrategy attributeIgnoringStrategy) {
config.put(ATTRIBUTE_IGNORING_STRATEGY.name(), attributeIgnoringStrategy);
}
//if not specified, the appropriate attrValIgnoringStrategy will be chosen when building BranchFinders
public void attributeValueIgnoringStrategyBuilder(AttributeValueIgnoringStrategyBuilder<ClassificationCounter> attributeValueIgnoringStrategyBuilder) {
config.put(ATTRIBUTE_VALUE_IGNORING_STRATEGY_BUILDER.name(), attributeValueIgnoringStrategyBuilder);
}
public void numSamplesPerNumericBin(int numSamplesPerNumericBin) {
config.put(NUM_SAMPLES_PER_NUMERIC_BIN.name(), numSamplesPerNumericBin);
}
public void numNumericBins(int numNumericBins) {
config.put(NUM_NUMERIC_BINS.name(), numNumericBins);
}
public void branchingConditions(DTBranchingConditions branchingConditions) {
config.put(BRANCHING_CONDITIONS.name(), branchingConditions);
}
public void scorerFactory(ScorerFactory<ClassificationCounter> scorerFactory) {
config.put(SCORER_FACTORY.name(), scorerFactory);
}
public void degreeOfGainRatioPenalty(double degreeOfGainRatioPenalty) {
config.put(DEGREE_OF_GAIN_RATIO_PENALTY.name(), degreeOfGainRatioPenalty);
}
public void imbalancePenaltyPower(double imbalancePenaltyPower) {
config.put(IMBALANCE_PENALTY_POWER.name(), imbalancePenaltyPower);
}
public void branchFinderBuilders(ArrayList<? extends BranchFinderBuilder<ClassificationCounter>> branchFinderBuilders) {
config.put(BRANCH_FINDER_BUILDERS.name(), branchFinderBuilders);
}
public void minAttributeValueOccurences(int minAttributeValueOccurences) {
config.put(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name(), minAttributeValueOccurences);
}
public void minScore(double minScore) {
config.put(MIN_SCORE.name(), minScore);
}
}