/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.extract; import java.io.PrintStream; import java.io.OutputStream; import java.io.PrintWriter; import java.text.DecimalFormat; import java.util.Iterator; import java.util.Vector; import cc.mallet.fst.confidence.ConfidenceEvaluator; import cc.mallet.types.Label; import cc.mallet.types.LabelAlphabet; import cc.mallet.types.MatrixOps; /** * Constructs Accuracy-coverage graph using confidence values to sort Fields. * * Created: Nov 8, 2005 * * @author <A HREF="mailto:culotta@cs.umass.edu>culotta@cs.umass.edu</A> */ public class AccuracyCoverageEvaluator implements ExtractionEvaluator { private int numberBins; private FieldComparator comparator = new ExactMatchComparator (); private PrintStream errorOutputStream = null; public AccuracyCoverageEvaluator (int numberBins) { this.numberBins = 20; } public FieldComparator getComparator () { return comparator; } public void setComparator (FieldComparator comparator) { this.comparator = comparator; } public PrintStream getErrorOutputStream () { return errorOutputStream; } public void setErrorOutputStream (OutputStream errorOutputStream) { this.errorOutputStream = new PrintStream (errorOutputStream); } public void evaluate (Extraction extraction) { evaluate ("", extraction, System.out); } // Assumes that there are as many records as documents, indexed by docs. // Assumes that extractor returns at most one value public void evaluate (String description, Extraction extraction, PrintStream out) { int numDocs = extraction.getNumDocuments (); assert numDocs == extraction.getNumRecords (); Vector entityConfidences = new Vector(); int numTrueValues = 0; int numPredValues = 0; int numCorrValues = 0; for (int docnum = 0; docnum < numDocs; docnum++) { Record extracted = extraction.getRecord (docnum); Record target = extraction.getTargetRecord (docnum); Iterator it = extracted.fieldsIterator (); while (it.hasNext ()) { Field predField = (Field) it.next (); Field trueField = target.getField (predField.getName()); if (predField != null) numPredValues += predField.numValues(); for (int j = 0; j < predField.numValues(); j++) { LabeledSpan span = predField.span(j); boolean correct = (trueField != null && trueField.isValue (predField.value (j), comparator)); entityConfidences.add(new ConfidenceEvaluator.EntityConfidence (span.getConfidence(), correct, span.getText())); if (correct) numCorrValues++; } } it = target.fieldsIterator (); while (it.hasNext ()) { Field trueField = (Field) it.next (); numTrueValues += trueField.numValues (); } } ConfidenceEvaluator evaluator = new ConfidenceEvaluator(entityConfidences, this.numberBins); out.println("correlation: " + evaluator.correlation()); out.println("avg precision: " + evaluator.getAveragePrecision()); out.println("coverage\taccuracy:\n" + evaluator.accuracyCoverageValuesToString()); double[] ac = evaluator.getAccuracyCoverageValues(); for (int i=0; i < ac.length; i++) { int marks = (int)(ac[i]*25.0); for (int j=0; j < marks; j++) out.print("*"); out.println(); } out.println("nTrue:" + numTrueValues + " nCorr:" + numCorrValues + " nPred:" + numPredValues + "\n"); out.println("recall\taccuracy:\n" + evaluator.accuracyRecallValuesToString(numTrueValues)); } }