package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst; import com.google.common.base.Preconditions; import com.google.common.base.Predicate; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.twitter.common.stats.ReservoirSampler; import com.twitter.common.util.Random; import org.apache.commons.lang.mutable.MutableInt; import org.javatuples.Pair; import quickml.collections.MapUtils; import quickml.data.instances.ClassifierInstance; import quickml.supervised.PredictiveModelBuilder; import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldScorers.GiniImpurityOldScorer; import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.oldAttributeIgnoringStrategies.AttributeIgnoringStrategy; import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.oldAttributeIgnoringStrategies.IgnoreAttributesWithConstantProbability; import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.*; import javax.annotation.Nullable; import java.io.Serializable; import java.util.*; import java.util.Map.Entry; public final class OldTreeBuilder<T extends ClassifierInstance> implements PredictiveModelBuilder<OldTree, T> { public static final String MAX_DEPTH = "maxDepth"; public static final String MIN_SCORE = "minScore"; //the minimum number of times a categorical attribute value must be observed to be considered during splitting. //also the minimimum number of times a numeric attribute must be observed to fall inside a closed interval for that interval to be considered in a split decision public static final String MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE = "minOccurrencesOfAttributeValue"; public static final String MIN_LEAF_INSTANCES = "minLeafInstances"; public static final String SCORER = "scorerFactory"; public static final String PENALIZE_CATEGORICAL_SPLITS = "penalizeCategoricalSplitsBySplitAttributeIntrinsicValue"; public static final String ATTRIBUTE_IGNORING_STRATEGY = "attributeIgnoringStrategy"; public static final String DEGREE_OF_GAIN_RATIO_PENALTY = "degreeOfGainRatioPenalty"; public static final String IGNORE_ATTR_PROB = "ignoreAttributeAtNodeProbability"; public static final String ORDINAL_TEST_SPLITS = "ordinalTestSpilts"; public static final int SMALL_TRAINING_SET_LIMIT = 9; public static final int RESERVOIR_SIZE = 100; public static final Serializable MISSING_VALUE = "%missingVALUE%83257"; private static final int HARD_MINIMUM_INSTANCES_PER_CATEGORICAL_VALUE = 10; public static final String MIN_SPLIT_FRACTION = "minSplitFraction"; public static final String EXEMPT_ATTRIBUTES = "exemptAttributes"; private OldScorer oldScorer; private int maxDepth = 5; private double minimumScore = 0.00000000000001; private int minDiscreteAttributeValueOccurances = 0; private double minSplitFraction = .005; private HashSet<String> exemptAttributes = Sets.newHashSet(); private int minLeafInstances = 0; private Random rand = Random.Util.fromSystemRandom(MapUtils.random); private boolean penalizeCategoricalSplitsBySplitAttributeIntrinsicValue = true; private double degreeOfGainRatioPenalty = 1.0; private int ordinalTestSpilts = 5; private double fractionOfDataToUseInHoldOutSet; private AttributeIgnoringStrategy attributeIgnoringStrategy = new IgnoreAttributesWithConstantProbability(0.0); //TODO: make it so only one thread computes the below 4 values since all trees compute the same values.. private Serializable minorityClassification; private Serializable majorityClassification; private double majorityToMinorityRatio = 1; private boolean binaryClassifications = true; Map<String, AttributeCharacteristics> attributeCharacteristics; public OldTreeBuilder() { this(new GiniImpurityOldScorer()); } public OldTreeBuilder attributeIgnoringStrategy(AttributeIgnoringStrategy attributeIgnoringStrategy) { this.attributeIgnoringStrategy = attributeIgnoringStrategy; return this; } public OldTreeBuilder exemptAttributes(HashSet<String> exemptAttributes) { this.exemptAttributes = exemptAttributes; return this; } public OldTreeBuilder minSplitFraction(double minSplitFraction) { this.minSplitFraction = minSplitFraction; return this; } @Deprecated public OldTreeBuilder ignoreAttributeAtNodeProbability(double ignoreAttributeAtNodeProbability) { attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(ignoreAttributeAtNodeProbability)); return this; } public synchronized OldTreeBuilder<T> copy() { OldTreeBuilder<T> copy = new OldTreeBuilder<>(); copy.oldScorer = oldScorer; copy.maxDepth = maxDepth; copy.minimumScore = minimumScore; copy.minDiscreteAttributeValueOccurances = minDiscreteAttributeValueOccurances; copy.minLeafInstances = minLeafInstances; copy.penalizeCategoricalSplitsBySplitAttributeIntrinsicValue = penalizeCategoricalSplitsBySplitAttributeIntrinsicValue; copy.degreeOfGainRatioPenalty = degreeOfGainRatioPenalty; copy.ordinalTestSpilts = ordinalTestSpilts; copy.attributeIgnoringStrategy = attributeIgnoringStrategy.copy(); copy.fractionOfDataToUseInHoldOutSet = fractionOfDataToUseInHoldOutSet; copy.minSplitFraction = minSplitFraction; copy.exemptAttributes = Sets.newHashSet(exemptAttributes); return copy; } public OldTreeBuilder(final OldScorer oldScorer) { this.oldScorer = oldScorer; } public void updateBuilderConfig(final Map<String, Serializable> cfg) { if (cfg.containsKey(SCORER)) scorer((OldScorer) cfg.get(SCORER)); if (cfg.containsKey(MAX_DEPTH)) maxDepth((Integer) cfg.get(MAX_DEPTH)); if (cfg.containsKey(MIN_SCORE)) minimumScore((Double) cfg.get(MIN_SCORE)); if (cfg.containsKey(MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE)) minCategoricalAttributeValueOccurances((Integer) cfg.get(MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE)); if (cfg.containsKey(MIN_LEAF_INSTANCES)) minLeafInstances((Integer) cfg.get(MIN_LEAF_INSTANCES)); if (cfg.containsKey(MIN_SPLIT_FRACTION)) minSplitFraction((Double) cfg.get(MIN_SPLIT_FRACTION)); if (cfg.containsKey(ORDINAL_TEST_SPLITS)) ordinalTestSplits((Integer) cfg.get(ORDINAL_TEST_SPLITS)); if (cfg.containsKey(EXEMPT_ATTRIBUTES)) exemptAttributes((HashSet<String>) cfg.get(EXEMPT_ATTRIBUTES)); if (cfg.containsKey(DEGREE_OF_GAIN_RATIO_PENALTY)) degreeOfGainRatioPenalty((Double) cfg.get(DEGREE_OF_GAIN_RATIO_PENALTY)); if (cfg.containsKey(ATTRIBUTE_IGNORING_STRATEGY)) attributeIgnoringStrategy((AttributeIgnoringStrategy) cfg.get(ATTRIBUTE_IGNORING_STRATEGY)); if (cfg.containsKey(IGNORE_ATTR_PROB)) ignoreAttributeAtNodeProbability((Double)cfg.get(IGNORE_ATTR_PROB)); penalizeCategoricalSplitsBySplitAttributeIntrinsicValue(cfg.containsKey(PENALIZE_CATEGORICAL_SPLITS) ? (Boolean) cfg.get(PENALIZE_CATEGORICAL_SPLITS) : true); } public OldTreeBuilder degreeOfGainRatioPenalty(double degreeOfGainRatioPenalty) { this.degreeOfGainRatioPenalty = degreeOfGainRatioPenalty; return this; } public OldTreeBuilder ordinalTestSplits(int ordinalTestSpilts) { this.ordinalTestSpilts = ordinalTestSpilts; return this; } public OldTreeBuilder<T> scorer(final OldScorer oldScorer) { this.oldScorer = oldScorer; return this; } public OldTreeBuilder<T> maxDepth(int maxDepth) { this.maxDepth = maxDepth; return this; } public OldTreeBuilder<T> binaryClassification(boolean binaryClassification) { this.binaryClassifications = binaryClassification; return this; } public OldTreeBuilder<T> minLeafInstances(int minLeafInstances) { this.minLeafInstances = minLeafInstances; return this; } public OldTreeBuilder<T> penalizeCategoricalSplitsBySplitAttributeIntrinsicValue(boolean useGainRatio) { this.penalizeCategoricalSplitsBySplitAttributeIntrinsicValue = useGainRatio; return this; } public OldTreeBuilder<T> minCategoricalAttributeValueOccurances(int occurances) { this.minDiscreteAttributeValueOccurances = occurances; return this; } public OldTreeBuilder<T> minimumScore(double minimumScore) { this.minimumScore = minimumScore; return this; } @Override public OldTree buildPredictiveModel(Iterable<T> trainingData) { List <T> trainingDataList = Lists.newArrayList(); for (T instance : trainingData ) { trainingDataList.add(instance); } Set<Serializable> classifications = getClassificationProperties(trainingDataList); attributeCharacteristics = surveyTrainingData(trainingData); return new OldTree(growTree(null, trainingDataList, 0), classifications); } private Set<Serializable> getClassificationProperties(Iterable<T> trainingData) { HashMap<Serializable, MutableInt> classificationsAndCounts = Maps.newHashMap(); Serializable minorityClassification = null; Serializable majorityClassification = null; boolean binaryClassifications = true; double majorityToMinorityRatio = 1; for (T instance : trainingData) { Serializable classification = instance.getLabel(); if (classificationsAndCounts.containsKey(classification)) { classificationsAndCounts.get(classification).increment(); } else classificationsAndCounts.put(classification, new MutableInt(1)); } if (classificationsAndCounts.size() > 2) { setBinaryClassifications(false); return new HashSet<>(classificationsAndCounts.keySet()); } minorityClassification = null; double minorityClassificationCount = 0; majorityClassification = null; double majorityClassificationCount = 0; for (Serializable val : classificationsAndCounts.keySet()) { if (minorityClassification == null || classificationsAndCounts.get(val).doubleValue() < minorityClassificationCount) { minorityClassification = val; minorityClassificationCount = classificationsAndCounts.get(val).doubleValue(); } if (majorityClassification == null || classificationsAndCounts.get(val).doubleValue() > majorityClassificationCount) { majorityClassification = val; majorityClassificationCount = classificationsAndCounts.get(val).doubleValue(); } } majorityToMinorityRatio = classificationsAndCounts.get(majorityClassification).doubleValue() / classificationsAndCounts.get(minorityClassification).doubleValue(); writeClassificationPropertiesOfDataSet(minorityClassification, majorityClassification, binaryClassifications, majorityToMinorityRatio); return new HashSet<>(classificationsAndCounts.keySet()); } private void setBinaryClassifications(boolean binaryClassifications) { this.binaryClassifications = binaryClassifications; } private void writeClassificationPropertiesOfDataSet(Serializable minorityClassification, Serializable majorityClassification, boolean binaryClassifications, double majorityToMinorityRatio) { this.minorityClassification = minorityClassification; this.majorityClassification = majorityClassification; this.binaryClassifications = binaryClassifications; this.majorityToMinorityRatio = majorityToMinorityRatio; } private double[] createNumericSplit(final List<T> trainingData, final String attribute) { int numSamples = Math.min(RESERVOIR_SIZE, trainingData.size()); final ReservoirSampler<Double> reservoirSampler = new ReservoirSampler<Double>(numSamples, rand); int samplesToSkipPerStep = Math.max(1, trainingData.size() / RESERVOIR_SIZE); for (int i=0; i<trainingData.size(); i+=samplesToSkipPerStep) { Serializable value = trainingData.get(i).getAttributes().get(attribute); if (value == null) { continue; } reservoirSampler.sample(((Number) value).doubleValue()); } return getSplit(reservoirSampler); } private Map<String, double[]> createNumericSplits(final List<T> trainingData) { final Map<String, ReservoirSampler<Double>> rsm = Maps.newHashMap(); int numSamples = Math.min(RESERVOIR_SIZE, trainingData.size()); int samplesToSkipPerStep = Math.max(1, trainingData.size() / RESERVOIR_SIZE); for (int i=0; i<numSamples; i+=samplesToSkipPerStep) { for (final Entry<String, Serializable> attributeEntry : trainingData.get(i).getAttributes().entrySet()) { if (attributeEntry.getValue() instanceof Number) { ReservoirSampler<Double> reservoirSampler = rsm.get(attributeEntry.getKey()); if (reservoirSampler == null) { reservoirSampler = new ReservoirSampler<>(numSamples, rand); rsm.put(attributeEntry.getKey(), reservoirSampler); } reservoirSampler.sample(((Number) attributeEntry.getValue()).doubleValue()); } } } final Map<String, double[]> splits = Maps.newHashMap(); for (final Entry<String, ReservoirSampler<Double>> e : rsm.entrySet()) { final double[] split = getSplit(e.getValue()); splits.put(e.getKey(), split); } return splits; } private double[] getSplit(ReservoirSampler<Double> reservoirSampler) { final ArrayList<Double> splitList = Lists.newArrayList(); for (final Double sample : reservoirSampler.getSamples()) { splitList.add(sample); } if (splitList.isEmpty()) { throw new RuntimeException("Split list empty"); } Collections.sort(splitList); final double[] split = new double[ordinalTestSpilts - 1]; final int indexMultiplier = splitList.size() / (split.length + 1);//num elements / num bins for (int x = 0; x < split.length; x++) { split[x] = splitList.get((x + 1) * indexMultiplier); } return split; } private OldNode growTree(OldBranch parent, List<T> trainingData, final int depth) { Preconditions.checkArgument(!Iterables.isEmpty(trainingData), "At Depth: " + depth + ". Can't build a oldTree with no training data"); if (depth >= maxDepth) { return getLeaf(parent, trainingData, depth); } Pair<? extends OldBranch, Double> bestPair = getBestNodePair(parent, trainingData); OldBranch bestNode = bestPair != null ? bestPair.getValue0() : null; double bestScore = bestPair != null ? bestPair.getValue1() : 0; if (bestNode == null || bestScore < minimumScore) { //bestNode will be null if no attribute could provide a split that had enough statistically significant variable values // to produce 2 children where each had at least minInstancesPerLeafSamples. //The best score condition naturally catches the situation where all instances have the same classification. return getLeaf(parent, trainingData, depth); } ArrayList<T> trueTrainingSet = Lists.newArrayList(); ArrayList<T> falseTrainingSet = Lists.newArrayList(); setTrueAndFalseTrainingSets(trainingData, bestNode, trueTrainingSet, falseTrainingSet); bestNode.trueChild = growTree(bestNode, trueTrainingSet, depth + 1); bestNode.falseChild = growTree(bestNode, falseTrainingSet, depth + 1); return bestNode; } private OldLeaf getLeaf(OldNode parent, List<T> trainingData, int depth) { return new OldLeaf(parent, trainingData, depth); } private void setTrueAndFalseTrainingSets(Iterable<T> trainingData, OldBranch bestNode, List<T> trueTrainingSet, List<T> falseTrainingSet) { //put instances with attribute values into appropriate training sets for (T instance : trainingData) { if (bestNode.decide(instance.getAttributes())) { trueTrainingSet.add(instance); } else { falseTrainingSet.add(instance); } } } private Pair<? extends OldBranch, Double> getBestNodePair(OldBranch parent, List<T> trainingData) { //should not be doing the following operation every time we call growTree boolean smallTrainingSet = isSmallTrainingSet(trainingData); Pair<? extends OldBranch, Double> bestPair = null; //TODO: make this lazy in the sense that only numeric attributes that are not randomly rignored should have this done for (final Entry<String, AttributeCharacteristics> attributeCharacteristicsEntry : attributeCharacteristics.entrySet()) { if (this.attributeIgnoringStrategy.ignoreAttribute(attributeCharacteristicsEntry.getKey(), parent)) { continue; } Pair<? extends OldBranch, Double> thisPair = null; Pair<? extends OldBranch, Double> numericPair = null; Pair<? extends OldBranch, Double> categoricalPair = null; if (!smallTrainingSet && attributeCharacteristicsEntry.getValue().isNumber) { numericPair = createNumericBranch(parent, attributeCharacteristicsEntry.getKey(), trainingData); } else if (!attributeCharacteristicsEntry.getValue().isNumber) { categoricalPair = createCategoricalNode(parent, attributeCharacteristicsEntry.getKey(), trainingData); } if (numericPair != null) { thisPair = numericPair; } else { thisPair = categoricalPair;//(numericPair.getValue1() > categoricalPair.getValue1()) ? numericPair : categoricalPair; } if (bestPair == null || (thisPair != null && bestPair != null && thisPair.getValue1() > bestPair.getValue1())) { bestPair = thisPair; } } return bestPair; } private boolean isSmallTrainingSet(Iterable<T> trainingData) { boolean smallTrainingSet = true; int tsCount = 0; for (T instance : trainingData) { tsCount++; if (tsCount > SMALL_TRAINING_SET_LIMIT) { smallTrainingSet = false; break; } } return smallTrainingSet; } private Map<String, AttributeCharacteristics> surveyTrainingData(final Iterable<T> trainingData) { //tells us if each attribute is numeric or not. Map<String, AttributeCharacteristics> attributeCharacteristics = Maps.newHashMap(); for (T instance : trainingData) { for (Entry<String, Serializable> e : instance.getAttributes().entrySet()) { AttributeCharacteristics attributeCharacteristic = attributeCharacteristics.get(e.getKey()); if (attributeCharacteristic == null) { attributeCharacteristic = new AttributeCharacteristics(); attributeCharacteristics.put(e.getKey(), attributeCharacteristic); } if (!(e.getValue() instanceof Number)) { attributeCharacteristic.isNumber = false; } } } return attributeCharacteristics; } private Pair<? extends OldBranch, Double> createCategoricalNode(OldNode parent, String attribute, Iterable<T> instances) { if (binaryClassifications) { return createTwoClassCategoricalNode(parent, attribute, instances); } else { return createNClassCategoricalNode(parent, attribute, instances); } } private Pair<? extends OldBranch, Double> createTwoClassCategoricalNode(OldNode parent, final String attribute, final Iterable<T> instances) { double bestScore = 0; final Pair<OldClassificationCounter, List<OldAttributeValueWithClassificationCounter>> valueOutcomeCountsPairs = OldClassificationCounter.getSortedListOfAttributeValuesWithClassificationCounters(instances, attribute, minorityClassification); //returns a list of ClassificationCounterList OldClassificationCounter outCounts = new OldClassificationCounter(valueOutcomeCountsPairs.getValue0()); //classification counter treating all values the same OldClassificationCounter inCounts = new OldClassificationCounter(); //the histogram of counts by classification for the in-set final List<OldAttributeValueWithClassificationCounter> valuesWithClassificationCounters = valueOutcomeCountsPairs.getValue1(); //map of value _> classificationCounter double numTrainingExamples = valueOutcomeCountsPairs.getValue0().getTotal(); Serializable lastValOfInset = valuesWithClassificationCounters.get(0).attributeValue; double probabilityOfBeingInInset = 0; int valuesInTheInset = 0; int attributesWithSufficientValues = labelAttributeValuesWithInsufficientData(valuesWithClassificationCounters); if (attributesWithSufficientValues <= 1) return null; //there is just 1 value available. double intrinsicValueOfAttribute = getIntrinsicValueOfAttribute(valuesWithClassificationCounters, numTrainingExamples); for (final OldAttributeValueWithClassificationCounter valueWithClassificationCounter : valuesWithClassificationCounters) { final OldClassificationCounter testValCounts = valueWithClassificationCounter.classificationCounter; if (testValCounts == null || valueWithClassificationCounter.attributeValue.equals(MISSING_VALUE)) { // Also a kludge, figure out why continue; } if (this.minDiscreteAttributeValueOccurances > 0) { if (!testValCounts.hasSufficientData()) continue; } inCounts = inCounts.add(testValCounts); outCounts = outCounts.subtract(testValCounts); double numInstances = inCounts.getTotal() + outCounts.getTotal(); if (!exemptAttributes.contains(attribute) && (inCounts.getTotal()/ numInstances <minSplitFraction || outCounts.getTotal()/ numInstances < minSplitFraction)) { continue; } if (inCounts.getTotal() < minLeafInstances || outCounts.getTotal() < minLeafInstances) { continue; } double thisScore = oldScorer.scoreSplit(inCounts, outCounts); valuesInTheInset++; if (penalizeCategoricalSplitsBySplitAttributeIntrinsicValue) { thisScore = thisScore * (1 - degreeOfGainRatioPenalty) + degreeOfGainRatioPenalty * (thisScore / intrinsicValueOfAttribute); } if (thisScore > bestScore) { bestScore = thisScore; lastValOfInset = valueWithClassificationCounter.attributeValue; probabilityOfBeingInInset = inCounts.getTotal() / (inCounts.getTotal() + outCounts.getTotal()); } } final Set<Serializable> inSet = Sets.newHashSet(); final Set<Serializable> outSet = Sets.newHashSet(); boolean insetIsBuiltNowBuildingOutset = false; inCounts = new OldClassificationCounter(); outCounts = new OldClassificationCounter(); for (OldAttributeValueWithClassificationCounter oldAttributeValueWithClassificationCounter : valuesWithClassificationCounters) { if (!insetIsBuiltNowBuildingOutset && oldAttributeValueWithClassificationCounter.classificationCounter.hasSufficientData()) { inSet.add(oldAttributeValueWithClassificationCounter.attributeValue); inCounts.add(oldAttributeValueWithClassificationCounter.classificationCounter); if (oldAttributeValueWithClassificationCounter.attributeValue.equals(lastValOfInset)) { insetIsBuiltNowBuildingOutset = true; } } else { outCounts.add(oldAttributeValueWithClassificationCounter.classificationCounter); //outSet.add(attributeValueWithClassificationCounter.attributeValue); } } if (bestScore==0) return null; else { Pair<OldCategoricalOldBranch, Double> bestPair = Pair.with(new OldCategoricalOldBranch(parent, attribute, inSet, probabilityOfBeingInInset), bestScore); return bestPair; } } private int labelAttributeValuesWithInsufficientData(List<OldAttributeValueWithClassificationCounter> valuesWithClassificationCounters) { int attributesWithSuffValues = 0; for (final OldAttributeValueWithClassificationCounter valueWithClassificationCounter : valuesWithClassificationCounters) { if (this.minDiscreteAttributeValueOccurances > 0) { OldClassificationCounter testValCounts = valueWithClassificationCounter.classificationCounter; if (attributeValueOrIntervalOfValuesHasInsufficientStatistics(testValCounts)) { testValCounts.setHasSufficientData(false); } else { attributesWithSuffValues++; } } else { attributesWithSuffValues++; } } return attributesWithSuffValues; } private double getIntrinsicValueOfAttribute(List<OldAttributeValueWithClassificationCounter> valuesWithCCs, double numTrainingExamples) { double informationValue = 0; double attributeValProb = 0; for (OldAttributeValueWithClassificationCounter oldAttributeValueWithClassificationCounter : valuesWithCCs) { OldClassificationCounter classificationCounter = oldAttributeValueWithClassificationCounter.classificationCounter; attributeValProb = classificationCounter.getTotal() / (numTrainingExamples);//-insufficientDataInstances); informationValue -= attributeValProb * Math.log(attributeValProb) / Math.log(2); } return informationValue; } private double getIntrinsicValueOfNumericAttribute() { double informationValue = 0; double attributeValProb = 1.0/ordinalTestSpilts; informationValue -= Math.log(attributeValProb) / Math.log(2);//factor of 1.0/ordinalTestSplits * ordinalTestSplits cancels return informationValue; } private Pair<? extends OldBranch, Double> createNClassCategoricalNode(OldNode parent, final String attribute, final Iterable<T> instances) { final Set<Serializable> values = getAttrinbuteValues(instances, attribute); if (insufficientTrainingDataGivenNumberOfAttributeValues(instances, values)) return null; final Set<Serializable> inValueSet = Sets.newHashSet(); //the in-set OldClassificationCounter inSetClassificationCounts = new OldClassificationCounter(); //the histogram of counts by classification for the in-set final Pair<OldClassificationCounter, Map<Serializable, OldClassificationCounter>> valueOutcomeCountsPair = OldClassificationCounter .countAllByAttributeValues(instances, attribute); OldClassificationCounter outSetClassificationCounts = valueOutcomeCountsPair.getValue0(); //classification counter treating all values the same final Map<Serializable, OldClassificationCounter> valueOutcomeCounts = valueOutcomeCountsPair.getValue1(); //map of value _> classificationCounter double insetScore = 0; while (true) { com.google.common.base.Optional<ScoreValuePair> bestValueAndScore = com.google.common.base.Optional.absent(); //values should be greater than 1 for (final Serializable thisValue : values) { final OldClassificationCounter testValCounts = valueOutcomeCounts.get(thisValue); //TODO: the next 3 lines may no longer be needed. Verify. if (testValCounts == null || thisValue == null || thisValue.equals(MISSING_VALUE)) { continue; } if (this.minDiscreteAttributeValueOccurances > 0) { if (shouldWeIgnoreThisValue(testValCounts)) continue; } final OldClassificationCounter testInCounts = inSetClassificationCounts.add(testValCounts); final OldClassificationCounter testOutCounts = outSetClassificationCounts.subtract(testValCounts); double scoreWithThisValueAddedToInset = oldScorer.scoreSplit(testInCounts, testOutCounts); if (!bestValueAndScore.isPresent() || scoreWithThisValueAddedToInset > bestValueAndScore.get().getScore()) { bestValueAndScore = com.google.common.base.Optional.of(new ScoreValuePair(scoreWithThisValueAddedToInset, thisValue)); } } if (bestValueAndScore.isPresent() && bestValueAndScore.get().getScore() > insetScore) { insetScore = bestValueAndScore.get().getScore(); final Serializable bestValue = bestValueAndScore.get().getValue(); inValueSet.add(bestValue); values.remove(bestValue); final OldClassificationCounter bestValOutcomeCounts = valueOutcomeCounts.get(bestValue); inSetClassificationCounts = inSetClassificationCounts.add(bestValOutcomeCounts); outSetClassificationCounts = outSetClassificationCounts.subtract(bestValOutcomeCounts); } else { break; } } if (inSetClassificationCounts.getTotal() < minLeafInstances || outSetClassificationCounts.getTotal() < minLeafInstances) { return null; } //because inSetClassificationCounts is only mutated to better insets during the for loop...it corresponds to the actual inset here. double probabilityOfBeingInInset = inSetClassificationCounts.getTotal() / (inSetClassificationCounts.getTotal() + outSetClassificationCounts.getTotal()); return Pair.with(new OldCategoricalOldBranch(parent, attribute, inValueSet, probabilityOfBeingInInset), insetScore); } private boolean insufficientTrainingDataGivenNumberOfAttributeValues(final Iterable<T> trainingData, final Set<Serializable> values) { final int averageInstancesPerValue = Iterables.size(trainingData) / values.size(); final boolean notEnoughTrainingDataGivenNumberOfValues = averageInstancesPerValue < Math.max(this.minDiscreteAttributeValueOccurances, HARD_MINIMUM_INSTANCES_PER_CATEGORICAL_VALUE); if (notEnoughTrainingDataGivenNumberOfValues) { return true; } return false; } private Set<Serializable> getAttrinbuteValues(final Iterable<T> trainingData, final String attribute) { final Set<Serializable> values = Sets.newHashSet(); for (T instance : trainingData) { Serializable value = instance.getAttributes().get(attribute); if (value == null) value = MISSING_VALUE; values.add(value); } return values; } private boolean attributeValueOrIntervalOfValuesHasInsufficientStatistics(final OldClassificationCounter testValCounts) { Preconditions.checkArgument(majorityClassification!=null && minorityClassification !=null); Map<Serializable, Double> counts = testValCounts.getCounts(); if (counts.containsKey(minorityClassification) && counts.get(minorityClassification) > minDiscreteAttributeValueOccurances) { return false; } if (counts.containsKey(majorityClassification) && counts.get(majorityClassification) > majorityToMinorityRatio * minDiscreteAttributeValueOccurances) { return false; } if (hasBothMinorityAndMajorityClassifications(counts) && hasSufficientStatisticsForBothClassifications(counts)) { return false; } return true; } private boolean shouldWeIgnoreThisValue(final OldClassificationCounter testValCounts) { Map<Serializable, Double> counts = testValCounts.getCounts(); for (Serializable key : counts.keySet()) { if (counts.get(key).doubleValue() < minDiscreteAttributeValueOccurances) { return true; } } return false; } private boolean hasSufficientStatisticsForBothClassifications(Map<Serializable, Double> counts) { return counts.get(majorityClassification) > 0.6 * majorityToMinorityRatio * minDiscreteAttributeValueOccurances && counts.get(minorityClassification) > 0.6 * minDiscreteAttributeValueOccurances; } private boolean hasBothMinorityAndMajorityClassifications(Map<Serializable, Double> counts) { return counts.containsKey(majorityClassification) && counts.containsKey(minorityClassification); } private Pair<? extends OldBranch, Double> createNumericBranch(OldNode parent, final String attribute, List<T> instances) { double bestScore = 0; double bestThreshold = 0; double lastThreshold = Double.MIN_VALUE; double probabilityOfBeingInInset = 0; final double[] splits = createNumericSplit(instances, attribute); for (final double threshold : splits) { if (threshold == lastThreshold) { continue; } lastThreshold = threshold; Iterable<T> inSet = Iterables.filter(instances, new GreaterThanThresholdPredicate(attribute, threshold)); Iterable<T> outSet = Iterables.filter(instances, new LessThanEqualThresholdPredicate(attribute, threshold)); OldClassificationCounter inClassificationCounts = OldClassificationCounter.countAll(inSet); OldClassificationCounter outClassificationCounts = OldClassificationCounter.countAll(outSet); double numInstances = inClassificationCounts.getTotal() + outClassificationCounts.getTotal(); if (!exemptAttributes.contains(attribute) && (inClassificationCounts.getTotal()/ numInstances <minSplitFraction || outClassificationCounts.getTotal()/ numInstances < minSplitFraction)) { continue; } if (binaryClassifications) { if (attributeValueOrIntervalOfValuesHasInsufficientStatistics(inClassificationCounts) || inClassificationCounts.getTotal() < minLeafInstances || attributeValueOrIntervalOfValuesHasInsufficientStatistics(outClassificationCounts) || outClassificationCounts.getTotal() < minLeafInstances) { continue; } } else if (shouldWeIgnoreThisValue(inClassificationCounts) || shouldWeIgnoreThisValue(outClassificationCounts)) { continue; } double thisScore = oldScorer.scoreSplit(inClassificationCounts, outClassificationCounts); if (thisScore > bestScore) { bestScore = thisScore; bestThreshold = threshold; probabilityOfBeingInInset = inClassificationCounts.getTotal() / numInstances; } } if (bestScore == 0) { return null; } double penalizedBestScore = bestScore/getIntrinsicValueOfNumericAttribute(); return Pair.with(new OldNumericBranch(parent, attribute, bestThreshold, probabilityOfBeingInInset), penalizedBestScore); } public static class AttributeCharacteristics { public boolean isNumber = true; } private class GreaterThanThresholdPredicate implements Predicate<T> { private final String attribute; private final double threshold; public GreaterThanThresholdPredicate(String attribute, double threshold) { this.attribute = attribute; this.threshold = threshold; } @Override public boolean apply(@Nullable T input) { try { if (input == null) {//consider deleting return false; } Serializable value = input.getAttributes().get(attribute); if (value == null) { value = 0; } return ((Number) value).doubleValue() > threshold; } catch (final ClassCastException e) { // Kludge, need to // handle better return false; } } } private class LessThanEqualThresholdPredicate implements Predicate<T> { private final String attribute; private final double threshold; public LessThanEqualThresholdPredicate(String attribute, double threshold) { this.attribute = attribute; this.threshold = threshold; } @Override public boolean apply(@Nullable T input) { try { if (input == null) { return false; } Serializable value = input.getAttributes().get(attribute); if (value == null) { value = Double.MIN_VALUE; } return ((Number) value).doubleValue() <= threshold; //missing values should go the way of the outset. Future improvement shoud allow missing values to go way of either inset or outset } catch (final ClassCastException e) { // Kludge, need to // handle better return false; } } } private class ScoreValuePair { private double score; private Serializable value; private ScoreValuePair(final double score, final Serializable value) { this.score = score; this.value = value; } public double getScore() { return score; } public Serializable getValue() { return value; } } }