/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ /** @author Aron Culotta <a href="mailto:culotta@cs.umass.edu">culotta@cs.umass.edu</a> */ package cc.mallet.classify.evaluate; import java.awt.*; import java.awt.event.*; import javax.swing.*; import cc.mallet.classify.Classification; import cc.mallet.classify.Classifier; import cc.mallet.classify.Trial; import cc.mallet.classify.evaluate.GraphItem; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.LabelVector; import cc.mallet.util.MalletLogger; import cc.mallet.util.PrintUtilities; import java.util.*; import java.util.logging.*; import java.text.DecimalFormat; /** * Methods for calculating and displaying the accuracy v. * coverage data for a Trial */ public class AccuracyCoverage implements ActionListener { private static Logger logger = MalletLogger.getLogger(AccuracyCoverage.class.getName()); static final int DEFAULT_NUM_BUCKETS = 20; static final int DEFAULT_MAX_X = 100; private ArrayList classifications; private double [] accuracyValues; private int numBuckets; private double step; private Graph2 graph; private JFrame frame; /** * Constructs object, sorts classifications, and creates * accuracyValues array * @param t trial to get data from * @param numBuckets number of x-axis measurements to find accuracy */ public AccuracyCoverage(Trial t, int numBuckets, String title, String dataName) { this.classifications = t; this.numBuckets = numBuckets; this.step = (double)DEFAULT_MAX_X/numBuckets; this.accuracyValues = new double[numBuckets]; this.frame = null; logger.info("Constructing AccCov with " + this.classifications.size()); sortClassifications(); /* for(int i=0; i<classifications.size(); i++) { Classification c = (Classification)this.classifications.get(i); LabelVector distr = c.getLabelVector(); System.out.println(distr.getBestValue()); } */ createAccuracyArray(); this.graph = new Graph2( title, 0, 100, "Coverage", "Accuracy"); addDataToGraph(this.accuracyValues, numBuckets, dataName); } public AccuracyCoverage(Trial t, String title, String name) { this(t, DEFAULT_NUM_BUCKETS, title, name); } public AccuracyCoverage(Trial t, String title) { this(t, DEFAULT_NUM_BUCKETS, title, "unnamed"); } public AccuracyCoverage(Classifier C, InstanceList ilist, String title) { this(new Trial(C, ilist), DEFAULT_NUM_BUCKETS, title, "unnamed"); } public AccuracyCoverage(Classifier C, InstanceList ilist, int numBuckets, String title) { this(new Trial(C, ilist), numBuckets, title, "unnamed"); } /** * Finds the "area under the acc/cov curve" * steps by one percentage point and calcs area * of trapezoid */ public double cumulativeAccuracy() { double area = 0.0; for(int i=1; i<100; i++) { double leftAccuracy = accuracyAtCoverage((double)i/100); double rightAccuracy = accuracyAtCoverage((double)(i+1)/100); area += .5*(leftAccuracy + rightAccuracy); } return area; } /** * Creates array of accuracy values for coverage * at each step as defined by numBuckets. */ public void createAccuracyArray() { // System.out.println("Creating accuracyArray. Step= "+step); for(int i=0 ; i<numBuckets; i++) { accuracyValues[i] = accuracyAtCoverage(step *(double)(i+1)/100.0); } } /** * accuracy at a given coverage percentage * @param cov coverage percentage * @return accuracy value */ public double accuracyAtCoverage(double cov) { assert(cov <= 1 && cov > 0); int numTrials = (int)(Math.round((double)classifications.size()*cov)); int numCorrect = 0; // System.out.println("NumTrials="+numTrials); for(int i= classifications.size()-1; i >= classifications.size()-numTrials; i--) { Classification temp = (Classification)classifications.get(i); if(temp.bestLabelIsCorrect()) numCorrect++; } // System.out.println("Accuracy at cov "+cov+" is "+ //(double)numCorrect/numTrials); return((double)numCorrect/numTrials); } /** * Sort classifications ArrayList * by winner's value */ public void sortClassifications() { Collections.sort(classifications, new ClassificationComparator()); } public void addDataToGraph(double [] accValues, int nBuckets, String name) { Vector values = new Vector(nBuckets); for(int i=0; i<nBuckets; i++) { GraphItem temp = new GraphItem("", (int)(accValues[i]*100), Color.black); values.add(temp); } logger.info("Sending "+values.size()+" elements to graph"); this.graph.addItemVector(values, name); } /** * Displays the accuracy v. coverage graph */ public void displayGraph() { Vector values = new Vector(this.numBuckets); JButton printButton = new JButton("Print"); frame = new JFrame("Graph"); DecimalFormat df = new DecimalFormat(); printButton.addActionListener(this); frame.addWindowListener (new WindowAdapter() { public void windowClosing(WindowEvent e) { System.exit(0); } } ); // Get content pane Container pane = frame.getContentPane(); // Set layout manager pane.setLayout( new FlowLayout() ); assert(graph!= null); // make sure we've got data in the graph // Add to pane pane.add( graph ); pane.add( printButton ); frame.pack(); // Center the frame Toolkit toolkit = Toolkit.getDefaultToolkit(); // Get the current screen size Dimension scrnsize = toolkit.getScreenSize(); // Get the frame size Dimension framesize= frame.getSize(); // Set X,Y location frame.setLocation ( (int) (scrnsize.getWidth() - frame.getWidth() ) / 2 , (int) (scrnsize.getHeight() - frame.getHeight()) / 2); frame.setVisible(true); } public void actionPerformed(ActionEvent event) { PrintUtilities.printComponent(graph); } public void addTrial(Trial t, String name) { addTrial(t, DEFAULT_NUM_BUCKETS, name); } public void addTrial(Trial t, int nBuckets, String name) { AccuracyCoverage newData = new AccuracyCoverage(t, nBuckets, "untitled", name); double [] accValues = newData.accuracyValues(); addDataToGraph(accValues, nBuckets, name); } public double[] accuracyValues() { return this.accuracyValues; } public class ClassificationComparator implements Comparator { public final int compare (Object a, Object b) { LabelVector x = (LabelVector) (((Classification)a).getLabelVector()); LabelVector y = (LabelVector) (((Classification)b).getLabelVector()); double difference = x.getBestValue() - y.getBestValue(); int toReturn = 0; if(difference > 0) toReturn = 1; else if (difference < 0) toReturn = -1; return(toReturn); } } }