package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldScorers; import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.OldScorer; import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.OldClassificationCounter; import java.io.Serializable; /** * A Scorer intended to estimate the impact on the Mean of the Squared Error (MSE) * of a branch existing versus not existing. The value returned is the MSE * without the branch minus the MSE with the branch (so higher is better, as * is required by the scoreSplit() interface. */ public class MSEOldScorerWithCrossValidationCorrection implements OldScorer { private double normalizedParentMSE = 0; private OldClassificationCounter trainingInsetCC; private OldClassificationCounter trainingOutsetCC; private double totalTestWeight; public MSEOldScorerWithCrossValidationCorrection(OldClassificationCounter trainingSetParent, OldClassificationCounter testSetCCParent) { totalTestWeight = testSetCCParent.getTotal(); normalizedParentMSE = getErrorOnTestSet(trainingSetParent, testSetCCParent) / totalTestWeight; } private double getErrorOnTestSet(OldClassificationCounter trainingSetCC, OldClassificationCounter testSetCC) { double totalTrainingWeight = trainingSetCC.getTotal(); double mse = 0; for (Serializable label : testSetCC.getCounts().keySet()) { double labelProb = trainingSetCC.getCounts().get(label) / totalTrainingWeight; double labelError = 1.0 - labelProb; double labelOccurencesInTestSet = testSetCC.getCount(label); mse += labelError * labelError * labelOccurencesInTestSet; } return mse; } public void updateTrainingSetClassificationCounters(OldClassificationCounter trainingSetInset, OldClassificationCounter trainingSetOutset) { this.trainingInsetCC = trainingSetInset; this.trainingOutsetCC = trainingSetOutset; } @Override public double scoreSplit(final OldClassificationCounter testInsetCC, final OldClassificationCounter testOutSetCC) { double normalizedSplitMSE = (getErrorOnTestSet(trainingInsetCC, testInsetCC) + getErrorOnTestSet(trainingOutsetCC, testOutSetCC)) / (totalTestWeight); return normalizedParentMSE - normalizedSplitMSE; } }