package quickml.supervised.tree.decisionTree.branchFinders;
import com.google.common.base.Optional;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import quickml.supervised.tree.attributeIgnoringStrategies.AttributeIgnoringStrategy;
import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategy;
import quickml.supervised.tree.branchFinders.BranchFinder;
import quickml.supervised.tree.branchingConditions.BranchingConditions;
import quickml.supervised.tree.constants.BranchType;
import quickml.supervised.tree.constants.MissingValue;
import quickml.supervised.tree.decisionTree.nodes.DTCatBranch;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.nodes.*;
import quickml.supervised.tree.reducers.AttributeStats;
import quickml.supervised.tree.scorers.Scorer;
import quickml.supervised.tree.scorers.ScorerFactory;
import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Created by alexanderhawk on 6/20/15.
*/
public class DTNClassCatBranchFinder extends BranchFinder<ClassificationCounter> {
public DTNClassCatBranchFinder(Collection<String> candidateAttributes, BranchingConditions<ClassificationCounter> branchingConditions, ScorerFactory<ClassificationCounter> scorerFactory, AttributeValueIgnoringStrategy<ClassificationCounter> attributeValueIgnoringStrategy, AttributeIgnoringStrategy attributeIgnoringStrategy) {
super(candidateAttributes, branchingConditions, scorerFactory, attributeValueIgnoringStrategy, attributeIgnoringStrategy);
}
@Override
public BranchType getBranchType() {
return BranchType.CATEGORICAL;
}
@Override
public Optional<? extends Branch<ClassificationCounter>> getBranch(Branch<ClassificationCounter> parent,
AttributeStats<ClassificationCounter> attributeStats) {
final Set<Serializable> trueSet = Sets.newHashSet();
ClassificationCounter trueClassificationCounts = new ClassificationCounter();
ClassificationCounter falseClassificationCounts = attributeStats.getAggregateStats();
final List<ClassificationCounter> valueOutcomeCounts = attributeStats.getStatsOnEachValue();
Map<Serializable, ClassificationCounter> attrValToCCMap = Maps.newHashMap();
for (ClassificationCounter classificationCounter: valueOutcomeCounts) {
attrValToCCMap.put(classificationCounter.getAttrVal(), classificationCounter);
}
Scorer<ClassificationCounter> scorer = scorerFactory.getScorer(attributeStats);
double scoreWithCurrentTrueSet = 0;
while (true) {
Optional<ScoreValuePair> bestValueAndScore = getNextBestAttributeValueToAddToTrueSet(trueClassificationCounts, falseClassificationCounts, attrValToCCMap, scorer);
if (bestValueAndScore.isPresent() && bestValueAndScore.get().getScore() > scoreWithCurrentTrueSet) {
scoreWithCurrentTrueSet = bestValueAndScore.get().getScore();
final Serializable bestValue = bestValueAndScore.get().getValue();
trueSet.add(bestValue);
final ClassificationCounter bestValOutcomeCounts = attrValToCCMap.get(bestValue);
trueClassificationCounts = trueClassificationCounts.add(bestValOutcomeCounts);
falseClassificationCounts = falseClassificationCounts.subtract(bestValOutcomeCounts);
attrValToCCMap.remove(bestValue);
} else {
break;
}
}
if (branchingConditions.isInvalidSplit(trueClassificationCounts, falseClassificationCounts, attributeStats.getAttribute()) || branchingConditions.isInvalidSplit(scoreWithCurrentTrueSet)) {
return Optional.absent();
}
//because trueClassificationCounts is only mutated to better insets during the for loop...it corresponds to the actual inset here.
double probabilityOfBeingInTrueSet = trueClassificationCounts.getTotal() / (trueClassificationCounts.getTotal() + falseClassificationCounts.getTotal());
return Optional.of(new DTCatBranch(parent, attributeStats.getAttribute(), trueSet, probabilityOfBeingInTrueSet, scoreWithCurrentTrueSet,attributeStats.getAggregateStats()));
}
private Optional<ScoreValuePair> getNextBestAttributeValueToAddToTrueSet(ClassificationCounter trueClassificationCounts, ClassificationCounter falseClassificationCounts, Map<Serializable, ClassificationCounter> attrValToCCMap, Scorer<ClassificationCounter> scorer) {
Optional<ScoreValuePair> bestValueAndScore = Optional.absent();
//values should be greater than 1
for (final Serializable attrVal : attrValToCCMap.keySet()) {
ClassificationCounter cc = attrValToCCMap.get(attrVal);
if ( attrVal== null
|| attrVal.equals(MissingValue.MISSING_VALUE)
|| attributeValueIgnoringStrategy.shouldWeIgnoreThisValue(cc)) {
continue;
}
final ClassificationCounter testInCounts = trueClassificationCounts.add(cc);
final ClassificationCounter testOutCounts = falseClassificationCounts.subtract(cc);
double scoreWithThisValueAddedToTrueSet = scorer.scoreSplit(testInCounts, testOutCounts);
if (!bestValueAndScore.isPresent() || scoreWithThisValueAddedToTrueSet > bestValueAndScore.get().getScore()) {
bestValueAndScore = Optional.of(new ScoreValuePair(scoreWithThisValueAddedToTrueSet, attrVal));
}
}
return bestValueAndScore;
}
static class ScoreValuePair {
double score;
Serializable value;
public ScoreValuePair(double score, Serializable value) {
this.score = score;
this.value = value;
}
public double getScore() {
return score;
}
public Serializable getValue() {
return value;
}
}
}