/* * RapidMiner * * Copyright (C) 2001-2011 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.operator.performance; import java.io.ObjectStreamException; import java.util.LinkedList; import java.util.List; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.operator.OperatorException; import com.rapidminer.tools.math.Averagable; import com.rapidminer.tools.math.ROCBias; import com.rapidminer.tools.math.ROCData; import com.rapidminer.tools.math.ROCDataGenerator; /** * This criterion calculates the area under the ROC curve. * * @author Ingo Mierswa, Martin Scholz */ public class AreaUnderCurve extends MeasuredPerformance { private static final long serialVersionUID = 6877715214974493828L; public static class Optimistic extends AreaUnderCurve { private static final long serialVersionUID = 1L; public Optimistic() { super(ROCBias.OPTIMISTIC); } } public static class Pessimistic extends AreaUnderCurve { private static final long serialVersionUID = 1L; public Pessimistic() { super(ROCBias.PESSIMISTIC); } } public static class Neutral extends AreaUnderCurve { private static final long serialVersionUID = 1L; public Neutral() { super(ROCBias.NEUTRAL); } } /** The value of the AUC. */ private double auc = Double.NaN; /** The data generator for this ROC curve. */ private transient ROCDataGenerator rocDataGenerator = new ROCDataGenerator(1.0d, 1.0d); /** The data for the ROC curve. */ private LinkedList<ROCData> rocData = new LinkedList<ROCData>(); /** A counter for average building. */ private int counter = 1; /** The positive class name. */ private String positiveClass; private ROCBias method; /** Clone constructor. */ public AreaUnderCurve() { method = ROCBias.OPTIMISTIC; } public AreaUnderCurve(ROCBias method) { this.method = method; } public AreaUnderCurve(AreaUnderCurve aucObject) { super(aucObject); this.auc = aucObject.auc; this.counter = aucObject.counter; this.positiveClass = aucObject.positiveClass; this.method = aucObject.method; } /** Calculates the AUC. */ @Override public void startCounting(ExampleSet exampleSet, boolean useExampleWeights) throws OperatorException { super.startCounting(exampleSet, useExampleWeights); // create ROC data this.rocData.add(rocDataGenerator.createROCData(exampleSet, useExampleWeights, method)); this.auc = rocDataGenerator.calculateAUC(this.rocData.getLast()); this.positiveClass = exampleSet.getAttributes().getPredictedLabel().getMapping().getPositiveString(); } /** Does nothing. Everything is done in {@link #startCounting(ExampleSet, boolean)}. */ @Override public void countExample(Example example) {} @Override public double getExampleCount() { return 1.0d; } @Override public double getMikroVariance() { return Double.NaN; } @Override public double getMikroAverage() { return auc / counter; } /** Returns the fitness. */ @Override public double getFitness() { return getAverage(); } @Override public String getName() { if (method == ROCBias.NEUTRAL) { return "AUC"; } else { return "AUC ("+method.toString().toLowerCase()+")"; } } @Override public String getDescription() { return "The area under a ROC curve. Given example weights are also considered. Please note that the second class is considered to be positive."; } @Override public void buildSingleAverage(Averagable performance) { AreaUnderCurve other = (AreaUnderCurve) performance; this.counter += other.counter; this.auc += other.auc; this.rocData.addAll(other.rocData); } @Override public String toString() { return super.toString() + " (positive class: " + positiveClass + ")"; } public List<ROCData> getRocData() { return rocData; } public ROCDataGenerator getRocDataGenerator() { return rocDataGenerator; } public void readResolve() throws ObjectStreamException { rocDataGenerator = new ROCDataGenerator(1.0d, 1.0d); } }