package arkref.ext.fig.basic;
import java.util.*;
/**
* Allows measuring precision and recall.
*/
public class EvalResult {
// True, pred: p = positive, n = negative
private double pp, pn, np, nn;
private double count;
public EvalResult() { }
public EvalResult(double numTrueAndPred, double numTrue, double numPred) {
pp = numTrueAndPred;
pn = numTrue - numTrueAndPred;
np = numPred - numTrueAndPred;
nn = 0;
count = numTrue + numPred - numTrueAndPred;
}
// Probability of being positive
public void add(double trueProb, double predProb) {
pp += trueProb * predProb;
pn += trueProb * (1-predProb);
np += (1-trueProb) * predProb;
nn += (1-trueProb) * (1-predProb);
count++;
}
public void add(boolean trueVal, boolean predVal) {
add(trueVal ? 1 : 0, predVal ? 1 : 0);
}
public void add(EvalResult r) {
pp += r.pp;
pn += r.pn;
np += r.np;
nn += r.nn;
count += r.count;
}
public <T> void add(HashSet<T> trueSet, HashSet<T> predSet) {
for(T x : trueSet)
add(true, predSet.contains(x));
for(T x : predSet)
if(!trueSet.contains(x))
add(false, true);
}
public double precision() { return pp / (pp + np); }
public double recall() { return pp / (pp + pn); }
public double falsePos() { return np / (pp + np); }
public double trueNeg() { return pn / (pp + pn); }
public double count() { return count; }
public double numTrue() { return pp+pn; }
public double numPred() { return pp+np; }
public double f1() {
double p = precision(), r = recall();
return 2 * p * r / (p + r);
}
public String toString() {
return String.format("Precision = %s, recall = %s, F1 = %s (%s)",
Fmt.D(precision()), Fmt.D(recall()), Fmt.D(f1()), Fmt.D(count()));
}
}