package quickml.supervised.tree.decisionTree.branchFinders;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;
import quickml.supervised.tree.scorers.GRImbalancedScorer;
import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.OldScorer;
import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.OldClassificationCounter;
import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldScorers.GiniImpurityOldScorer;
import quickml.supervised.tree.attributeIgnoringStrategies.AttributeIgnoringStrategy;
import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategy;
import quickml.supervised.tree.branchFinders.SortableLabelsCategoricalBranchFinder;
import quickml.supervised.tree.branchFinders.SplittingUtils;
import quickml.supervised.tree.branchingConditions.BranchingConditions;
import quickml.supervised.tree.constants.BranchType;
import quickml.supervised.tree.constants.MissingValue;
import quickml.supervised.tree.decisionTree.nodes.DTCatBranch;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.reducers.AttributeStats;
import quickml.supervised.tree.scorers.ScorerFactory;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Created by alexanderhawk on 7/5/15.
*/
public class OldBinCatBranchFinder extends SortableLabelsCategoricalBranchFinder<ClassificationCounter>{
//problem
private int minDiscreteAttributeValueOccurances = 2;
private Serializable minorityClassification = 1.0;
private Serializable majorityClassification = 0.0;
private double majorityToMinorityRatio = 20.0896;//check
private int minLeafInstances = 0;
private boolean penalizeCategoricalSplitsBySplitAttributeIntrinsicValue = true;
private double degreeOfGainRatioPenalty = 1.0;
private double minSplitFraction = 0.005;
private OldScorer oldScorer = new GiniImpurityOldScorer();
public OldBinCatBranchFinder(Set<String> candidateAttributes, BranchingConditions<ClassificationCounter> branchingConditions,
ScorerFactory<ClassificationCounter> scorerFactory, AttributeValueIgnoringStrategy<ClassificationCounter> attributeValueIgnoringStrategy,
AttributeIgnoringStrategy attributeIgnoringStrategy) {
super(candidateAttributes, branchingConditions, scorerFactory, attributeValueIgnoringStrategy, attributeIgnoringStrategy);
}
@Override
public Optional<? extends quickml.supervised.tree.nodes.Branch<ClassificationCounter>> getBranch(quickml.supervised.tree.nodes.Branch<ClassificationCounter> parent, AttributeStats<ClassificationCounter> attributeStats) {
if (attributeStats.getStatsOnEachValue().size()<=1) {
return Optional.absent();
}
Optional<SplittingUtils.SplitScore> splitScoreOptional = oldCreateTwoClassCategoricalNode(attributeStats, null, branchingConditions, attributeValueIgnoringStrategy, true);
if ( splitScoreOptional ==null || !splitScoreOptional.isPresent()) {
return Optional.absent();
}
SplittingUtils.SplitScore splitScore = splitScoreOptional.get();
return createBranch(parent, attributeStats, splitScore);
}
@Override
public BranchType getBranchType() {
return BranchType.BINARY_CATEGORICAL;
}
@Override
protected Optional<? extends quickml.supervised.tree.nodes.Branch<ClassificationCounter>> createBranch(quickml.supervised.tree.nodes.Branch<ClassificationCounter> parent, AttributeStats<ClassificationCounter> attributeStats, SplittingUtils.SplitScore splitScore) {
return Optional.of(new DTCatBranch(parent, attributeStats.getAttribute(), splitScore.trueSet,
splitScore.probabilityOfBeingInTrueSet, splitScore.score,
attributeStats.getAggregateStats()));
}
public Optional<SplittingUtils.SplitScore> oldCreateTwoClassCategoricalNode(AttributeStats<ClassificationCounter> attributeStats, GRImbalancedScorer<ClassificationCounter> scorer,
BranchingConditions<ClassificationCounter> branchingConditions,
AttributeValueIgnoringStrategy<ClassificationCounter> attributeValueIgnoringStrategy,
boolean doNotUseAttributeValuesWithInsuffientStatistics) { //Node parent, final String attribute, final Iterable<T> instances) {
//use the sorted list of cc's i get. or add a typeless method to attributeStatsProducer to hold the instances? To complicated for now.
List<ClassificationCounter> ccs = attributeStats.getStatsOnEachValue();
ClassificationCounter outCounts = attributeStats.getAggregateStats();
ClassificationCounter inCounts = outCounts.subtract(outCounts); //empty true Set
Set<String> exemptAttributes = Sets.newHashSet();
double bestScore = 0;
double numTrainingExamples = outCounts.getTotal();
Serializable lastValOfInset = ccs.get(0).attrVal;
double probabilityOfBeingInInset = 0;
int valuesInTheInset = 0;
int attributesWithSufficientValues = labelAttributeValuesWithInsufficientData(ccs);
if (attributesWithSufficientValues <= 1)
return null; //there is just 1 value available.
double intrinsicValueOfAttribute = getIntrinsicValueOfAttribute(ccs, numTrainingExamples);
for (final ClassificationCounter cc : ccs) {
if (cc == null || cc.attrVal.equals(MissingValue.MISSING_VALUE)) { // Also a kludge, figure out why
continue;
}
if (this.minDiscreteAttributeValueOccurances > 0) {
if (!cc.hasSufficientData()) continue;
}
inCounts = inCounts.add(cc);
outCounts = outCounts.subtract(cc);
double numInstances = inCounts.getTotal() + outCounts.getTotal();
if (!exemptAttributes.contains(attributeStats.getAttribute()) && (inCounts.getTotal()/ numInstances <minSplitFraction ||
outCounts.getTotal()/ numInstances < minSplitFraction)) {
continue;
}
if (inCounts.getTotal() < minLeafInstances || outCounts.getTotal() < minLeafInstances) {
continue;
}
double thisScore = this.oldScorer.scoreSplit(new OldClassificationCounter(inCounts),
new OldClassificationCounter(outCounts));
valuesInTheInset++;
if (penalizeCategoricalSplitsBySplitAttributeIntrinsicValue) {
thisScore = thisScore * (1 - degreeOfGainRatioPenalty) + degreeOfGainRatioPenalty * (thisScore / intrinsicValueOfAttribute); }
if (thisScore > bestScore) {
bestScore = thisScore;
lastValOfInset = cc.attrVal;
probabilityOfBeingInInset = inCounts.getTotal() / (inCounts.getTotal() + outCounts.getTotal());
}
}
final Set<Serializable> inSet = Sets.newHashSet();
final Set<Serializable> outSet = Sets.newHashSet();
boolean insetIsBuiltNowBuildingOutset = false;
inCounts = new ClassificationCounter();
outCounts = new ClassificationCounter();
int indexOfLastValueCounterInTrueSet = 0;
for (ClassificationCounter cc : ccs) {
if (!insetIsBuiltNowBuildingOutset) {
indexOfLastValueCounterInTrueSet++;
}
if (!insetIsBuiltNowBuildingOutset && cc.hasSufficientData()) {
inSet.add(cc.attrVal);
inCounts.add(cc);
if (cc.getAttrVal().equals(lastValOfInset)) {
insetIsBuiltNowBuildingOutset = true;
}
} else {
outCounts.add(cc);
}
}
if (bestScore==0)
return Optional.absent();
else {
return Optional.of(new SplittingUtils.SplitScore(bestScore, indexOfLastValueCounterInTrueSet, probabilityOfBeingInInset, inSet));
}
}
private int labelAttributeValuesWithInsufficientData(List<ClassificationCounter> valuesWithClassificationCounters) {
int attributesWithSuffValues = 0;
for (final ClassificationCounter cc : valuesWithClassificationCounters) {
if (this.minDiscreteAttributeValueOccurances > 0) {
if (attributeValueOrIntervalOfValuesHasInsufficientStatistics(cc)) {
cc.setHasSufficientData(false);
} else {
attributesWithSuffValues++;
}
} else {
attributesWithSuffValues++;
}
}
return attributesWithSuffValues;
}
private double getIntrinsicValueOfAttribute(List<ClassificationCounter> valuesWithCCs, double numTrainingExamples) {
double informationValue = 0;
double attributeValProb = 0;
for (ClassificationCounter classificationCounter : valuesWithCCs) {
attributeValProb = classificationCounter.getTotal() / (numTrainingExamples);//-insufficientDataInstances);
informationValue -= attributeValProb * Math.log(attributeValProb) / Math.log(2);
}
return informationValue;
}
private boolean attributeValueOrIntervalOfValuesHasInsufficientStatistics(final ClassificationCounter testValCounts) {
Preconditions.checkArgument(majorityClassification != null && minorityClassification != null);
Map<Serializable, Double> counts = testValCounts.getCounts();
if (counts.containsKey(minorityClassification) &&
counts.get(minorityClassification) > minDiscreteAttributeValueOccurances) {
return false;
}
if (counts.containsKey(majorityClassification) &&
counts.get(majorityClassification) > majorityToMinorityRatio * minDiscreteAttributeValueOccurances) {
return false;
}
if (hasBothMinorityAndMajorityClassifications(counts)
&& hasSufficientStatisticsForBothClassifications(counts)) {
return false;
}
return true;
}
private boolean hasSufficientStatisticsForBothClassifications(Map<Serializable, Double> counts) {
return counts.get(majorityClassification) > 0.6 * majorityToMinorityRatio * minDiscreteAttributeValueOccurances
&& counts.get(minorityClassification) > 0.6 * minDiscreteAttributeValueOccurances;
}
private boolean hasBothMinorityAndMajorityClassifications(Map<Serializable, Double> counts) {
return counts.containsKey(majorityClassification) && counts.containsKey(minorityClassification);
}
}