package quickml.supervised.tree.decisionTree.scorers;
//TODO: fix oldScorers
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.reducers.AttributeStats;
import quickml.supervised.tree.scorers.GRImbalancedScorer;
import java.io.Serializable;
import java.util.Map;
/**
* A Scorer intended to estimate the impact on the Mean of the Squared Error (MSE)
* of a branch existing versus not existing. The value returned is the MSE
* without the branch minus the MSE with the branch (so higher is better, as
* is required by the scoreSplit() interface.
*/
public class PenalizedMSEScorer extends GRImbalancedScorer<ClassificationCounter> {
@Override
protected double getUnSplitScore(ClassificationCounter a) {
return getTotalError(a)/a.getTotal();
}
public PenalizedMSEScorer(double degreeOfGainRatioPenalty, double imbalancePenaltyPower, AttributeStats<ClassificationCounter> attributeStats) {
super(degreeOfGainRatioPenalty, imbalancePenaltyPower, attributeStats);
}
@Override
public double scoreSplit(final ClassificationCounter a, final ClassificationCounter b) {
double splitMSE = (getTotalError(a) + getTotalError(b)) / (a.getTotal() + b.getTotal());
return correctForGainRatio(unSplitScore - splitMSE) * getPenaltyForImabalance(a, b);
}
private double getTotalError(ClassificationCounter cc) {
double totalError = 0;
for (Map.Entry<Serializable, Double> e : cc.getCounts().entrySet()) {
double error = (cc.getTotal()>0) ? 1.0 - e.getValue()/cc.getTotal() : 0;
double errorSquared = error*error;
totalError += errorSquared * e.getValue();
}
return totalError;
}
public enum CrossValidationCorrection {
TRUE, FALSE
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("MSEScorer{");
sb.append('}');
return sb.toString();
}
}