/* * RapidMiner * * Copyright (C) 2001-2008 by Rapid-I and the contributors * * Complete list of developers available at our web site: * * http://rapid-i.com * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see http://www.gnu.org/licenses/. */ package com.rapidminer.tools.math; import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import com.rapidminer.datatable.DataTable; import com.rapidminer.datatable.SimpleDataTable; import com.rapidminer.datatable.SimpleDataTableRow; import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.gui.plotter.SimplePlotterDialog; /** * Helper class containing some methods for Lift plots. * * @author Ingo Mierswa * @version $Id: LiftDataGenerator.java,v 1.3 2008/05/09 19:23:02 ingomierswa Exp $ */ public class LiftDataGenerator { /** Defines the maximum amount of points which is plotted in the ROC curve. */ public static final int MAX_LIFT_POINTS = 500; private static final int TP = 0; private static final int FP = 1; private static final int FN = 2; private static final int TN = 3; private double maxLift = 0; /** Creates a new Lift data generator. */ public LiftDataGenerator() {} /** Creates a list of ROC data poings from the given example set. The example set must have * a binary label attribute and confidence values for both values, i.e. a model must have been * applied on the data. */ public List<double[]> createLiftDataList(ExampleSet exampleSet) { Attribute label = exampleSet.getAttributes().getLabel(); Attribute predictedLabel = exampleSet.getAttributes().getPredictedLabel(); // create sorted collection with all label values and example weights WeightedConfidenceAndLabel[] calArray = new WeightedConfidenceAndLabel[exampleSet.size()]; Attribute weightAttr = exampleSet.getAttributes().getWeight(); Attribute labelAttr = exampleSet.getAttributes().getLabel(); String positiveClassName = labelAttr.getMapping().mapIndex(label.getMapping().getPositiveIndex()); int index = 0; Iterator<Example> reader = exampleSet.iterator(); while (reader.hasNext()) { Example example = reader.next(); WeightedConfidenceAndLabel wcl; if (weightAttr == null) { wcl = new WeightedConfidenceAndLabel((-1) * example.getConfidence(positiveClassName), example.getValue(labelAttr), example.getValue(predictedLabel)); } else { wcl = new WeightedConfidenceAndLabel((-1) * example.getConfidence(positiveClassName), example.getValue(labelAttr), example.getValue(weightAttr), example.getValue(predictedLabel)); } calArray[index++] = wcl; } Arrays.sort(calArray); List<double[]> tableData = new LinkedList<double[]>(); double[] confidenceMatrix = new double[4]; // Iterate through the example set sorted by predictions. // In each iteration the lift is calculated and added to the list this.maxLift = Double.NEGATIVE_INFINITY; for (int i = 0; i < calArray.length; i++) { WeightedConfidenceAndLabel wcl = calArray[i]; double weight = wcl.getWeight(); double labelValue = wcl.getLabel(); double predictionValue = wcl.getPrediction(); if (labelValue == label.getMapping().getPositiveIndex()) { if (predictionValue == label.getMapping().getPositiveIndex()) { confidenceMatrix[TP] += weight; } else { confidenceMatrix[FN] += weight; } } else { if (predictionValue == label.getMapping().getPositiveIndex()) { confidenceMatrix[FP] += weight; } else { confidenceMatrix[TN] += weight; } } double lift = (confidenceMatrix[TP] * (confidenceMatrix[FP] + confidenceMatrix[TN])) / ((confidenceMatrix[TP] + confidenceMatrix[FP]) * (confidenceMatrix[TP] + confidenceMatrix[FN])); if (!Double.isNaN(lift)) maxLift = Math.max(lift, this.maxLift); tableData.add(new double[] { i, lift }); } return tableData; } /** Creates a dialog containing a plotter for a given list of ROC data points. */ public void createLiftChartPlot(List<double[]> data) { // create data table DataTable dataTable = new SimpleDataTable("Lift Chart", new String[] { "Fraction", "Lift" }); Iterator i = data.iterator(); int pointCounter = 0; int eachPoint = Math.max(1, (int) Math.round((double) data.size() / (double) MAX_LIFT_POINTS)); while (i.hasNext()) { double[] point = (double[]) i.next(); if ((pointCounter == 0) || ((pointCounter % eachPoint) == 0) || (!i.hasNext())) { double fraction = point[0]; double lift = point[1]; if (Double.isNaN(lift)) lift = this.maxLift; dataTable.add(new SimpleDataTableRow(new double[] { fraction, lift })); } pointCounter++; } // create plotter SimplePlotterDialog plotter = new SimplePlotterDialog(dataTable); plotter.setXAxis(0); plotter.plotColumn(1, true); //plotter.setDrawRange(0.0d, 1.0d, 0.0d, 1.0d); plotter.setVisible(true); } }