/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * VisualizeROC.java * Copyright (C) 2009 University of Waikato, Hamilton, New Zealand */ package wekaexamples.gui.visualize; import weka.classifiers.evaluation.ThresholdCurve; import weka.core.Instances; import weka.core.Utils; import weka.core.converters.ConverterUtils.DataSource; import weka.gui.visualize.PlotData2D; import weka.gui.visualize.ThresholdVisualizePanel; import java.awt.BorderLayout; import javax.swing.JFrame; /** * Visualizes a previously saved ROC curve. Code taken from the * <code>weka.gui.explorer.ClassifierPanel</code> - involved methods: * <ul> * <li>visualize(String,int,int)</li> * </li>visualizeClassifierErrors(VisualizePanel)</li> * </ul> * * @author FracPete (fracpete at waikato dot ac dot nz) * @version $Revision$ * @see weka.gui.explorer.ClassifierPanel */ public class VisualizeROC { /** * Takes one argument: previously saved ROC curve data (ARFF file). * * @param args the commandline arguments * @throws Exception if something goes wrong */ public static void main(String[] args) throws Exception { Instances curve = DataSource.read(args[0]); curve.setClassIndex(curve.numAttributes() - 1); // method visualize ThresholdVisualizePanel tvp = new ThresholdVisualizePanel(); tvp.setROCString("(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(curve), 4) + ")"); tvp.setName(curve.relationName()); PlotData2D plotdata = new PlotData2D(curve); plotdata.setPlotName(curve.relationName()); plotdata.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[curve.numInstances()]; for (int n = 1; n < cp.length; n++) cp[n] = true; plotdata.setConnectPoints(cp); // add plot tvp.addPlot(plotdata); // method visualizeClassifierErrors final JFrame jf = new JFrame("WEKA ROC: " + tvp.getName()); jf.setSize(500,400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(tvp, BorderLayout.CENTER); jf.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE); jf.setVisible(true); } }