package quickml.supervised.tree.branchFinders; import com.google.common.base.Optional; import com.google.common.collect.Sets; import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategy; import quickml.supervised.tree.constants.MissingValue; import quickml.supervised.tree.scorers.Scorer; import quickml.supervised.tree.scorers.ScorerFactory; import quickml.supervised.tree.summaryStatistics.ValueCounter; import quickml.supervised.tree.reducers.AttributeStats; import quickml.supervised.tree.branchingConditions.BranchingConditions; import java.io.Serializable; import java.util.List; import java.util.Set; /** * Created by alexanderhawk on 4/23/15. */ public class SplittingUtils { public static <VC extends ValueCounter<VC>> Optional<SplitScore> splitSortedAttributeStats(AttributeStats<VC> attributeStats, ScorerFactory<VC> scorerFactory, BranchingConditions<VC> branchingConditions, AttributeValueIgnoringStrategy<VC> attributeValueIgnoringStrategy, boolean doNotUseAttributeValuesWithInsuffientStatistics) { double bestScore = 0; int indexOfLastValueCounterInTrueSet = 0; double probabilityOfBeingInTrueSet = 0; boolean trueSetExists = false; int valuesConsidered = 0; List<VC> attributeValueStatsList = attributeStats.getStatsOnEachValue(); VC falseSet = attributeStats.getAggregateStats(); VC trueSet = falseSet.subtract(falseSet); //empty true Set Scorer<VC> scorer = scorerFactory.getScorer(attributeStats); for (int i = 0; i < attributeValueStatsList.size() - 1; i++) { VC valueCounterForAttrVal = attributeValueStatsList.get(i); if (shouldWeIgnoreValueCounter(attributeValueIgnoringStrategy, doNotUseAttributeValuesWithInsuffientStatistics, valueCounterForAttrVal)) { continue; } valuesConsidered++; trueSet = trueSet.add(valueCounterForAttrVal); falseSet = falseSet.subtract(valueCounterForAttrVal); //TODO Could optimize by knowing that all additional trial splits will fail once false set because small enough if (branchingConditions.isInvalidSplit(trueSet, falseSet, attributeStats.getAttribute()) || attributeValueIgnoringStrategy.shouldWeIgnoreThisValue(trueSet) || attributeValueIgnoringStrategy.shouldWeIgnoreThisValue(falseSet)) { continue; } double thisScore = scorer.scoreSplit(trueSet, falseSet); if (branchingConditions.isInvalidSplit(thisScore)) { continue; } if (thisScore > bestScore) { bestScore = thisScore; indexOfLastValueCounterInTrueSet = i; probabilityOfBeingInTrueSet = trueSet.getTotal() / (trueSet.getTotal() + falseSet.getTotal()); trueSetExists = true; } } if (!trueSetExists || valuesConsidered<1) { return Optional.absent(); } Set<Serializable> trueSetVals = createTrueSetVals(indexOfLastValueCounterInTrueSet, attributeValueStatsList, attributeValueIgnoringStrategy, doNotUseAttributeValuesWithInsuffientStatistics); return Optional.of(new SplitScore(bestScore, indexOfLastValueCounterInTrueSet, probabilityOfBeingInTrueSet, trueSetVals)); } private static <VC extends ValueCounter<VC>> Set<Serializable> createTrueSetVals(int indexOfLastTermStatsInTrueSet, List<VC> valueCounters, AttributeValueIgnoringStrategy<VC> attributeValueIgnoringStrategy, boolean doNotUseAttributeValuesWithInsuffientStatistics) { Set<Serializable> trueSetVals = Sets.newHashSet(); for (int j = 0; j <= indexOfLastTermStatsInTrueSet; j++) { VC valueCounterForAttrVal = valueCounters.get(j); if (shouldWeIgnoreValueCounter(attributeValueIgnoringStrategy, doNotUseAttributeValuesWithInsuffientStatistics, valueCounterForAttrVal)) { continue; } trueSetVals.add(valueCounterForAttrVal.getAttrVal()); } return trueSetVals; } private static <VC extends ValueCounter<VC>> boolean shouldWeIgnoreValueCounter(AttributeValueIgnoringStrategy<VC> attributeValueIgnoringStrategy, boolean doNotUseAttributeValuesWithInsuffientStatistics, VC valueCounterForAttrVal) { if (valueCounterForAttrVal == null || valueCounterForAttrVal.attrVal.equals(MissingValue.MISSING_VALUE)) { return true; } return attributeValueIgnoringStrategy.shouldWeIgnoreThisValue(valueCounterForAttrVal) && doNotUseAttributeValuesWithInsuffientStatistics; } public static class SplitScore { public double score; public int indexOfLastValueCounterInTrueSet; public double probabilityOfBeingInTrueSet; public Set<Serializable> trueSet; public SplitScore(double score, int indexOfLastValueCounterInTrueSet, double probabilityOfBeingInTrueSet, Set<Serializable> trueSet) { this.score = score; this.indexOfLastValueCounterInTrueSet = indexOfLastValueCounterInTrueSet; this.probabilityOfBeingInTrueSet = probabilityOfBeingInTrueSet; this.trueSet = trueSet; } } }