package de.berlin.hu.wbi.common.research;
import java.io.Serializable;
import java.util.*;
/**
* FIXME: translate to english
* Generische Klasse zum Berechnen von Precision, Recall und anderen Massen,
* die als Eingabe einen Goldstandard und ein Ergebnis, welches mit diesem
* verglichen werden soll, erwartet.
*
* @author arzt
*
* @param <ResultType> Typ eines zu vergleichenden Elementes
*/
public class Evaluator<ResultType, StandardType> implements Serializable {
private static final long serialVersionUID = 3009808474569044273L;
private ArrayList<ResultType> truePositives;
private ArrayList<ResultType> falsePositives;
private ArrayList<StandardType> falseNegatives;
private transient Set<? extends ResultType> result;
private transient Set<? extends StandardType> standard;
private double precision;
private double recall;
private double fMeasure;
public Evaluator() {
super();
this.precision = Double.NaN;
this.recall = Double.NaN;
this.fMeasure = Double.NaN;
}
public void trim() {
if (truePositives != null) truePositives.trimToSize();
if (falseNegatives != null) falseNegatives.trimToSize();
if (falsePositives != null) falsePositives.trimToSize();
result = null;
standard = null;
}
/**
* @param result predictions
* @param standard goldstandard
*/
public Evaluator(Collection<? extends ResultType> result, Collection<? extends StandardType> standard) {
this();
setResultPositives(result);
setStandardPositives(standard);
}
public Evaluator<ResultType, StandardType> setResultPositives(Collection<? extends ResultType> result) {
assert this.result == null : "Do not overwrite the result collection!";
assert result != null : "Your result collection is null!";
assert result.size() > 0 : "Your result collection is empty!";
this.result = new HashSet<ResultType>(result);
return this;
}
public Evaluator<ResultType, StandardType> setStandardPositives(Collection<? extends StandardType> standard) {
assert this.standard == null : "Do not overwrite the standard collection!";
assert standard != null : "Your standard collection is null!";
assert standard.size() > 0 : "Your standard collection is empty!";
this.standard = new HashSet<StandardType>(standard);
return this;
}
public Evaluator<ResultType, StandardType> evaluate() {
assert result != null : "You forgot to set the result collection!";
assert standard != null : "You forgot to set the standard collection!";
this.truePositives = new ArrayList<ResultType>(result.size());
this.falsePositives = new ArrayList<ResultType>(result.size());
this.falseNegatives = new ArrayList<StandardType>(standard.size());
//Compute TP and FP
for (ResultType resultSample : result) {
boolean contains = standard.contains(resultSample);
if (contains) {
truePositives.add(resultSample);
} else {
falsePositives.add(resultSample);
}
}
//Compute FN
for (StandardType standardSample : standard) {
if (!result.contains(standardSample)) {
falseNegatives.add(standardSample);
}
}
return this;
}
private void computeFMeasure() {
double p = getPrecision();
double r = getRecall();
double fM = EvalMeasures.getFMeasure(p, r);
fMeasure = fM;
}
private void computeRecall() {
assert truePositives != null && falseNegatives != null : "You forgot to call evaluate() first!";
int tp = truePositives.size();
int fn = falseNegatives.size();
double r = EvalMeasures.getRecall(tp, fn);
recall = r;
}
private void computePrecision() {
assert truePositives != null && falsePositives != null : "You forgot to call evaluate() first!";
int tp = truePositives.size();
int fp = falsePositives.size();
precision = EvalMeasures.getPrecision(tp, fp);
}
public double getPrecision() {
if (Double.isNaN(precision)) computePrecision();
return precision;
}
public double getRecall() {
if (Double.isNaN(recall)) computeRecall();
return recall;
}
public double getFMeasure() {
if (Double.isNaN(fMeasure)) computeFMeasure();
return fMeasure;
}
public Collection<ResultType> getTruePositives() {
if (truePositives == null) evaluate();
return Collections.unmodifiableCollection(truePositives);
}
public Collection<ResultType> getFalsePositives() {
if (falsePositives == null) evaluate();
return Collections.unmodifiableCollection(falsePositives);
}
public Collection<StandardType> getFalseNegatives() {
if (falseNegatives == null) evaluate();
return Collections.unmodifiableCollection(falseNegatives);
}
public static <R, S> Evaluator<R, S> create(Collection<R> result, Collection<S> standard) {
return new Evaluator<R, S>(result, standard);
}
public static <R, S> Evaluator<R, S> create() {
return new Evaluator<R, S>();
}
public static <T> double getPrecision(Collection<T> result, Collection<T> standard) {
return Evaluator.create(result, standard).getPrecision();
}
public static <T extends Comparable<? super T>> double getRecall(Collection<T> result, Collection<T> standard) {
return Evaluator.create(result, standard).getRecall();
}
public static <T> double getFMeasur(Collection<T> result, Collection<T> standard) {
return Evaluator.create(result, standard).getFMeasure();
}
public static <T> Collection<T> getTruePositives(Collection<T> result, Collection<T> standard) {
return Evaluator.create(result, standard).getTruePositives();
}
public static <T> Collection<T> getFalsePositives(Collection<T> result, Collection<T> standard) {
return Evaluator.create(result, standard).getFalsePositives();
}
public static <T> Collection<T> getFalseNagatives(Collection<T> result, Collection<T> standard) {
return Evaluator.create(result, standard).getFalsePositives();
}
public int getNumberOfTP() {
return truePositives.size();
}
public int getNumberOfFP() {
return falsePositives.size();
}
public int getNumberOfFN() {
return falseNegatives.size();
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("Evaluator [truePositives=");
builder.append(truePositives.size());
builder.append(", falsePositives=");
builder.append(falsePositives.size());
builder.append(", falseNegatives=");
builder.append(falseNegatives.size());
builder.append("]");
return builder.toString();
}
}