/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * NoiseGrapher.java * Copyright (C) 2002 Raymond J. Mooney * Modified to read Key of Noise */ package weka.experiment; import java.util.*; import java.io.*; import weka.core.*; /** * Class for producing performance graphs for any metric from learning curve results. * Currently supports gnuplot format with various types of error bars */ public class NoiseGrapher { /** Experimental result data in arff format */ protected Instances data; /** Names of datasets in data */ protected String[] datasets; /** Map from scheme + options name to result data in the form of an array of Stats's, one for each learning curve point in points */ protected HashMap schemeMap; /** Ordered array of points on learning in number of training examples */ protected int[] points; /** Name of original file of experimental result data in arff format */ protected String arffFileName; /** The name of the performance metric to plot */ public String metric = "Percent_correct"; /** Set if desire error bars of particular type */ public short errorBars = NONE; /** errorBar value for no error bars */ public static short NONE = 0; /** errorBar value for error bars using standard deviations */ public static short STD_DEV = 1; /** errorBar value for error bars using 95% confidence intervals */ public static short CONF_INF = 2; /** errorBar value for error bars using min and max values */ public static short MIN_MAX = 3; /** Set if desire error bars based on 95% confidence intervals */ public boolean confIntErrorBars = false; /** The name of the dataset to plot performance for */ public String dataset; /** Create an initial Grapher and load in data, names of datasets, * and set of points on learning curve. */ public NoiseGrapher (String arffFileName, short errorBars) throws Exception { this.arffFileName = arffFileName; this.errorBars = errorBars; setData(); setDatasets(); setPoints(); dataset = datasets[0]; } /** Load data for graph in from the given Experiment result file in arff format */ protected void setData () throws Exception { data = new Instances (new BufferedReader(new FileReader(arffFileName))); } /** Set array of points on learning curve from Key_Noise_levels values in data */ protected void setPoints() throws Exception { Attribute attr = data.attribute("Key_Noise_levels"); points = new int[attr.numValues()]; for (int i =0; i < points.length; i++) points[i] = Integer.parseInt(attr.value(i)); Arrays.sort(points); } /** Set array of points on learning curve from Key_Dataset values in data */ protected void setDatasets() throws Exception { Attribute attr = data.attribute("Key_Dataset"); datasets = new String[attr.numValues()]; for (int i =0; i < datasets.length; i++) datasets[i] = attr.value(i); } /** Read in data for the current values of dataset and metric by indexing * for each scheme+options name an array of Stats objects for each point on the * learning curve */ protected void processData () throws Exception { schemeMap = new HashMap(); // Go through each data line in the data Enumeration enum = data.enumerateInstances(); while (enum.hasMoreElements()) { Instance inst = (Instance)enum.nextElement(); // If this is not a line for the current dataset, skip it if (!inst.stringValue(data.attribute("Key_Dataset")).equals(dataset)) continue; // Get the full name of the scheme by concatenating the system // name and the set of system options String name = inst.stringValue(data.attribute("Key_Scheme")) + inst.stringValue(data.attribute("Key_Scheme_options")); // See if this scheme already has and Stats vector for points Stats[] pointsStats = (Stats[])schemeMap.get(name); if (pointsStats == null) { // If not create one pointsStats = new Stats[points.length]; schemeMap.put(name, pointsStats); } // Get the number of training instances for this line int point = Integer.parseInt(inst.stringValue(data.attribute("Key_Noise_levels"))); // Find the position in the array of points associated with this point int pointPos = Arrays.binarySearch(points, point); // Get the Stats performance metric object for this point Stats stats = pointsStats[pointPos]; if (stats == null) { // If there is none, create one stats = new Stats(); pointsStats[pointPos] = stats; } Attribute metricAttr = data.attribute(metric); if (metricAttr == null) throw new Error("Unrecognized metric:" + metric); // Get the value of the performance metric for this line double metricValue = inst.value(metricAttr); // Add this value to the Stats object for this scheme and point // that keeps track of the running sum to eventually compute an average stats.add(metricValue); } } /** Generate gnuplot files for plotting a learning curve for the current * dataset and metric. Assumes a processData was last performed for * this case dataset and metric */ public void gnuplot() throws Exception { // Find min and max values of the performance metric double yMin=Double.POSITIVE_INFINITY, yMax=Double.NEGATIVE_INFINITY; // Iterate though each scheme and each of its plots points Iterator schemeEntries = schemeMap.entrySet().iterator(); // Index of last point on the learning curve (this may differ // for different datasets). int last_point=-1, last_index=0; while (schemeEntries.hasNext()) { Map.Entry schemeEntry = (Map.Entry)schemeEntries.next(); Stats[] pointsStats = (Stats[])schemeEntry.getValue(); for (int i=0; i < points.length; i++) { // First calculate final mean and other summary stats //PM if(pointsStats[i]==null) continue; // Keep track of which is the last point on the // learning curve on this dataset if(points[i]>last_point) { last_point = points[i]; last_index = i; } pointsStats[i].calculateDerived(); if (pointsStats[i].mean < yMin) yMin = pointsStats[i].mean; if (pointsStats[i].mean > yMax) yMax = pointsStats[i].mean; } } // Use result file name stem as a stem for plot files String fileStem = removeFileExtension(arffFileName); // Also include the name of the dataset in the plot-file stem if // there is results for more than one dataset in this result file if (datasets.length > 1) fileStem = fileStem + dataset; String fileName = fileStem + "_" + metric + ".gplot"; // Create a file for the gnuplot PrintWriter out = new PrintWriter(new FileWriter(fileName)); // Write proper gnuplot commands in this file out.println("set xlabel \"Percentage of Noise in Data\""); out.println("set ylabel \"" + metric.replace('_', ' ') + "\""); out.println("\nset terminal postscript color\nset size 0.75,0.75\n\nset data style linespoints"); // Move the key of curve names to the lower right corner, good for learning // curves and train time plots that go from lower left to top right out.println("set key " + 0.85 * points[last_index] + "," + (yMin + 0.25 * (yMax - yMin))); out.print("\nplot "); // For each scheme, add it to the plot command to plot this scheme's learning curve // for the metric and create a data file for the average data for the learning curve points schemeEntries = schemeMap.entrySet().iterator(); while (schemeEntries.hasNext()) { Map.Entry schemeEntry = (Map.Entry)schemeEntries.next(); String scheme = cleanSchemeName((String)schemeEntry.getKey()); Stats[] pointsStats = (Stats[])schemeEntry.getValue(); // Create a data file for this scheme String dataFileName = fileStem + "_" + metric + "_" + scheme; out.print("'" + dataFileName + "' title \"" + scheme + "\""); if (errorBars != NONE) out.print(", '" + dataFileName + "' notitle with errorbars"); if (schemeEntries.hasNext()) out.print(", "); PrintWriter dataOut = new PrintWriter(new FileWriter(dataFileName)); // Write out a line for each data point on the learning curve for the metric for (int i=0; i <= last_index; i++) { dataOut.print(points[i] + " " + pointsStats[i].mean); // Add a third (and maybe fourth) entry for the error bar. // Just a third indicates a delta about the mean, a third // and fourth indicates a lower and upper bound if (errorBars == STD_DEV) { dataOut.print(" " + pointsStats[i].stdDev); } else if (errorBars == CONF_INF) { // a 95% confidence interval is a delta of 1.96 standard deviations dataOut.print(" " + 1.96 * pointsStats[i].stdDev); } else if (errorBars == MIN_MAX) { dataOut.print(" " + pointsStats[i].min + " " + pointsStats[i].max); } dataOut.println(""); } dataOut.close(); } out.close(); } /** Clean the name of a scheme to make it appropriate for a file name */ private String cleanSchemeName(String schemeName) { return Utils.removeSubstring(schemeName, "weka.classifiers.").replace(' ','_'); } /** Return the name of a file with the extension removed */ public static String removeFileExtension(String fileName) { int pos = fileName.lastIndexOf("."); if (pos == -1) return fileName; else return fileName.substring(0,pos); } /** Produce a gnuplot for each dataset in the result file */ public void gnuplotAllDatasets () throws Exception{ for(int i =0; i < datasets.length; i++) { dataset = datasets[i]; processData(); gnuplot(); } } /** Create gnuplot graphs of learning curves. The first argument should * be the name of an arff file of experimental result for a learning curve experiment. * If present, the second argument should be the name of a performance metric in * result file to plot (which defaults to Percent_correct). Options are: * <ul> * <li> -s: Plot error bars of standard deviations. * <li> -c: Plot error bars of 95% confidence intervals. * <li> -m: Plot error bars of min and max values. *</ul> */ public static void main (String[] args) throws Exception { int current = 0; short errorBars = NONE; if (args[current].equals("-s")){ errorBars = STD_DEV; current++; } else if (args[current].equals("-c")){ errorBars = CONF_INF; current++; } else if (args[current].equals("-m")){ errorBars = MIN_MAX; current++; } NoiseGrapher noisegrapher = new NoiseGrapher(args[current++],errorBars); if (args.length > current) noisegrapher.metric = args[current++]; noisegrapher.gnuplotAllDatasets(); } }