package edu.stanford.nlp.stats;
import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.Datum;
import java.util.function.Function;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Triple;
import java.io.BufferedReader;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.*;
import java.util.regex.Pattern;
/**
* Extension of MultiClassPrecisionRecallStats that also computes accuracy
* @author Angel Chang
*/
public class MultiClassPrecisionRecallExtendedStats<L> extends MultiClassPrecisionRecallStats<L> {
protected IntCounter<L> correctGuesses;
protected IntCounter<L> foundCorrect;
protected IntCounter<L> foundGuessed;
protected int tokensCount = 0;
protected int tokensCorrect = 0;
protected int noLabel = 0;
protected Function<String,L> stringConverter;
public <F> MultiClassPrecisionRecallExtendedStats(Classifier<L,F> classifier, GeneralDataset<L,F> data, L negLabel)
{
super(classifier, data, negLabel);
}
public MultiClassPrecisionRecallExtendedStats(L negLabel)
{
super(negLabel);
}
public MultiClassPrecisionRecallExtendedStats(Index<L> dataLabelIndex, L negLabel)
{
this(negLabel);
setLabelIndex(dataLabelIndex);
}
public void setLabelIndex(Index<L> dataLabelIndex) {
labelIndex = dataLabelIndex;
negIndex = labelIndex.indexOf(negLabel);
}
public <F> double score(Classifier<L,F> classifier, GeneralDataset<L,F> data) {
labelIndex = new HashIndex<>();
labelIndex.addAll(classifier.labels());
labelIndex.addAll(data.labelIndex.objectsList());
clearCounts();
int[] labelsArr = data.getLabelsArray();
for (int i = 0; i < data.size(); i++) {
Datum<L, F> d = data.getRVFDatum(i);
L guess = classifier.classOf(d);
addGuess(guess, labelIndex.get(labelsArr[i]));
}
finalizeCounts();
return getFMeasure();
}
/**
* Returns the score (F1) for the given list of guesses
* @param guesses - Guesses by classifier
* @param trueLabels - Gold labels to compare guesses against
* @param dataLabelIndex - Index of labels
* @return F1 score
*/
public double score(List<L> guesses, List<L> trueLabels, Index<L> dataLabelIndex) {
setLabelIndex(dataLabelIndex);
return score(guesses, trueLabels);
}
/**
* Returns the score (F1) for the given list of guesses
* @param guesses - Guesses by classifier
* @param trueLabels - Gold labels to compare guesses against
* @return F1 score
*/
public double score(List<L> guesses, List<L> trueLabels) {
clearCounts();
addGuesses(guesses, trueLabels);
finalizeCounts();
return getFMeasure();
}
public double score()
{
finalizeCounts();
return getFMeasure();
}
public void clearCounts()
{
if (foundCorrect != null) {
foundCorrect.clear();
} else {
foundCorrect = new IntCounter<>();
}
if (foundGuessed != null) {
foundGuessed.clear();
} else {
foundGuessed = new IntCounter<>();
}
if (correctGuesses != null) {
correctGuesses.clear();
} else {
correctGuesses = new IntCounter<>();
}
if (tpCount != null) {
Arrays.fill(tpCount, 0);
}
if (fnCount != null) {
Arrays.fill(fnCount, 0);
}
if (fpCount != null) {
Arrays.fill(fpCount, 0);
}
tokensCount = 0;
tokensCorrect = 0;
}
protected void finalizeCounts()
{
negIndex = labelIndex.indexOf(negLabel);
int numClasses = labelIndex.size();
if (tpCount == null || tpCount.length != numClasses) {
tpCount = new int[numClasses];
}
if (fpCount == null || fpCount.length != numClasses) {
fpCount = new int[numClasses];
}
if (fnCount == null || fnCount.length != numClasses) {
fnCount = new int[numClasses];
}
for (int i = 0; i < numClasses; i++) {
L label = labelIndex.get(i);
tpCount[i] = correctGuesses.getIntCount(label);
fnCount[i] = foundCorrect.getIntCount(label) - tpCount[i];
fpCount[i] = foundGuessed.getIntCount(label) - tpCount[i];
}
}
protected void markBoundary()
{
}
protected void addGuess(L guess, L label)
{
addGuess(guess, label, true);
}
protected void addGuess(L guess, L label, boolean addUnknownLabels)
{
if (label == null) {
noLabel++;
return;
}
if (addUnknownLabels) {
if (labelIndex == null) {
labelIndex = new HashIndex<>();
}
labelIndex.add(guess);
labelIndex.add(label);
}
if (guess.equals(label)) {
correctGuesses.incrementCount(label);
tokensCorrect++;
}
if (!guess.equals(negLabel)) {
foundGuessed.incrementCount(guess);
}
if (!label.equals(negLabel)) {
foundCorrect.incrementCount(label);
}
tokensCount++;
}
public void addGuesses(List<L> guesses, List<L> trueLabels)
{
for (int i=0; i < guesses.size(); ++i)
{
L guess = guesses.get(i);
L label = trueLabels.get(i);
addGuess(guess, label);
}
}
/**
* Return overall number of correct answers
*/
public int getCorrect()
{
return correctGuesses.totalIntCount();
}
public int getCorrect(L label)
{
return correctGuesses.getIntCount(label);
}
public int getRetrieved(L label)
{
return foundGuessed.getIntCount(label);
}
public int getRetrieved()
{
return foundGuessed.totalIntCount();
}
public int getRelevant(L label)
{
return foundCorrect.getIntCount(label);
}
public int getRelevant()
{
return foundCorrect.totalIntCount();
}
/**
* Return overall per token accuracy
*/
public Triple<Double, Integer, Integer> getAccuracyInfo()
{
int totalCorrect = tokensCorrect;
int totalWrong = tokensCount - tokensCorrect;
return new Triple<>((((double) totalCorrect) / tokensCount),
totalCorrect, totalWrong);
}
public double getAccuracy() {
return getAccuracyInfo().first();
}
/**
* Returns a String summarizing overall accuracy that will print nicely.
*/
public String getAccuracyDescription(int numDigits) {
NumberFormat nf = NumberFormat.getNumberInstance();
nf.setMaximumFractionDigits(numDigits);
Triple<Double, Integer, Integer> accu = getAccuracyInfo();
return nf.format(accu.first()) + " (" + accu.second() + "/" + (accu.second() + accu.third()) + ")";
}
public double score(String filename, String delimiter) throws IOException {
return score(filename, delimiter, null);
}
public double score(String filename, String delimiter, String boundary) throws IOException {
return score(IOUtils.getBufferedFileReader(filename), delimiter, boundary);
}
public double score(BufferedReader br, String delimiter) throws IOException
{
return score(br, delimiter, null);
}
public double score(BufferedReader br, String delimiter, String boundary) throws IOException
{
int TOKEN_INDEX = 0;
int ANSWER_INDEX = 1;
int GUESS_INDEX = 2;
String line;
Pattern delimPattern = Pattern.compile(delimiter);
clearCounts();
while ((line = br.readLine()) != null) {
line = line.trim();
if (line.length() > 0) {
String[] fields = delimPattern.split(line);
if (boundary != null && boundary.equals(fields[TOKEN_INDEX])) {
markBoundary();
} else {
L answer = stringConverter.apply(fields[ANSWER_INDEX]);
L guess = stringConverter.apply(fields[GUESS_INDEX]);
addGuess(guess, answer);
}
} else {
markBoundary();
}
}
finalizeCounts();
return getFMeasure();
}
public List<L> getLabels() {
return labelIndex.objectsList();
}
public String getConllEvalString()
{
return getConllEvalString(true);
}
public String getConllEvalString(boolean ignoreNegLabel)
{
List<L> labels = getLabels();
if (labels.size() > 1 && labels.get(0) instanceof Comparable) {
List<Comparable> sortedLabels = (List<Comparable>) labels;
Collections.sort(sortedLabels);
}
return getConllEvalString(labels, ignoreNegLabel);
}
private String getConllEvalString(List<L> orderedLabels, boolean ignoreNegLabel)
{
StringBuilder sb = new StringBuilder();
int correctPhrases = getCorrect() - getCorrect(negLabel);
Triple<Double,Integer,Integer> accuracyInfo = getAccuracyInfo();
int totalCount = accuracyInfo.second() + accuracyInfo.third();
sb.append("processed " + totalCount + " tokens with " + getRelevant() + " phrases; ");
sb.append("found: " + getRetrieved() + " phrases; correct: " + correctPhrases + "\n");
Formatter formatter = new Formatter(sb, Locale.US);
formatter.format("accuracy: %6.2f%%; ", accuracyInfo.first() * 100);
formatter.format("precision: %6.2f%%; ", getPrecision() * 100);
formatter.format("recall: %6.2f%%; ", getRecall() * 100);
formatter.format("FB1: %6.2f\n", getFMeasure() * 100);
for (L label: orderedLabels) {
if (ignoreNegLabel && label.equals(negLabel)) { continue; }
formatter.format("%17s: ", label);
formatter.format("precision: %6.2f%%; ", getPrecision(label) * 100);
formatter.format("recall: %6.2f%%; ", getRecall(label) * 100);
formatter.format("FB1: %6.2f %d\n", getFMeasure(label) * 100, getRetrieved(label));
}
return sb.toString();
}
public static class StringStringConverter implements Function<String,String>
{
public String apply(String str) { return str; }
}
public static class MultiClassStringLabelStats extends MultiClassPrecisionRecallExtendedStats<String>
{
public <F> MultiClassStringLabelStats(Classifier<String,F> classifier, GeneralDataset<String,F> data, String negLabel)
{
super(classifier, data, negLabel);
stringConverter = new StringStringConverter();
}
public MultiClassStringLabelStats(String negLabel)
{
super(negLabel);
stringConverter = new StringStringConverter();
}
public MultiClassStringLabelStats(Index<String> dataLabelIndex, String negLabel)
{
this(negLabel);
setLabelIndex(dataLabelIndex);
}
}
}