package edu.stanford.nlp.stats; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Map; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; import edu.stanford.nlp.util.ErasureUtils; import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.Sets; /** * A class for calculating precision and recall statistics based on * comparisons between two {@link Collection}s. * Allows flexible specification of: * <p/> * <ul> * <li>The criterion by which to evaluate whether two Objects are equivalent * for purposes of precision and recall * calculation (specified by an {@link EqualityChecker} instance) * <li>The criterion by which Objects are grouped into equivalence classes * for purposes of calculating subclass precision * and recall (specified by an {@link EquivalenceClasser} instance) * <li>Evaluation is set-based or bag-based (by default, it is set-based). For example, if a gold collection * has {a,a,b} and a guess collection has {a,b}, then recall is 100% in set-based * evaluation, but is 66.67% in bag-based evaluation. * </ul> * * Note that for set-based evaluation, sets are always constructed using object equality, NOT * equality on the basis of an {@link EqualityChecker} if one is given. If set-based evaluation * were conducted on the basis of an EqualityChecker, then there would be indeterminacy when it did not subsume the {@link EquivalenceClasser}, * if one was given. For example, if objects of the form * X:y were equivalence-classed by the left criterion and evaluated for equality on the right, then set-based * evaluation based on the equality checker would be indeterminate for a collection of {A:a,B:a} * because it would be unclear whether to use the first or second element of the collection. * * @author Roger Levy * @author Sarah Spikes (sdspikes@cs.stanford.edu) Attempt at templatization... this may be a failure. */ public class EquivalenceClassEval<IN, OUT> { /** If bagEval is set to <code>true</code>, then multiple instances of the same item will not be merged. For example, * gold (a,a,b) against guess (a,b) will be scored as 100% precision and 66.67% recall. It is <code>false</code> * by default.*/ public void setBagEval(boolean bagEval) { this.bagEval = bagEval; } protected boolean bagEval = false; /** * Maps all objects to the equivalence class <code>null</code> */ @SuppressWarnings("unchecked") public static final EquivalenceClasser NULL_EQUIVALENCE_CLASSER = o -> null; public static final <T,U> EquivalenceClasser<T,U> nullEquivalenceClasser() { return ErasureUtils.<EquivalenceClasser<T,U>>uncheckedCast(NULL_EQUIVALENCE_CLASSER); } private boolean verbose = false; EquivalenceClasser<IN, OUT> eq; Eval.CollectionContainsChecker<IN> checker; String summaryName; /** * Specifies a default EquivalenceClassEval, using {@link Object#equals(java.lang.Object)} as equality criterion * and grouping all items into the "null" equivalence class for reporting purposes */ public EquivalenceClassEval() { this(EquivalenceClassEval.<IN,OUT>nullEquivalenceClasser()); } /** * Specifies an EquivalenceClassEval using {@link Object#equals(java.lang.Object)} as equality criterion * and grouping all items according to the EquivalenceClasser argument. */ public EquivalenceClassEval(EquivalenceClasser<IN, OUT> eq) { this(eq, ""); } /** * Specifies an EquivalenceClassEval using the Eval.EqualityChecker argument as equality criterion * and grouping all items into a single equivalence class for reporting statistics. */ public EquivalenceClassEval(EqualityChecker<IN> e) { this(EquivalenceClassEval.<IN,OUT>nullEquivalenceClasser(), e); } /** * Specifies an EquivalenceClassEval using {@link Object#equals(java.lang.Object)} as equality criterion * and grouping all items according to the EquivalenceClasser argument. */ public EquivalenceClassEval(EquivalenceClasser<IN, OUT> eq, String name) { this(eq, EquivalenceClassEval.<IN>defaultChecker(), name); } /** * Specifies an EquivalenceClassEval using the Eval.EqualityChecker argument as equality criterion * and grouping all items according to the EquivalenceClasser argument. */ public EquivalenceClassEval(EquivalenceClasser<IN, OUT> eq, EqualityChecker<IN> e) { this(eq, e, ""); } /** * Specifies an EquivalenceClassEval using the Eval.EqualityChecker argument as equality criterion * and grouping all items according to the EquivalenceClasser argument. */ public EquivalenceClassEval(EquivalenceClasser<IN, OUT> eq, EqualityChecker<IN> e, String summaryName) { this(eq, new Eval.CollectionContainsChecker<>(e), summaryName); } EquivalenceClassEval(EquivalenceClasser<IN, OUT> eq, Eval.CollectionContainsChecker<IN> checker, String summaryName) { this.eq = eq; this.checker = checker; this.summaryName = summaryName; } ClassicCounter<OUT> guessed = new ClassicCounter<>(); ClassicCounter<OUT> guessedCorrect = new ClassicCounter<>(); ClassicCounter<OUT> gold = new ClassicCounter<>(); ClassicCounter<OUT> goldCorrect = new ClassicCounter<>(); private ClassicCounter<OUT> lastPrecision = new ClassicCounter<>(); private ClassicCounter<OUT> lastRecall = new ClassicCounter<>(); private ClassicCounter<OUT> lastF1 = new ClassicCounter<>(); private ClassicCounter<OUT> previousGuessed; private ClassicCounter<OUT> previousGuessedCorrect; private ClassicCounter<OUT> previousGold; private ClassicCounter<OUT> previousGoldCorrect; //Eval eval = new Eval(); /** * Adds a round of evaluation between guesses and golds {@link Collection}s to the tabulated statistics of * the evaluation. */ public void eval(Collection<IN> guesses, Collection<IN> golds) { eval(guesses, golds, new PrintWriter(System.out, true)); } // this one is all side effects /** * @param guesses Collection of guessed objects * @param golds Collection of gold-standard objects * @param pw {@link PrintWriter} to print eval stats */ public void eval(Collection<IN> guesses, Collection<IN> golds, PrintWriter pw) { if (verbose) { System.out.println("evaluating precision..."); } Pair<ClassicCounter<OUT>, ClassicCounter<OUT>> precision = evalPrecision(guesses, golds); previousGuessed = precision.first(); Counters.addInPlace(guessed, previousGuessed); previousGuessedCorrect = precision.second(); Counters.addInPlace(guessedCorrect, previousGuessedCorrect); if (verbose) { System.out.println("evaluating recall..."); } Pair<ClassicCounter<OUT>, ClassicCounter<OUT>> recall = evalPrecision(golds, guesses); previousGold = recall.first(); Counters.addInPlace(gold, previousGold); previousGoldCorrect = recall.second(); Counters.addInPlace(goldCorrect, previousGoldCorrect); } /* returns a Pair of each */ Pair<ClassicCounter<OUT>, ClassicCounter<OUT>> evalPrecision(Collection<IN> guesses, Collection<IN> golds) { Collection<IN> internalGuesses = null; Collection<IN> internalGolds = null; if(bagEval) { internalGuesses = new ArrayList<>(guesses.size()); internalGolds = new ArrayList<>(golds.size()); } else { internalGuesses = Generics.newHashSet(guesses.size()); internalGolds = Generics.newHashSet(golds.size()); } internalGuesses.addAll(guesses); internalGolds.addAll(golds); ClassicCounter<OUT> thisGuessed = new ClassicCounter<>(); ClassicCounter<OUT> thisCorrect = new ClassicCounter<>(); for (IN o : internalGuesses) { OUT equivalenceClass = eq.equivalenceClass(o); thisGuessed.incrementCount(equivalenceClass); if (checker.contained(o, internalGolds)) { thisCorrect.incrementCount(equivalenceClass); removeItem(o,internalGolds,checker); } else { if (verbose) { System.out.println("Eval missed " + o); } } } return Generics.newPair(thisGuessed, thisCorrect); } /* there is some discomfort here, we should really be using an EqualityChecker for checker, but * I screwed up the API. */ protected static <T> void removeItem(T o, Collection<T> c, Eval.CollectionContainsChecker<T> checker) { for(T o1 : c) { if(checker.contained(o,Collections.singleton(o1))) { c.remove(o1); return; } } } /** * Displays the cumulative results of the evaluation to {@link System#out}. */ public void display() { display(new PrintWriter(System.out, true)); } /** * Displays the cumulative results of the evaluation. */ public void display(PrintWriter pw) { pw.println("*********Final " + summaryName + " eval stats by antecedent category***********"); Set<OUT> keys = Generics.newHashSet(); keys.addAll(guessed.keySet()); keys.addAll(gold.keySet()); displayHelper(keys, pw, guessed, guessedCorrect, gold, goldCorrect); pw.println("Finished final " + summaryName + " eval stats."); } /** * Displays the results of the previous Collection pair evaluation to {@link System#out}. */ public void displayLast() { displayLast(new PrintWriter(System.out, true)); } /** * Displays the results of the previous Collection pair evaluation. */ public void displayLast(PrintWriter pw) { Set<OUT> keys = Generics.newHashSet(); keys.addAll(previousGuessed.keySet()); keys.addAll(previousGold.keySet()); displayHelper(keys, pw, previousGuessed, previousGuessedCorrect, previousGold, previousGoldCorrect); } public double precision(OUT key) { return percentage(key, guessed, guessedCorrect); } public double recall(OUT key) { return percentage(key, gold, goldCorrect); } public double lastPrecision(OUT key) { return percentage(key, previousGuessed, previousGuessedCorrect); } public ClassicCounter<OUT> lastPrecision() { ClassicCounter<OUT> result = new ClassicCounter<>(); Counters.addInPlace(result, previousGuessedCorrect); Counters.divideInPlace(result, previousGuessed); return result; } public double lastRecall(OUT key) { return percentage(key, previousGold, previousGoldCorrect); } public ClassicCounter<OUT> lastRecall() { ClassicCounter<OUT> result = new ClassicCounter<>(); Counters.addInPlace(result, previousGoldCorrect); Counters.divideInPlace(result, previousGold); return result; } public double lastNumGuessed(OUT key) { return previousGuessed.getCount(key); } public ClassicCounter<OUT> lastNumGuessed() { return previousGuessed; } public ClassicCounter<OUT> lastNumGuessedCorrect() { return previousGuessedCorrect; } public double lastNumGolds(OUT key) { return previousGold.getCount(key); } public ClassicCounter<OUT> lastNumGolds() { return previousGold; } public ClassicCounter<OUT> lastNumGoldsCorrect() { return previousGoldCorrect; } public double f1(OUT key) { return f1(precision(key), recall(key)); } public double lastF1(OUT key) { return f1(lastPrecision(key), lastRecall(key)); } public ClassicCounter<OUT> lastF1() { ClassicCounter<OUT> result = new ClassicCounter<>(); Set<OUT> keys = Sets.union(previousGuessed.keySet(),previousGold.keySet()); for(OUT key : keys) { result.setCount(key,lastF1(key)); } return result; } public static double f1(double precision, double recall) { return (precision == 0.0 || recall == 0.0) ? 0.0 : (2 * precision * recall) / (precision + recall); } public static <E> Counter<E> f1(Counter<E> precision, Counter<E> recall) { Counter<E> result = precision.getFactory().create(); for(E key : Sets.intersection(precision.keySet(),recall.keySet())) { result.setCount(key,f1(precision.getCount(key),recall.getCount(key))); } return result; } private double percentage(OUT key, ClassicCounter<OUT> guessed, ClassicCounter<OUT> guessedCorrect) { double thisGuessed = guessed.getCount(key); double thisGuessedCorrect = guessedCorrect.getCount(key); return (thisGuessed == 0.0) ? 0.0 : thisGuessedCorrect / thisGuessed; } private void displayHelper(Set<OUT> keys, PrintWriter pw, ClassicCounter<OUT> guessed, ClassicCounter<OUT> guessedCorrect, ClassicCounter<OUT> gold, ClassicCounter<OUT> goldCorrect) { Map<OUT, String> pads = getPads(keys); for (OUT key : keys) { double thisGuessed = guessed.getCount(key); double thisGuessedCorrect = guessedCorrect.getCount(key); double precision = (thisGuessed == 0.0) ? 0.0 : thisGuessedCorrect / thisGuessed; lastPrecision.setCount(key, precision); double thisGold = gold.getCount(key); double thisGoldCorrect = goldCorrect.getCount(key); double recall = (thisGold == 0.0) ? 0.0 : thisGoldCorrect / thisGold; lastRecall.setCount(key, recall); double f1 = f1(precision, recall); lastF1.setCount(key, f1); String pad = pads.get(key); pw.println(key + pad + "\t" + "P: " + formatNumber(precision) + "\ton " + formatCount(thisGuessed) + " objects\tR: " + formatNumber(recall) + "\ton " + formatCount(thisGold) + " objects\tF1: " + formatNumber(f1)); } } // public static String formatNumber(double d) { // double frac = d % 1.0; // int whole = (int) Math.round(d - frac); // int frac1 = (int) Math.round(frac * 1000); // String prePad = ""; // if(whole < 1000) // prePad += " "; // if(whole > 100) // prePad += " "; // if(whole > 10) // prePad += " "; // return pad + whole + "." + frac1; // } private static java.text.NumberFormat numberFormat = java.text.NumberFormat.getNumberInstance(); { numberFormat.setMaximumFractionDigits(4); numberFormat.setMinimumFractionDigits(4); numberFormat.setMinimumIntegerDigits(1); numberFormat.setMaximumIntegerDigits(1); } private static String formatNumber(double d) { return numberFormat.format(d); } private static int formatCount(double d) { return (int) Math.round(d); } /* find pads for each key based on length of longest key */ private static <OUT> Map<OUT, String> getPads(Set<OUT> keys) { Map<OUT, String> pads = Generics.newHashMap(); int max = 0; for (OUT key : keys) { String keyString = key==null ? "null" : key.toString(); if (keyString.length() > max) { max = keyString.length(); } } for (OUT key : keys) { String keyString = key==null ? "null" : key.toString(); int diff = max - keyString.length(); String pad = ""; for (int j = 0; j < diff; j++) { pad += " "; } pads.put(key, pad); } return pads; } public static void main(String[] args) { final Pattern p = Pattern.compile("^([^:]*):(.*)$"); Collection<String> guesses = Arrays.asList(new String[]{"S:a", "S:b", "VP:c", "VP:d", "S:a"}); Collection<String> golds = Arrays.asList(new String[]{"S:a", "S:b", "S:b", "VP:d", "VP:a"}); EqualityChecker<String> e = (o1, o2) -> { Matcher m1 = p.matcher(o1); m1.find(); String s1 = m1.group(2); System.out.println(s1); Matcher m2 = p.matcher(o2); m2.find(); String s2 = m2.group(2); System.out.println(s2); return s1.equals(s2); }; EquivalenceClasser<String, String> eq = o -> { Matcher m = p.matcher(o); m.find(); return m.group(1); }; EquivalenceClassEval<String, String> eval = new EquivalenceClassEval<>(eq, e, "testing"); eval.setBagEval(false); eval.eval(guesses, golds); eval.displayLast(); eval.display(); } /** * A strategy-type interface for specifying an equality criterion for pairs of {@link Object}s. * * @author Roger Levy */ public interface EqualityChecker<T> { /** * Returns <code>true</code> iff <code>o1</code> and <code>o2</code> are equal by the desired * evaluation criterion. */ public boolean areEqual(T o1, T o2); } /** * A default equality checker that uses {@link Object#equals} to determine equality. */ @SuppressWarnings("unchecked") public static final EqualityChecker DEFAULT_CHECKER = new EqualityChecker() { public boolean areEqual(Object o1, Object o2) { return o1.equals(o2); } }; @SuppressWarnings("unchecked") public static final <T> EqualityChecker<T> defaultChecker() { return DEFAULT_CHECKER; } static class Eval<T> { private boolean bagEval = false; public Eval(EqualityChecker<T> e) { this(false,e); } public Eval() { this(false); } public Eval(boolean bagEval) { this(bagEval,EquivalenceClassEval.<T>defaultChecker()); } public Eval(boolean bagEval, EqualityChecker<T> e) { checker = new CollectionContainsChecker<>(e); this.bagEval = bagEval; } CollectionContainsChecker<T> checker; /* a filter that returns true iff the object is a collection that contains currentItem */ static class CollectionContainsChecker<T> { EqualityChecker<T> e; public CollectionContainsChecker(EqualityChecker<T> e) { this.e = e; } public boolean contained(T obj, Collection<T> coll) { for (T o : coll) { if (e.areEqual(obj, o)) { return true; } } return false; } } // end class CollectionContainsChecker double guessed = 0.0; double guessedCorrect = 0.0; double gold = 0.0; double goldCorrect = 0.0; double lastPrecision; double lastRecall; double lastF1; public void eval(Collection<T> guesses, Collection<T> golds) { eval(guesses, golds, new PrintWriter(System.out, true)); } // this one is all side effects public void eval(Collection<T> guesses, Collection<T> golds, PrintWriter pw) { double precision = evalPrecision(guesses, golds); lastPrecision = precision; double recall = evalRecall(guesses, golds); lastRecall = recall; double f1 = (2 * precision * recall) / (precision + recall); lastF1 = f1; guessed += guesses.size(); guessedCorrect += (guesses.size() == 0.0 ? 0.0 : precision * guesses.size()); gold += golds.size(); goldCorrect += (golds.size() == 0.0 ? 0.0 : recall * golds.size()); pw.println("This example:\tP:\t" + precision + " R:\t" + recall + " F1:\t" + f1); double cumPrecision = guessedCorrect / guessed; double cumRecall = goldCorrect / gold; double cumF1 = (2 * cumPrecision * cumRecall) / (cumPrecision + cumRecall); pw.println("Cumulative:\tP:\t" + cumPrecision + " R:\t" + cumRecall + " F1:\t" + cumF1); } // this has no side effects! public double evalPrecision(Collection<T> guesses, Collection<T> golds) { Collection<T> internalGuesses; Collection<T> internalGolds; if(bagEval) { internalGuesses = new ArrayList<>(guesses.size()); internalGolds = new ArrayList<>(golds.size()); } else { internalGuesses = Generics.newHashSet(guesses.size()); internalGolds = Generics.newHashSet(golds.size()); } internalGuesses.addAll(guesses); internalGolds.addAll(golds); double thisGuessed = 0.0; double thisGuessedCorrect = 0.0; for (T o: internalGuesses) { thisGuessed += 1.0; if (checker.contained(o, internalGolds)) { thisGuessedCorrect += 1.0; removeItem(o,internalGolds,checker); } // else // System.out.println("Precision eval missed " + o); } return thisGuessedCorrect / thisGuessed; } // no side effects here either public double evalRecall(Collection<T> guesses, Collection<T> golds) { double thisGold = 0.0; double thisGoldCorrect = 0.0; for (T o : golds) { thisGold += 1.0; if (guesses.contains(o)) { thisGoldCorrect += 1.0; } // else // System.out.println("Recall eval missed " + o); } return thisGoldCorrect / thisGold; } public void display() { display(new PrintWriter(System.out, true)); } public void display(PrintWriter pw) { double precision = guessedCorrect / guessed; double recall = goldCorrect / gold; double f1 = (2 * precision * recall) / (precision + recall); pw.println("*********Final eval stats***********"); pw.println("P:\t" + precision + " R:\t" + recall + " F1:\t" + f1); } } public static interface Factory<IN, OUT> { public EquivalenceClassEval<IN, OUT> equivalenceClassEval(); } /** * returns a new {@link Factory} instance that vends new EquivalenceClassEval instances with * settings like <code>this</code> */ public Factory<IN, OUT> factory() { return new Factory<IN, OUT>() { boolean bagEval1 = bagEval; EquivalenceClasser<IN, OUT> eq1 = eq; Eval.CollectionContainsChecker<IN> checker1 = checker; String summaryName1 = summaryName; public EquivalenceClassEval<IN, OUT> equivalenceClassEval() { EquivalenceClassEval<IN, OUT> e = new EquivalenceClassEval<>(eq1, checker1, summaryName1); e.setBagEval(bagEval1); return e; } }; } }