package edu.stanford.nlp.classify; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.util.Pair; import java.io.Serializable; import java.util.Collection; /** * A simple interface for classifying and scoring data points, implemented * by most of the classifiers in this package. A basic Classifier * works over a List of categorical features. For classifiers over * real-valued features, see {@link RVFClassifier}. * * @author Dan Klein * @author Sarah Spikes (sdspikes@cs.stanford.edu) (Templatization) * * @param <L> The type of the label(s) in each Datum * @param <F> The type of the features in each Datum */ public interface Classifier<L, F> extends Serializable { public L classOf(Datum<L, F> example); public Counter<L> scoresOf(Datum<L, F> example); public Collection<L> labels(); /** * Evaluates the precision and recall of this classifier against a dataset, and the target label. * * @param testData The dataset to evaluate the classifier on. * @param targetLabel The target label (e.g., for relation extraction, this is the relation we're interested in). * @return A pair of the precision (first) and recall (second) of the classifier on the target label. */ public default Pair<Double, Double> evaluatePrecisionAndRecall(GeneralDataset<L, F> testData, L targetLabel) { if (targetLabel == null) { throw new IllegalArgumentException("Must supply a target label to compute precision and recall against"); } // Variables to count int numCorrectAndTarget = 0; int numTargetGuess = 0; int numTargetGold = 0; // Iterate over dataset for (RVFDatum<L, F> datum : testData) { // Get the gold label L label = datum.label(); if (label == null) { throw new IllegalArgumentException("Cannot compute precision and recall on unlabelled dataset. Offending datum: " + datum); } // Get the guess label L guess = classOf(datum); // Compute statistics on datum if (label.equals(targetLabel)) { numTargetGold += 1; } if (guess.equals(targetLabel)) { numTargetGuess += 1; if (guess.equals(label)) { numCorrectAndTarget += 1; } } } // Aggregate statistics double precision = numTargetGuess == 0 ? 0.0 : ((double) numCorrectAndTarget) / ((double) numTargetGuess); double recall = numTargetGold == 0 ? 1.0 : ((double) numCorrectAndTarget) / ((double) numTargetGold); return Pair.makePair(precision, recall); } /** * Evaluate the accuracy of this classifier on the given dataset. * * @param testData The dataset to evaluate the classifier on. * @return The accuracy of the classifier on the given dataset. */ public default double evaluateAccuracy(GeneralDataset<L, F> testData) { int numCorrect = 0; for (RVFDatum<L, F> datum : testData) { // Get the gold label L label = datum.label(); if (label == null) { throw new IllegalArgumentException("Cannot compute precision and recall on unlabelled dataset. Offending datum: " + datum); } // Get the guess L guess = classOf(datum); // Compute statistics if (label.equals(guess)) { numCorrect += 1; } } return ((double) numCorrect) / ((double) testData.size); } }