package quickml.supervised.tree.regressionTree.scorers; //TODO: fix oldScorers import quickml.supervised.tree.reducers.AttributeStats; import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; import quickml.supervised.tree.scorers.GRImbalancedScorer; /** * A Scorer intended to estimate the impact on the Mean of the Squared Error (MSE) * of a branch existing versus not existing. The value returned is the MSE * without the branch minus the MSE with the branch (so higher is better, as * is required by the scoreSplit() interface. */ public class PenalizedMSEScorer extends GRImbalancedScorer<MeanValueCounter> { @Override protected double getUnSplitScore(MeanValueCounter a) { return getTotalError(a)/a.getTotal(); } public PenalizedMSEScorer(double degreeOfGainRatioPenalty, double imbalancePenaltyPower, AttributeStats<MeanValueCounter> attributeStats) { super(degreeOfGainRatioPenalty, imbalancePenaltyPower, attributeStats); } @Override public double scoreSplit(final MeanValueCounter a, final MeanValueCounter b) { double splitMSE = (getTotalError(a) + getTotalError(b)) / (a.getTotal() + b.getTotal()); return correctForGainRatio(unSplitScore - splitMSE) * getPenaltyForImabalance(a, b); } private double getTotalError(MeanValueCounter mvc) { //below: total MSE for using the mvc as a leaf is Sum( (yi- mean)^2 ) = accumulatedSquares - mean^2 *numSamples double totalError = (mvc.getAccumulatedSquares() - mvc.getAccumulatedValue()*mvc.getAccumulatedValue()/mvc.getTotal()); return totalError; } @Override public String toString() { final StringBuilder sb = new StringBuilder("MSEScorer{"); sb.append('}'); return sb.toString(); } }