package quickml.supervised.tree.branchingConditions; import com.google.common.collect.Sets; import quickml.supervised.tree.nodes.Branch; import quickml.supervised.tree.nodes.Node; import quickml.supervised.tree.summaryStatistics.ValueCounter; import static quickml.supervised.tree.constants.ForestOptions.*; import java.io.Serializable; import java.util.Map; import java.util.Set; /** * Created by alexanderhawk on 4/4/15. */ public class StandardBranchingConditions<VC extends ValueCounter<VC>> implements BranchingConditions<VC> { private double minScore=0; private int maxDepth = Integer.MAX_VALUE; private int minLeafInstances = 0; private double minSplitFraction = 0; private Set<String> exemptAttributes = Sets.newHashSet(); public StandardBranchingConditions(double minScore, int maxDepth, int minLeafInstances, double minSplitFraction) { this(minScore, maxDepth, minLeafInstances, minSplitFraction, Sets.<String>newHashSet()); } public StandardBranchingConditions(double minScore, int maxDepth, int minLeafInstances, double minSplitFraction, Set<String> exemptAttributes) { this.minScore = minScore; this.maxDepth = maxDepth; this.minLeafInstances = minLeafInstances; this.minSplitFraction = minSplitFraction; this.exemptAttributes = exemptAttributes; } public StandardBranchingConditions() {} public StandardBranchingConditions minScore(double minScore) { this.minScore = minScore; return this; } public StandardBranchingConditions maxDepth(int maxDepth) { this.maxDepth = maxDepth; return this; } public StandardBranchingConditions minLeafInstances(int minLeafInstances) { this.minLeafInstances = minLeafInstances; return this; } public StandardBranchingConditions minSplitFraction(double minSplitFraction) { this.minSplitFraction = minSplitFraction; return this; } public StandardBranchingConditions exemptAttributes(Set<String> exemptAttributes) { this.exemptAttributes = exemptAttributes; return this; } @Override public boolean isInvalidSplit(VC trueSet, VC falseSet, String attribute) { return ( (getSplitFraction(trueSet, falseSet) < minSplitFraction && !exemptAttributes.contains(attribute))) || violatesMinLeafInstances(trueSet, falseSet); } @Override public boolean isInvalidSplit(VC trueSet, VC falseSet) { return (getSplitFraction(trueSet, falseSet) < minSplitFraction) || violatesMinLeafInstances(trueSet, falseSet); } private double getSplitFraction(VC trueSet, VC falseSet) { return Math.min(trueSet.getTotal(), falseSet.getTotal())/ (trueSet.getTotal() + falseSet.getTotal()); } private boolean violatesMinLeafInstances(VC trueSet, VC falseSet) { return trueSet.getTotal() < minLeafInstances || falseSet.getTotal() < minLeafInstances; } public boolean isInvalidSplit(double score) { return score <= minScore; } @Override public boolean canTryAddingChildren(Branch<VC> parent, VC totals){ return (parent==null || parent.getDepth() < maxDepth-1 && totals.getTotal() >= 2 * minLeafInstances); } @Override public void update(Map<String, Serializable> cfg) { if (cfg.containsKey(MAX_DEPTH.name())) maxDepth = (Integer) cfg.get(MAX_DEPTH.name()); if (cfg.containsKey(MIN_SCORE.name())) minScore = (Double) cfg.get(MIN_SCORE.name()); if (cfg.containsKey(MIN_LEAF_INSTANCES.name())) minLeafInstances = (Integer) cfg.get(MIN_LEAF_INSTANCES.name()); if (cfg.containsKey(MIN_SLPIT_FRACTION.name())) minSplitFraction = (Double) cfg.get(MIN_SLPIT_FRACTION.name()); if (cfg.containsKey(EXEMPT_ATTRIBUTES.name())) exemptAttributes = (Set<String>) cfg.get(EXEMPT_ATTRIBUTES.name()); } @Override public synchronized StandardBranchingConditions copy(){ return new StandardBranchingConditions(this.minScore, this.maxDepth, this.minLeafInstances, this.minSplitFraction, Sets.newHashSet(this.exemptAttributes)); } }