package edu.hawaii.jmotif.performance.cbf; import java.awt.BorderLayout; import java.io.BufferedReader; import java.io.FileReader; import java.util.Random; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import weka.classifiers.bayes.NaiveBayes; import weka.classifiers.evaluation.ThresholdCurve; import weka.core.Instances; import weka.core.Utils; import weka.gui.visualize.PlotData2D; import weka.gui.visualize.ThresholdVisualizePanel; import edu.hawaii.jmotif.performance.SAXVSMClassifier; /** * Generates and displays a ROC curve from a dataset. Uses a default NaiveBayes to generate the ROC * data. * * @author FracPete */ public class CBFROC { /** * takes one argument: dataset in ARFF format (expects class to be last attribute) */ public static void main(String[] args) throws Exception { // load data Instances data = new Instances(new BufferedReader(new FileReader("data/gun.arff"))); data.setClassIndex(data.numAttributes() - 1); // train classifier Classifier cl = new NaiveBayes(); SAXVSMClassifier cl1 = new SAXVSMClassifier(); cl1.setSAXParams(32, 12, 9, "CLASSIC"); // cl1.setSAXParams(42, 5, 5, "CLASSIC"); // CBF // Evaluation eval = new Evaluation(data); // eval.crossValidateModel(cl, data, 8, new Random(1)); Evaluation eval = new Evaluation(data); eval.crossValidateModel(cl1, data, 10, new Random(1)); // generate curve ThresholdCurve tc = new ThresholdCurve(); int classIndex = 0; Instances result = tc.getCurve(eval.predictions(), classIndex); // plot curve ThresholdVisualizePanel vmc = new ThresholdVisualizePanel(); vmc.setROCString("(Area under ROC = " + Utils.doubleToString(ThresholdCurve.getROCArea(result), 4) + ")"); vmc.setName(result.relationName()); PlotData2D tempd = new PlotData2D(result); tempd.setPlotName(result.relationName()); tempd.addInstanceNumberAttribute(); // specify which points are connected boolean[] cp = new boolean[result.numInstances()]; for (int n = 1; n < cp.length; n++) cp[n] = true; tempd.setConnectPoints(cp); // add plot vmc.addPlot(tempd); // display curve String plotName = vmc.getName(); final javax.swing.JFrame jf = new javax.swing.JFrame("Weka Classifier Visualize: " + plotName); jf.setSize(500, 400); jf.getContentPane().setLayout(new BorderLayout()); jf.getContentPane().add(vmc, BorderLayout.CENTER); jf.addWindowListener(new java.awt.event.WindowAdapter() { public void windowClosing(java.awt.event.WindowEvent e) { jf.dispose(); } }); jf.setVisible(true); } }