/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.fst.confidence; import java.util.Vector; import java.util.Collections; import java.util.Comparator; import cc.mallet.fst.*; import cc.mallet.types.*; public class ConfidenceEvaluator { static int DEFAULT_NUM_BINS = 20; Vector confidences; int nBins; int numCorrect; public ConfidenceEvaluator (Vector confidences, int nBins) { this.confidences = confidences; this.nBins = nBins; this.numCorrect = getNumCorrectEntities(); // sort confidences by score Collections.sort (confidences, new ConfidenceComparator()); } public ConfidenceEvaluator (Vector confidences) { this (confidences, DEFAULT_NUM_BINS); } public ConfidenceEvaluator (Segment[] segments, boolean sorted) { this.confidences = new Vector (); for (int i=0; i < segments.length; i++) { confidences.add (new EntityConfidence (segments[i].getConfidence(), segments[i].correct(), segments[i].getInput(), segments[i].getStart(), segments[i].getEnd())); } if (!sorted) Collections.sort (confidences, new ConfidenceComparator()); this.nBins = DEFAULT_NUM_BINS; this.numCorrect = getNumCorrectEntities (); } public ConfidenceEvaluator (InstanceWithConfidence[] instances, boolean sorted) { this.confidences = new Vector (); for (int i=0; i < instances.length; i++) { Sequence input = (Sequence) instances[i].getInstance().getData(); confidences.add (new EntityConfidence (instances[i].getConfidence(), instances[i].correct(), input, 0, input.size()-1)); } if (!sorted) Collections.sort (confidences, new ConfidenceComparator()); this.nBins = DEFAULT_NUM_BINS; this.numCorrect = getNumCorrectEntities (); } public ConfidenceEvaluator (PipedInstanceWithConfidence[] instances, boolean sorted) { this.confidences = new Vector (); for (int i=0; i < instances.length; i++) { confidences.add (new EntityConfidence (instances[i].getConfidence(), instances[i].correct(), null, 0, 1)); } if (!sorted) Collections.sort (confidences, new ConfidenceComparator()); this.nBins = DEFAULT_NUM_BINS; this.numCorrect = getNumCorrectEntities (); } /** Correlation when one variable (X) is binary: r = (bar(x1) - bar(x0)) * sqrt(p(1-p)) / sx , where bar(x1) = mean of X when Y is 1 bar(x0) = mean of X when Y is 0 sx = standard deviation of X p = proportion of values where Y=1 */ public double pointBiserialCorrelation () { // here, Y = {incorrect = 0,correct = 1}, X = confidence double x0bar = getAverageIncorrectConfidence (); double x1bar = getAverageCorrectConfidence (); double p = (double)this.numCorrect / size(); double sx = getConfidenceStandardDeviation (); return (x1bar - x0bar) * Math.sqrt(p*(1-p)) / sx; } /** IR Average precision measure. Analogous to ranking _correct_ documents by confidence score. */ public double getAveragePrecision () { int nc = 0; int ni = 0; double totalPrecision = 0.0; for (int i=confidences.size()-1; i >= 0; i--) { EntityConfidence c = (EntityConfidence) confidences.get (i); if (c.correct()) { nc++; totalPrecision += (double)nc / (nc + ni); } else ni++; } return totalPrecision / nc; } /** For comparison, rank segments as badly as possible (all "incorrect" before "correct"). */ public double getWorstAveragePrecision () { int ni = confidences.size() - this.numCorrect; double totalPrecision = 0.0; for (int nc=1; nc <= this.numCorrect; nc++) { totalPrecision += (double) nc / (nc + ni); } return totalPrecision / this.numCorrect; } public double getConfidenceSum() { double sum = 0.0; for (int i = 0; i < size(); i++) sum += ((EntityConfidence)confidences.get(i)).confidence(); return sum; } public double getConfidenceMean () { return getConfidenceSum() / size(); } /** Standard deviation of confidence scores */ public double getConfidenceStandardDeviation () { double mean = getConfidenceMean(); double sumSquaredDifference = 0.0; for (int i = 0; i < size(); i++) { double conf = ((EntityConfidence)confidences.get(i)).confidence(); sumSquaredDifference += ((conf - mean) * (conf - mean)); } return Math.sqrt (sumSquaredDifference / (double)size()); } /** Calculate pearson's R for the corellation between confidence and * correct, where 1 = correct and -1 = incorrect */ public double correlation () { double xSum = 0; double xSumOfSquares = 0; double ySum = 0; double ySumOfSquares = 0; double xySum = 0; // product of x and y for (int i = 0; i < size(); i++) { double value = ((EntityConfidence)confidences.get(i)).correct() ? 1.0 : -1.0; xSum += value; xSumOfSquares += (value * value); double conf = ((EntityConfidence)confidences.get(i)).confidence(); ySum += conf; ySumOfSquares += (conf * conf); xySum += value * conf; } double xVariance = xSumOfSquares - (xSum * xSum / size()); double yVariance = ySumOfSquares - (ySum * ySum / size()); double crossVariance = xySum - (xSum * ySum / size()); return crossVariance / Math.sqrt (xVariance * yVariance); } /** get accuracy at coverage for each bin of values */ public double[] getAccuracyCoverageValues () { double [] values = new double [this.nBins]; int step = 100 / nBins; for (int i = 0; i < values.length; i++) { values[i] = accuracyAtCoverage (step * (double)(i+1) / 100.0); } return values; } public String accuracyCoverageValuesToString () { String buf = ""; double [] vals = getAccuracyCoverageValues (); int step = 100 / nBins; for (int i=0; i < vals.length; i++) { buf += ((step * (double)(i+1))/100.0) + "\t" + vals[i] + "\n"; } return buf; } /** get accuracy at recall for each bin of values * @param totalTrue total number of true Segments * @return 2-d array where values[i][0] is coverage and * values[i][1] is accuracy at position i. */ public double[][] getAccuracyRecallValues (int totalTrue) { double [][] values = new double [this.nBins][2]; int step = 100 / nBins; for (int i = 0; i < this.nBins; i++) { values[i] = new double[2]; double coverage = step * (double)(i+1) / 100.0; values[i][1] = accuracyAtCoverage(coverage); int numCorrect = numCorrectAtCoverage(coverage); values[i][0] = (double)numCorrect / totalTrue; } return values; } public String accuracyRecallValuesToString (int totalTrue) { String buf = ""; double [][] vals = getAccuracyRecallValues (totalTrue); for (int i=0; i < this.nBins; i++) buf += vals[i][0] + "\t" + vals[i][1] + "\n"; return buf; } public double accuracyAtCoverage (double cov) { assert (cov <= 1 && cov > 0); int numPoints = (int) (Math.round ((double)size()*cov)); return ((double)numCorrectAtCoverage(cov) / numPoints); } public int numCorrectAtCoverage (double cov) { assert (cov <= 1 && cov > 0); // num accuracies to sum for this value of cov int numPoints = (int) (Math.round ((double)size()*cov)); int numCorrect = 0; for (int i = 0; i < numPoints; i++) { if (((EntityConfidence)confidences.get(size() - i - 1)).correct()) numCorrect++; } return numCorrect; } public double getAverageAccuracy () { int numCorrect = 0; double totalArea= 0.0; for(int i=confidences.size()-1; i>=0; i--){ if ( ((EntityConfidence)confidences.get(i)).correct()) numCorrect++; totalArea += (double)numCorrect / (confidences.size() - i); } return totalArea / confidences.size(); } public int numCorrect() { return this.numCorrect; } /** number of entities correctly extracted */ private int getNumCorrectEntities () { int sum = 0; for (int i = 0; i < confidences.size(); i++) { EntityConfidence ec = (EntityConfidence) confidences.get(i); if (ec.correct()) { sum++; } } return sum; } /** Average confidence score for the incorrect entities */ public double getAverageIncorrectConfidence () { double sum = 0.0; for (int i = 0; i < confidences.size(); i++) { EntityConfidence ec = (EntityConfidence) confidences.get(i); if (!ec.correct()) { sum += ec.confidence(); } } return sum / ((double)size() - (double) this.numCorrect); } /** Average confidence score for the incorrect entities */ public double getAverageCorrectConfidence () { double sum = 0.0; for (int i = 0; i < confidences.size(); i++) { EntityConfidence ec = (EntityConfidence) confidences.get(i); if (ec.correct()) { sum += ec.confidence(); } } return sum / (double) this.numCorrect; } public int size() { return confidences.size(); } public String toString() { StringBuffer toReturn = new StringBuffer(); for (int i = 0; i < size(); i++) { toReturn.append (((EntityConfidence)confidences.get(i)).toString() + " "); } return toReturn.toString(); } /** a simple class to store a confidence score and whether or not this * labeling is correct */ public static class EntityConfidence { double confidence; boolean correct; String entity; public EntityConfidence (double conf, boolean corr, String text){ this.confidence = conf; this.correct = corr; this.entity = text; } public EntityConfidence (double conf, boolean corr, Sequence input, int start, int end){ this.confidence = conf; this.correct = corr; StringBuffer buff = new StringBuffer(); if (input != null) { for (int j = start; j <= end; j++){ FeatureVector fv = (FeatureVector) input.get(j); for (int k = 0; k < fv.numLocations(); k++) { String featureName = fv.getAlphabet().lookupObject (fv.indexAtLocation (k)).toString(); if (featureName.startsWith ("W=") && featureName.indexOf("@") == -1){ buff.append(featureName.substring (featureName.indexOf ('=')+1) + " "); } } } } this.entity = buff.toString(); } public double confidence () {return confidence;} public boolean correct () {return correct;} public String toString () { StringBuffer toReturn = new StringBuffer(); toReturn.append(this.entity + " / " + this.confidence + " / "+ (this.correct ? "correct" : "incorrect") + "\n"); return toReturn.toString(); } } private class ConfidenceComparator implements Comparator { public final int compare (Object a, Object b) { double x = ((EntityConfidence) a).confidence(); double y = ((EntityConfidence) b).confidence(); double difference = x - y; int toReturn = 0; if(difference > 0) toReturn = 1; else if (difference < 0) toReturn = -1; return(toReturn); } } }