package edu.stanford.nlp.coref.statistical; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import edu.stanford.nlp.coref.statistical.Clusterer.Cluster; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.util.Pair; /** * Utility classes for computing the B^3 and MUC coreference metrics * @author Kevin Clark */ public class EvalUtils { public static double getCombinedF1(double mucWeight, List<List<Integer>> gold, List<Cluster> clusters, Map<Integer, List<Integer>> mentionToGold, Map<Integer, Cluster> mentionToSystem) { CombinedEvaluator combined = new CombinedEvaluator(mucWeight); combined.update(gold, clusters, mentionToGold, mentionToSystem); return combined.getF1(); } public static double f1(double pNum, double pDen, double rNum, double rDen) { double p = pNum == 0 ? 0 : pNum / pDen; double r = rNum == 0 ? 0 : rNum / rDen; return p == 0 ? 0 : 2 * p * r / (p + r); } public interface Evaluator { public void update(List<List<Integer>> gold, List<Cluster> clusters, Map<Integer, List<Integer>> mentionToGold, Map<Integer, Cluster> mentionToSystem); public double getF1(); } public static class CombinedEvaluator implements Evaluator { private final B3Evaluator b3Evaluator; private final MUCEvaluator mucEvaluator; private final double mucWeight; public CombinedEvaluator(double mucWeight) { b3Evaluator = new B3Evaluator(); mucEvaluator = new MUCEvaluator(); this.mucWeight = mucWeight; } @Override public void update(List<List<Integer>> gold, List<Cluster> clusters, Map<Integer, List<Integer>> mentionToGold, Map<Integer, Cluster> mentionToSystem) { if (mucWeight != 1) { b3Evaluator.update(gold, clusters, mentionToGold, mentionToSystem); } if (mucWeight != 0) { mucEvaluator.update(gold, clusters, mentionToGold, mentionToSystem); } } @Override public double getF1() { return (mucWeight == 0 ? 0 : mucWeight * mucEvaluator.getF1()) + (mucWeight == 1 ? 0 : (1 - mucWeight) * b3Evaluator.getF1()); } } public static abstract class AbstractEvaluator implements Evaluator { public double pNum; public double pDen; public double rNum; public double rDen; @Override public void update(List<List<Integer>> gold, List<Cluster> clusters, Map<Integer, List<Integer>> mentionToGold, Map<Integer, Cluster> mentionToSystem) { List<List<Integer>> clustersAsList = clusters.stream().map(c -> c.mentions) .collect(Collectors.toList()); Map<Integer, List<Integer>> mentionToSystemLists = mentionToSystem.entrySet().stream() .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().mentions)); Pair<Double, Double> prec = getScore(clustersAsList, mentionToGold); Pair<Double, Double> rec = getScore(gold, mentionToSystemLists); pNum += prec.first; pDen += prec.second; rNum += rec.first; rDen += rec.second; } @Override public double getF1() { return f1(pNum, pDen, rNum, rDen); } public double getRecall() { return pNum == 0 ? 0 : pNum / pDen; } public double getPrecision() { return rNum == 0 ? 0 : rNum / rDen; } public abstract Pair<Double, Double> getScore(List<List<Integer>> clusters, Map<Integer, List<Integer>> mentionToGold); } public static class B3Evaluator extends AbstractEvaluator { @Override public Pair<Double, Double> getScore(List<List<Integer>> clusters, Map<Integer, List<Integer>> mentionToGold) { double num = 0; int dem = 0; for (List<Integer> c : clusters) { if (c.size() == 1) { continue; } Counter<List<Integer>> goldCounts = new ClassicCounter<>(); double correct = 0; for (int m : c) { List<Integer> goldCluster = mentionToGold.get(m); if (goldCluster != null) { goldCounts.incrementCount(goldCluster); } } for (Map.Entry<List<Integer>, Double> e : goldCounts.entrySet()) { if (e.getKey().size() != 1) { correct += e.getValue() * e.getValue(); } } num += correct / c.size(); dem += c.size(); } return new Pair<>(num, (double) dem); } } public static class MUCEvaluator extends AbstractEvaluator { @Override public Pair<Double, Double> getScore(List<List<Integer>> clusters, Map<Integer, List<Integer>> mentionToGold) { int tp = 0; int predictedPositive = 0; for (List<Integer> c : clusters) { predictedPositive += c.size() - 1; tp += c.size(); Set<List<Integer>> linked = new HashSet<>(); for (int m : c) { List<Integer> g = mentionToGold.get(m); if (g == null) { tp -= 1; } else { linked.add(g); } } tp -= linked.size(); } return new Pair<>((double) tp, (double) predictedPositive); } } }