package quickml.supervised.tree.branchFinders; import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability; import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategy; import quickml.supervised.tree.reducers.AttributeStatisticsProducer; import quickml.supervised.tree.scorers.ScorerFactory; import quickml.supervised.tree.summaryStatistics.ValueCounter; import quickml.supervised.tree.constants.BranchType; import quickml.supervised.tree.attributeIgnoringStrategies.AttributeIgnoringStrategy; import quickml.supervised.tree.reducers.AttributeStats; import quickml.supervised.tree.nodes.Branch; import quickml.supervised.tree.branchingConditions.BranchingConditions; import java.util.*; /** * Created by alexanderhawk on 3/24/15. */ public abstract class BranchFinder<VC extends ValueCounter<VC>> { protected Set<String> candidateAttributes; protected BranchingConditions<VC> branchingConditions; protected ScorerFactory<VC> scorerFactory; protected AttributeValueIgnoringStrategy<VC> attributeValueIgnoringStrategy; protected AttributeIgnoringStrategy attributeIgnoringStrategy; public BranchFinder(Collection<String> candidateAttributes, BranchingConditions<VC> branchingConditions, ScorerFactory<VC> scorerFactory, AttributeValueIgnoringStrategy<VC> attributeValueIgnoringStrategy, AttributeIgnoringStrategy attributeIgnoringStrategy) { this.candidateAttributes = Sets.newHashSet(candidateAttributes); this.branchingConditions = branchingConditions; this.scorerFactory = scorerFactory; this.attributeValueIgnoringStrategy = attributeValueIgnoringStrategy; this.attributeIgnoringStrategy = attributeIgnoringStrategy; } public abstract BranchType getBranchType(); protected List<String> getCandidateAttributesWithIgnoringApplied(Branch<VC> parent) { List<String> attributes = Lists.newArrayList(); for (String attribute : candidateAttributes) { if (!attributeIgnoringStrategy.ignoreAttribute(attribute, parent)) { attributes.add(attribute); } } return attributes; } protected List<String> alternativeGetCandidateAttributesWithIgnoringApplied(Branch<VC> parent) { double ignoreProb = ((IgnoreAttributesWithConstantProbability) attributeIgnoringStrategy).getIgnoreAttributeProbability(); ArrayList<String> candidates = Lists.newArrayList(candidateAttributes); if (ignoreProb == 0.0) { return candidates; } int numTrialAttributes = (int)((1.0-ignoreProb)*candidates.size()); //O(N) way of shuffling the attributes to make all permutations equally likely. Collections.shuffle(candidates); return candidates.subList(0,numTrialAttributes); } public Optional<? extends Branch<VC>> findBestBranch(Branch<VC> parent, AttributeStatisticsProducer<VC> attributeStatisticsProducer) { double bestScore = 0; Optional<? extends Branch<VC>> bestBranchOptional = Optional.absent(); // for (String attribute : getCandidateAttributesWithIgnoringApplied(parent)) { for (String attribute : alternativeGetCandidateAttributesWithIgnoringApplied(parent)) { Optional<AttributeStats<VC>> attributeStatsOptional = attributeStatisticsProducer.getAttributeStats(attribute); if (!attributeStatsOptional.isPresent()) { continue; } AttributeStats<VC> attributeStats = attributeStatsOptional.get(); Optional<? extends Branch<VC>> thisBranchOptional = getBranch(parent, attributeStats); if (thisBranchOptional.isPresent()) { Branch<VC> thisBranch = thisBranchOptional.get(); if (thisBranch.score > bestScore) { bestScore = thisBranch.score; bestBranchOptional = thisBranchOptional; } } } return bestBranchOptional; } public abstract Optional<? extends Branch<VC>> getBranch(Branch<VC> parent, AttributeStats<VC> attributeStats); }