package quickml.supervised.tree.treeBuildContexts; import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import quickml.data.instances.InstanceWithAttributesMap; import quickml.supervised.tree.scorers.ScorerFactory; import quickml.supervised.tree.summaryStatistics.ValueCounterProducer; import quickml.supervised.tree.summaryStatistics.ValueCounter; import quickml.supervised.tree.constants.BranchType; import quickml.supervised.tree.nodes.LeafBuilder; import quickml.supervised.tree.branchFinders.branchFinderBuilders.BranchFinderBuilder; import quickml.supervised.tree.branchingConditions.BranchingConditions; import static quickml.supervised.tree.constants.ForestOptions.*; import java.io.Serializable; import java.util.List; import java.util.Map; /** * Created by alexanderhawk on 3/20/15. */ public abstract class TreeContextBuilder<I extends InstanceWithAttributesMap<?>, VC extends ValueCounter<VC>> { protected Map<String, Serializable> config = Maps.newHashMap(); public List<? extends BranchFinderBuilder<VC>> getBranchFinderBuilders() { //TODO consider making this getter access the config, and removing the field altogether. if (config.containsKey(BRANCH_FINDER_BUILDERS.name())) { return (List<? extends BranchFinderBuilder<VC>>) config.get(BRANCH_FINDER_BUILDERS.name()); } else { List<? extends BranchFinderBuilder<VC>> emptyList = Lists.newArrayList(); return emptyList; } } public ScorerFactory<VC> getScorerFactory() { return (ScorerFactory<VC>) config.get(SCORER_FACTORY.name()); } public BranchingConditions<VC> getBranchingConditions() { return (BranchingConditions<VC>) config.get(BRANCHING_CONDITIONS.name()); } public LeafBuilder<VC> getLeafBuilder() { return (LeafBuilder<VC>) config.get(LEAF_BUILDER.name()); } public TreeContextBuilder<I, VC> copy() { TreeContextBuilder<I, VC> copy = createTreeBuildContext(); copy.config = deepCopyConfig(this.config); return copy; } public boolean hasBranchFinderBuilder(BranchType branchType) { return getBranchFinderBuilder(branchType).isPresent(); } public Optional<BranchFinderBuilder<VC>> getBranchFinderBuilder(BranchType branchType) { for (BranchFinderBuilder<VC> branchFinderBuilder : getBranchFinderBuilders()) { if (branchFinderBuilder.getBranchType().equals(branchType)) { return Optional.of(branchFinderBuilder); } } return Optional.absent(); } public void updateEachConfigElement() { if (config.containsKey(SCORER_FACTORY.name())) { ((ScorerFactory<VC>) config.get(SCORER_FACTORY.name())).update(config); } if (config.containsKey(BRANCHING_CONDITIONS.name())) { ((BranchingConditions<VC>) config.get(BRANCHING_CONDITIONS.name())).update(config); } //setting branchFinderBuilders must occur after the branching conditions and scorers are updated. if (config.containsKey(BRANCH_FINDER_BUILDERS.name())) { List<? extends BranchFinderBuilder<VC>> branchFinderBuilders = (List<? extends BranchFinderBuilder<VC>>) config.get(BRANCH_FINDER_BUILDERS.name()); if (branchFinderBuilders != null && !branchFinderBuilders.isEmpty()) for (BranchFinderBuilder<VC> branchFinderBuilder : branchFinderBuilders) { branchFinderBuilder.update(config); } } } public void setConfig(Map<String, Serializable> config) { this.config = deepCopyConfig(config); } public void initializeConfig(){ setDefaultsAsNeeded(); updateEachConfigElement(); } public abstract ValueCounterProducer<I, VC> getValueCounterProducer(); public abstract TreeContextBuilder<I, VC> createTreeBuildContext(); public abstract TreeContext<I, VC> buildContext(List<I> trainingData); public abstract void setDefaultsAsNeeded(); public abstract Map<String, Serializable> deepCopyConfig(Map<String, Serializable> config); }