/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.postprocessing; import java.util.Iterator; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.IOObject; import com.rapidminer.operator.Model; import com.rapidminer.operator.Operator; import com.rapidminer.operator.OperatorDescription; import com.rapidminer.operator.OperatorException; import com.rapidminer.operator.UserError; import com.rapidminer.operator.learner.LearnerCapability; import com.rapidminer.operator.learner.PredictionModel; import com.rapidminer.tools.LogService; /** * A scaling operator, applying the original algorithm by Platt (1999) to turn * confidence scores of boolean classifiers into probability estimates. * * Unlike the original version this operator assumes that the confidence scores * are already in the interval of [0,1], as e.g. given for the RapidMiner boosting * operators. The crude estimates are then transformed into log odds, and scaled * by the original transformation of Platt. * * The operator assumes a model and an example set for scaling. It outputs a * PlattScalingModel, that contains both, the supplied model and the scaling * step. If the example set contains a weight attribute, then this operator is * able to fit a model to the weighted examples. * * @author Martin Scholz * @version $Id: PlattScaling.java,v 1.7 2008/07/07 07:06:46 ingomierswa Exp $ */ public class PlattScaling extends Operator { public PlattScaling(OperatorDescription description) { super(description); } public boolean supportsCapability(LearnerCapability lc) { if ((lc == LearnerCapability.NUMERICAL_CLASS) || (lc == LearnerCapability.POLYNOMINAL_CLASS)) { return false; } else return true; } /** Returns an array with two elements: ExampleSet and Model. */ public Class<?>[] getInputClasses() { return new Class[] { ExampleSet.class, Model.class }; } /** Returns an array with one element; Model */ public Class<?>[] getOutputClasses() { return new Class[] { Model.class }; } public IOObject[] apply() throws OperatorException { ExampleSet exampleSet = getInput(ExampleSet.class); Model model = getInput(Model.class); // some checks if (exampleSet.getAttributes().getLabel() == null) { throw new UserError(this, 105, new Object[0]); } if (exampleSet.getAttributes().size() == 0) { throw new UserError(this, 106, new Object[0]); } final Attribute label = this.extractLabel(model, exampleSet); PlattParameters plattParams; { ExampleSet calibrationSet = (ExampleSet) exampleSet.clone(); calibrationSet = model.apply(calibrationSet); plattParams = computeParameters(calibrationSet, label); PredictionModel.removePredictedLabel(calibrationSet); } PlattScalingModel scalingModel = new PlattScalingModel(exampleSet, model, plattParams); return new IOObject[] { scalingModel }; } private Attribute extractLabel(Model model, ExampleSet exampleSet) { if (model instanceof PredictionModel) { return ((PredictionModel) model).getLabel(); } logWarning("Could not find label in model for Platt's Scaling, using Label of provided ExampleSet instead."); return exampleSet.getAttributes().getLabel(); } /** * Implementation of Platt' scaling algorithm as found in [Platt, 1999]. * * @param exampleSet * the example set for finding the model parameters. It needs to * contain a predicted label and confidence scores. Please note, * that the confidence values are expected to range from 0 to 1, * e.g. already take the form of coarse probability estimates. * @return an object containing the parameters A and B of Platt's scaling */ public static PlattParameters computeParameters(ExampleSet exampleSet, Attribute label) { // The current label indices may be different from the expected ones // (label stored in model). // The current ones are used when accessing the true label, // the confidences are accessed via the Strings representations. final String posLabelS = label.getMapping().getPositiveString(); final int posLabel = exampleSet.getAttributes().getLabel().getMapping().mapString(posLabelS); final String negLabelS = label.getMapping().getNegativeString(); final int negLabel = exampleSet.getAttributes().getLabel().getMapping().mapString(negLabelS); // Prefetch the weight attribute of the example set, may be null. final Attribute weightAttr = exampleSet.getAttributes().getWeight(); // compute priors double[] priors = new double[2]; Iterator<Example> reader = exampleSet.iterator(); while (reader.hasNext()) { Example example = reader.next(); double weight = (weightAttr == null) ? 1.0d : example.getWeight(); priors[(int) example.getLabel()] += weight; } // initialize values to be computed: A, B double A = 0; double B = Math.log((priors[negLabel] + 1.0d) / (priors[posLabel] + 1.0d)); double hiTarget = ((priors[posLabel] + 1) / (priors[posLabel] + 2)); double loTarget = 1.0d / (priors[negLabel] + 2); double lambda = 1E-3; double olderr = 1E300; // initialize temp array to store prob. estimates double[] pp = new double[exampleSet.size()]; for (int i = 0; i < pp.length; i++) { pp[i] = (priors[posLabel] + 1.0d) / (priors[negLabel] + priors[posLabel] + 2.0d); } int count = 0; for (int it = 1; it <= 100; it++) { double a = 0; double b = 0; double c = 0; double d = 0; double e = 0; double t = 0; // compute Hessian & gradient of error function reader = exampleSet.iterator(); int index = 0; while (reader.hasNext()) { Example example = reader.next(); if (example.getLabel() == posLabel) { t = hiTarget; } else { t = loTarget; } // translate predictions (confidences) into expected log odds // format double predicted = getLogOddsPosConfidence(example.getConfidence(posLabelS)); double weight = (weightAttr == null) ? 1.0d : example.getWeight(); double d1 = weight * (pp[index] - t); double d2 = weight * (pp[index] * (1 - pp[index])); a += predicted * predicted * d2; b += d2; c += predicted * d2; d += predicted * d1; e += d1; index++; } // stop if gradient is tiny if (Math.abs(d) < 1E-9 && Math.abs(e) < 1E-9) { break; } double oldA = A; double oldB = B; double err = 0; // Loop until goodness of fit increases while (true) { double det = (a + lambda) * (b + lambda) - c * c; if (det == 0) { lambda *= 10; continue; } A = oldA + ((b + lambda) * d - c * e) / det; B = oldB + ((a + lambda) * e - c * d) / det; err = 0; index = 0; while (reader.hasNext()) { Example example = reader.next(); double predicted = getLogOddsPosConfidence(example.getConfidence(posLabelS)); double weight = (weightAttr == null) ? 1.0d : example.getWeight(); // min and max avoids NaNs: double oddsVal = Math.min(1E30, Math.exp(predicted * A + B)); double p = Math.min((1.0d - 1E-30), 1.0d / (1.0d + oddsVal)); pp[index++] = p; err -= weight * (t * Math.log(p) + (t - 1) * Math.log(1.0d - p)); } if (err < olderr * (1.0d + 1E-7)) { lambda *= 0.1; break; } lambda *= 10; if (lambda >= 1E6) { break; } } double diff = err - olderr; double scale = 0.5 * (err + olderr + 1); if ((diff > -1E-3 * scale) && (diff < 1E-7 * scale)) { count++; } else { count = 0; } olderr = err; if (count == 3) { break; } } if (Double.isNaN(A) || Double.isNaN(B)) { A = 1.0d; B = 0.0d; exampleSet.getLog().logWarning("Discarding invalid result of Platt's scaling, using identity instead."); } return new PlattParameters(A, B); } /** * Translates confidence scores in [0, 1] to those originally expected by * Platt's scaling, where positive values result in positive predictions, * and where the absolute value indicates the confidence in the prediction. */ public static double getLogOddsPosConfidence(double originalConfidence) { // avoid infinite or meaningless results by not allowing arbitrarily // small or large values: double epsilon = 1E-30; double confidence = Math.min(Math.max(epsilon, originalConfidence), 1.0d - epsilon); if (Double.isNaN(confidence)) { // error, just try to continue confidence = 0.5; LogService.getGlobal().log("Found a NaN confidence during Platt's Scaling.", LogService.WARNING); } double odds = (1.0d - confidence) / confidence; // All we need to do is compute the logarithm, // the choice of the base is implicitly left to the scaling part: return (Math.log(odds)); // an input of 0.5 results in a return value // of 0 } }