package edu.stanford.nlp.stats;
import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.ling.Datum;
import java.text.NumberFormat;
/**
* Utility class for aggregating counts of true positives, false positives, and
* false negatives and computing precision/recall/F1 stats. Can be used for a single
* collection of stats, or to aggregate stats from a bunch of runs.
*
* @author Joseph Smarr
*/
public class PrecisionRecallStats {
/**
* Count of true positives.
*/
protected int tpCount = 0;
/**
* Count of false positives.
*/
protected int fpCount = 0;
/**
* Count of false negatives.
*/
protected int fnCount = 0;
/**
* Constructs a new PrecisionRecallStats with initially 0 counts.
*/
public PrecisionRecallStats() {
this(0, 0, 0);
}
public <L,F> PrecisionRecallStats(Classifier<L,F> classifier,Dataset<L,F> data,L positiveClass)
{
for (int i=0; i < data.size(); ++i)
{
Datum<L,F> d = data.getDatum(i);
L guess = classifier.classOf(d);
L label = d.label();
boolean guessPositive = guess.equals(positiveClass);
boolean isPositive = label.equals(positiveClass);
if (isPositive && guessPositive) tpCount++;
if (isPositive && !guessPositive) fnCount++;
if (!isPositive && guessPositive) fpCount++;
}
}
/**
* Constructs a new PrecisionRecallStats with the given initial counts.
*/
public PrecisionRecallStats(int tp, int fp, int fn) {
tpCount = tp;
fpCount = fp;
fnCount = fn;
}
/**
* Returns the current count of true positives.
*/
public int getTP() {
return tpCount;
}
/**
* Returns the current count of false positives.
*/
public int getFP() {
return fpCount;
}
/**
* Returns the current count of false negatives.
*/
public int getFN() {
return fnCount;
}
/**
* Adds the given number to the count of true positives.
*/
public void addTP(int count) {
tpCount += count;
}
/**
* Adds one to the count of true positives.
*/
public void incrementTP() {
addTP(1);
}
/**
* Adds the given number to the count of false positives.
*/
public void addFP(int count) {
fpCount += count;
}
/**
* Adds one to the count of false positives.
*/
public void incrementFP() {
addFP(1);
}
/**
* Adds the given number to the count of false negatives.
*/
public void addFN(int count) {
fnCount += count;
}
/**
* Adds one to the count of false negatives.
*/
public void incrementFN() {
addFN(1);
}
/**
* Adds the counts from the given stats to the counts of this stats.
*/
public void addCounts(PrecisionRecallStats prs) {
addTP(prs.getTP());
addFP(prs.getFP());
addFN(prs.getFN());
}
/**
* Returns the current precision: <tt>tp/(tp+fp)</tt>.
* Returns 1.0 if tp and fp are both 0.
*/
public double getPrecision() {
if (tpCount == 0 && fpCount == 0) {
return 1.0;
}
return ((double) tpCount) / (tpCount + fpCount);
}
/**
* Returns a String summarizing precision that will print nicely.
*/
public String getPrecisionDescription(int numDigits) {
NumberFormat nf = NumberFormat.getNumberInstance();
nf.setMaximumFractionDigits(numDigits);
return nf.format(getPrecision()) + " (" + tpCount + "/" + (tpCount + fpCount) + ")";
}
/**
* Returns the current recall: <tt>tp/(tp+fn)</tt>.
* Returns 1.0 if tp and fn are both 0.
*/
public double getRecall() {
if (tpCount == 0 && fnCount == 0) {
return 1.0;
}
return ((double) tpCount) / (tpCount + fnCount);
}
/**
* Returns a String summarizing recall that will print nicely.
*/
public String getRecallDescription(int numDigits) {
NumberFormat nf = NumberFormat.getNumberInstance();
nf.setMaximumFractionDigits(numDigits);
return nf.format(getRecall()) + " (" + tpCount + "/" + (tpCount + fnCount) + ")";
}
/**
* Returns the current F1 measure (<tt>alpha=0.5</tt>).
*/
public double getFMeasure() {
return getFMeasure(0.5);
}
/**
* Returns the F-Measure with the given mixing parameter (must be between 0 and 1).
* If either precision or recall are 0, return 0.0.
* <tt>F(alpha) = 1/(alpha/precision + (1-alpha)/recall)</tt>
*/
public double getFMeasure(double alpha) {
double pr = getPrecision();
double re = getRecall();
if (pr == 0 || re == 0) {
return 0.0;
}
return 1.0 / ((alpha / pr) + (1.0 - alpha) / re);
}
/**
* Returns a String summarizing F1 that will print nicely.
*/
public String getF1Description(int numDigits) {
NumberFormat nf = NumberFormat.getNumberInstance();
nf.setMaximumFractionDigits(numDigits);
return nf.format(getFMeasure());
}
/**
* Returns a String representation of this PrecisionRecallStats, indicating the number of tp, fp, fn counts.
*/
@Override
public String toString() {
return "PrecisionRecallStats[tp=" + getTP() + ",fp=" + getFP() + ",fn=" + getFN() + "]";
}
public String toString(int numDigits) {
return "PrecisionRecallStats[tp=" + getTP() + ",fp=" + getFP() + ",fn=" + getFN() +
",p=" + getPrecisionDescription(numDigits) + ",r=" + getRecallDescription(numDigits) +
",f1=" + getF1Description(numDigits) + "]";
}
}