/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.extract; import java.io.PrintStream; import java.io.OutputStream; import java.io.PrintWriter; import java.io.OutputStreamWriter; import java.text.DecimalFormat; import java.util.Iterator; import cc.mallet.types.Label; import cc.mallet.types.LabelAlphabet; import cc.mallet.types.MatrixOps; /** * Created: Oct 8, 2004 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: PerDocumentF1Evaluator.java,v 1.1 2007/10/22 21:37:44 mccallum Exp $ */ public class PerDocumentF1Evaluator implements ExtractionEvaluator { private FieldComparator comparator = new ExactMatchComparator (); private PrintStream errorOutputStream = null; public FieldComparator getComparator () { return comparator; } public void setComparator (FieldComparator comparator) { this.comparator = comparator; } public PrintStream getErrorOutputStream () { return errorOutputStream; } public void setErrorOutputStream (OutputStream errorOutputStream) { // Work around java bug when wrapping System.out if (errorOutputStream instanceof PrintStream) { this.errorOutputStream = (PrintStream) errorOutputStream; } else { this.errorOutputStream = new PrintStream (errorOutputStream); } } public void evaluate (Extraction extraction) { evaluate (extraction, System.out); } public void evaluate (Extraction extraction, PrintStream out) { evaluate ("", extraction, new PrintWriter (new OutputStreamWriter (out), true)); } public void evaluate (Extraction extraction, PrintWriter out) { evaluate ("", extraction, out); } // Assumes that there are as many records as documents, indexed by docs. // Assumes that extractor returns at most one value public void evaluate (String description, Extraction extraction, PrintWriter out) { int numDocs = extraction.getNumDocuments (); assert numDocs == extraction.getNumRecords (); LabelAlphabet dict = extraction.getLabelAlphabet(); int numLabels = dict.size(); int[] numCorr = new int [numLabels]; int[] numPred = new int [numLabels]; int[] numTrue = new int [numLabels]; for (int docnum = 0; docnum < numDocs; docnum++) { Record extracted = extraction.getRecord (docnum); Record target = extraction.getTargetRecord (docnum); // Calc precision Iterator it = extracted.fieldsIterator (); while (it.hasNext ()) { Field predField = (Field) it.next (); Label name = predField.getName (); Field trueField = target.getField (name); int idx = name.getIndex (); numPred [idx]++; if (predField.numValues() > 1) System.err.println ("Warning: Field "+predField+" has more than one extracted value. Picking arbitrarily..."); if (trueField != null && trueField.isValue (predField.value (0), comparator)) { numCorr [idx]++; } else { // We have an error, report if necessary if (errorOutputStream != null) { //xxx TODO: Display name of supporting document errorOutputStream.println ("Error in extraction! Document "+extraction.getDocumentExtraction (docnum).getName ()); errorOutputStream.println ("Predicted "+predField); errorOutputStream.println ("True "+trueField); errorOutputStream.println (); } } } // Calc true it = target.fieldsIterator (); while (it.hasNext ()) { Field trueField = (Field) it.next (); Label name = trueField.getName (); numTrue [name.getIndex ()]++; } } DecimalFormat f = new DecimalFormat ("0.####"); double totalF1 = 0; int totalFields = 0; out.println (description+" per-document F1"); out.println ("Name\tP\tR\tF1"); for (int i = 0; i < numLabels; i++) { double P = (numPred[i] == 0) ? 0 : ((double)numCorr[i]) / numPred [i]; double R = (numTrue[i] == 0) ? 1 : ((double)numCorr[i]) / numTrue [i]; double F1 = (P + R == 0) ? 0 : (2 * P * R) / (P + R); if ((numPred[i] > 0) || (numTrue[i] > 0)) { totalF1 += F1; totalFields++; } Label name = dict.lookupLabel (i); out.println (name+"\t"+f.format(P)+"\t"+f.format(R)+"\t"+f.format(F1)); } int totalCorr = MatrixOps.sum (numCorr); int totalPred = MatrixOps.sum (numPred); int totalTrue = MatrixOps.sum (numTrue); double P = ((double)totalCorr) / totalPred; double R = ((double)totalCorr) / totalTrue; double F1 = (2 * P * R) / (P + R); out.println ("OVERALL (micro-averaged) P="+f.format(P)+" R="+f.format(R)+" F1="+f.format(F1)); out.println ("OVERALL (macro-averaged) F1="+f.format(totalF1/totalFields)); out.println (); } }