package quickml.supervised.tree.decisionTree.scorers;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.reducers.AttributeStats;
import quickml.supervised.tree.scorers.GRImbalancedScorer;
import java.io.Serializable;
import java.util.Map;
/**
* Created by chrisreeves on 6/24/14.
*/
public class PenalizedGiniImpurityScorer extends GRImbalancedScorer<ClassificationCounter> {
public PenalizedGiniImpurityScorer(double degreeOfGainRatioPenalty, double imbalancePenaltyPower, AttributeStats<ClassificationCounter> attributeStats) {
super(degreeOfGainRatioPenalty, imbalancePenaltyPower, attributeStats);
}
@Override
public double scoreSplit(ClassificationCounter a, ClassificationCounter b) {
ClassificationCounter parent = ClassificationCounter.merge(a, b);
double aGiniIndex = getGiniIndex(a) * a.getTotal() / parent.getTotal();
double bGiniIndex = getGiniIndex(b) * b.getTotal() / parent.getTotal();
double score = unSplitScore - aGiniIndex - bGiniIndex;
return correctForGainRatio(score)*getPenaltyForImabalance(a, b);
}
@Override
public double getUnSplitScore(ClassificationCounter a) {
return getGiniIndex(a);
}
private double getGiniIndex(ClassificationCounter cc) {
double sum = 0.0d;
for (Map.Entry<Serializable, Double> e : cc.getCounts().entrySet()) {
double error = (cc.getTotal() > 0) ? e.getValue() / cc.getTotal() : 0;
sum += error * error;
}
return 1.0d - sum;
}
@Override
public String toString() {
return "GiniImpurity";
}
}