package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldScorers; import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.OldScorer; import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.OldClassificationCounter; import java.io.Serializable; import java.util.Map; /** * A Scorer intended to estimate the impact on the Mean of the Squared Error (MSE) * of a branch existing versus not existing. The value returned is the MSE * without the branch minus the MSE with the branch (so higher is better, as * is required by the scoreSplit() interface. */ public class MSEOldScorer implements OldScorer { private final double crossValidationInstanceCorrection; public MSEOldScorer(CrossValidationCorrection crossValidationCorrection) { if (crossValidationCorrection.equals(CrossValidationCorrection.TRUE)) { crossValidationInstanceCorrection = 1.0; } else { crossValidationInstanceCorrection = 0.0; } } @Override public double scoreSplit(final OldClassificationCounter a, final OldClassificationCounter b) { OldClassificationCounter parent = OldClassificationCounter.merge(a, b); double parentMSE = getTotalError(parent) / parent.getTotal(); double splitMSE = (getTotalError(a) + getTotalError(b)) / (a.getTotal() + b.getTotal()); return parentMSE - splitMSE; } private double getTotalError(OldClassificationCounter cc) { double totalError = 0; for (Map.Entry<Serializable, Double> e : cc.getCounts().entrySet()) { double error = (cc.getTotal()>0) ? 1.0 - e.getValue()/cc.getTotal() : 0; double errorSquared = error*error; totalError += errorSquared * e.getValue(); } return totalError; } public enum CrossValidationCorrection { TRUE, FALSE } @Override public String toString() { final StringBuilder sb = new StringBuilder("MSEScorer{"); sb.append("cvic=").append(crossValidationInstanceCorrection); sb.append('}'); return sb.toString(); } }