package edu.stanford.nlp.stats; import edu.stanford.nlp.classify.Classifier; import edu.stanford.nlp.classify.GeneralDataset; import edu.stanford.nlp.classify.ProbabilisticClassifier; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.util.HashIndex; import edu.stanford.nlp.util.Index; import edu.stanford.nlp.util.Triple; import java.text.NumberFormat; import java.util.ArrayList; import java.util.List; /** * @author Jenny Finkel */ public class MultiClassPrecisionRecallStats<L> implements Scorer<L> { /** * Count of true positives. */ protected int[] tpCount; /** * Count of false positives. */ protected int[] fpCount; /** * Count of false negatives. */ protected int[] fnCount; protected Index<L> labelIndex; protected L negLabel; protected int negIndex = -1; public <F> MultiClassPrecisionRecallStats(Classifier<L,F> classifier, GeneralDataset<L,F> data, L negLabel) { this.negLabel = negLabel; score(classifier, data); } public MultiClassPrecisionRecallStats(L negLabel) { this.negLabel = negLabel; } public L getNegLabel() { return negLabel; } public <F> double score(ProbabilisticClassifier<L,F> classifier, GeneralDataset<L,F> data) { return score((Classifier<L,F>)classifier, data); } public <F> double score(Classifier<L,F> classifier, GeneralDataset<L,F> data) { List<L> guesses = new ArrayList<>(); List<L> labels = new ArrayList<>(); for (int i = 0; i < data.size(); i++) { Datum<L, F> d = data.getRVFDatum(i); L guess = classifier.classOf(d); guesses.add(guess); } int[] labelsArr = data.getLabelsArray(); labelIndex = data.labelIndex; for (int i = 0; i < data.size(); i++) { labels.add(labelIndex.get(labelsArr[i])); } labelIndex = new HashIndex<>(); labelIndex.addAll(data.labelIndex().objectsList()); labelIndex.addAll(classifier.labels()); int numClasses = labelIndex.size(); tpCount = new int[numClasses]; fpCount = new int[numClasses]; fnCount = new int[numClasses]; negIndex = labelIndex.indexOf(negLabel); for (int i=0; i < guesses.size(); ++i) { L guess = guesses.get(i); int guessIndex = labelIndex.indexOf(guess); L label = labels.get(i); int trueIndex = labelIndex.indexOf(label); if (guessIndex == trueIndex) { if (guessIndex != negIndex) { tpCount[guessIndex]++; } } else { if (guessIndex != negIndex) { fpCount[guessIndex]++; } if (trueIndex != negIndex) { fnCount[trueIndex]++; } } } return getFMeasure(); } /** * Returns the current precision: <tt>tp/(tp+fp)</tt>. * Returns 1.0 if tp and fp are both 0. */ public Triple<Double, Integer, Integer> getPrecisionInfo(L label) { int i = labelIndex.indexOf(label); if (tpCount[i] == 0 && fpCount[i] == 0) { return new Triple<>(1.0, tpCount[i], fpCount[i]); } return new Triple<>((((double) tpCount[i]) / (tpCount[i] + fpCount[i])), tpCount[i], fpCount[i]); } public double getPrecision(L label) { return getPrecisionInfo(label).first(); } public Triple<Double, Integer, Integer> getPrecisionInfo() { int tp = 0, fp = 0; for (int i = 0; i < labelIndex.size(); i++) { if (i == negIndex) { continue; } tp += tpCount[i]; fp += fpCount[i]; } return new Triple<>((((double) tp) / (tp + fp)), tp, fp); } public double getPrecision() { return getPrecisionInfo().first(); } /** * Returns a String summarizing precision that will print nicely. */ public String getPrecisionDescription(int numDigits) { NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(numDigits); Triple<Double, Integer, Integer> prec = getPrecisionInfo(); return nf.format(prec.first()) + " (" + prec.second() + "/" + (prec.second() + prec.third()) + ")"; } public String getPrecisionDescription(int numDigits, L label) { NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(numDigits); Triple<Double, Integer, Integer> prec = getPrecisionInfo(label); return nf.format(prec.first()) + " (" + prec.second() + "/" + (prec.second() + prec.third()) + ")"; } public Triple<Double, Integer, Integer> getRecallInfo(L label) { int i = labelIndex.indexOf(label); if (tpCount[i] == 0 && fnCount[i] == 0) { return new Triple<>(1.0, tpCount[i], fnCount[i]); } return new Triple<>((((double) tpCount[i]) / (tpCount[i] + fnCount[i])), tpCount[i], fnCount[i]); } public double getRecall(L label) { return getRecallInfo(label).first(); } public Triple<Double, Integer, Integer> getRecallInfo() { int tp = 0, fn = 0; for (int i = 0; i < labelIndex.size(); i++) { if (i == negIndex) { continue; } tp += tpCount[i]; fn += fnCount[i]; } return new Triple<>((((double) tp) / (tp + fn)), tp, fn); } public double getRecall() { return getRecallInfo().first(); } /** * Returns a String summarizing precision that will print nicely. */ public String getRecallDescription(int numDigits) { NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(numDigits); Triple<Double, Integer, Integer> recall = getRecallInfo(); return nf.format(recall.first()) + " (" + recall.second() + "/" + (recall.second() + recall.third()) + ")"; } public String getRecallDescription(int numDigits, L label) { NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(numDigits); Triple<Double, Integer, Integer> recall = getRecallInfo(label); return nf.format(recall.first()) + " (" + recall.second() + "/" + (recall.second() + recall.third()) + ")"; } public double getFMeasure(L label) { double p = getPrecision(label); double r = getRecall(label); double f = (2 * p * r) / (p + r); return f; } public double getFMeasure() { double p = getPrecision(); double r = getRecall(); double f = (2 * p * r) / (p + r); return f; } /** * Returns a String summarizing F1 that will print nicely. */ public String getF1Description(int numDigits) { NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(numDigits); return nf.format(getFMeasure()); } public String getF1Description(int numDigits, L label) { NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(numDigits); return nf.format(getFMeasure(label)); } /** * Returns a String summarizing F1 that will print nicely. */ public String getDescription(int numDigits) { StringBuffer sb = new StringBuffer(); sb.append("--- PR Stats ---").append("\n"); for (L label : labelIndex) { if (label == null || label.equals(negLabel)) { continue; } sb.append("** ").append(label.toString()).append(" **\n"); sb.append("\tPrec: ").append(getPrecisionDescription(numDigits, label)).append("\n"); sb.append("\tRecall: ").append(getRecallDescription(numDigits, label)).append("\n"); sb.append("\tF1: ").append(getF1Description(numDigits, label)).append("\n"); } sb.append("** Overall **\n"); sb.append("\tPrec: ").append(getPrecisionDescription(numDigits)).append("\n"); sb.append("\tRecall: ").append(getRecallDescription(numDigits)).append("\n"); sb.append("\tF1: ").append(getF1Description(numDigits)); return sb.toString(); } }