package edu.stanford.nlp.stats; import edu.stanford.nlp.classify.GeneralDataset; import edu.stanford.nlp.classify.ProbabilisticClassifier; import edu.stanford.nlp.ling.Datum; import edu.stanford.nlp.util.BinaryHeapPriorityQueue; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.PriorityQueue; import edu.stanford.nlp.util.StringUtils; import java.text.NumberFormat; import java.util.List; /** * @author Jenny Finkel */ public class MultiClassAccuracyStats<L> implements Scorer<L> { double[] scores; //sorted scores boolean[] isCorrect; // is the i-th example correct double logLikelihood; double accuracy; static String saveFile = null; static int saveIndex = 1; public static final int USE_ACCURACY = 1; public static final int USE_LOGLIKELIHOOD = 2; private int scoreType = USE_ACCURACY; public MultiClassAccuracyStats(){ } public MultiClassAccuracyStats(int scoreType){ this.scoreType = scoreType; } public MultiClassAccuracyStats(String file){ this(file, USE_ACCURACY); } public MultiClassAccuracyStats(String file, int scoreType){ saveFile=file; this.scoreType = scoreType; } public <F> MultiClassAccuracyStats(ProbabilisticClassifier<L,F> classifier, GeneralDataset<L,F> data,String file) { this(classifier, data, file, USE_ACCURACY); } public <F> MultiClassAccuracyStats(ProbabilisticClassifier<L,F> classifier, GeneralDataset<L,F> data,String file, int scoreType) { saveFile=file; this.scoreType = scoreType; initMC(classifier, data); } int correct = 0; int total = 0; public <F> double score(ProbabilisticClassifier<L,F> classifier, GeneralDataset<L,F> data) { initMC(classifier,data); return score(); } public double score() { if (scoreType == USE_ACCURACY) { return accuracy; } else if (scoreType == USE_LOGLIKELIHOOD) { return logLikelihood; } else { throw new RuntimeException("Unknown score type: "+scoreType); } } public int numSamples() { return scores.length; } public double confidenceWeightedAccuracy() { double acc = 0; for (int recall = 1; recall <= numSamples(); recall++) { acc += numCorrect(recall) / (double) recall; } return acc / numSamples(); } public <F> void initMC(ProbabilisticClassifier<L,F> classifier, GeneralDataset<L,F> data) { //if (!(gData instanceof Dataset)) { // throw new UnsupportedOperationException("Can only handle Datasets, not "+gData.getClass().getName()); //} // //Dataset data = (Dataset)gData; PriorityQueue<Pair<Integer, Pair<Double, Boolean>>> q = new BinaryHeapPriorityQueue<>(); total = 0; correct = 0; logLikelihood = 0.0; for (int i = 0; i < data.size(); i++) { Datum<L,F> d = data.getRVFDatum(i); Counter<L> scores = classifier.logProbabilityOf(d); L guess = Counters.argmax(scores); L correctLab = d.label(); double guessScore = scores.getCount(guess); double correctScore = scores.getCount(correctLab); int guessInd = data.labelIndex().indexOf(guess); int correctInd = data.labelIndex().indexOf(correctLab); total++; if (guessInd == correctInd) { correct++; } logLikelihood += correctScore; q.add(new Pair<>(Integer.valueOf(i), new Pair<>(new Double(guessScore), Boolean.valueOf(guessInd == correctInd))), -guessScore); } accuracy = (double) correct / (double) total; List<Pair<Integer, Pair<Double, Boolean>>> sorted = q.toSortedList(); scores = new double[sorted.size()]; isCorrect = new boolean[sorted.size()]; for (int i = 0; i < sorted.size(); i++) { Pair<Double, Boolean> next = sorted.get(i).second(); scores[i] = next.first().doubleValue(); isCorrect[i] = next.second().booleanValue(); } } /** * how many correct do we have if we return the most confident num recall ones * */ public int numCorrect(int recall) { int correct = 0; for (int j = scores.length - 1; j >= scores.length - recall; j--) { if (isCorrect[j]) { correct++; } } return correct; } public int[] getAccCoverage() { int[] arr = new int[numSamples()]; for (int recall = 1; recall <= numSamples(); recall++) { arr[recall - 1] = numCorrect(recall); } return arr; } public String getDescription(int numDigits) { NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(numDigits); StringBuilder sb = new StringBuilder(); double confWeightedAccuracy = confidenceWeightedAccuracy(); sb.append("--- Accuracy Stats ---").append("\n"); sb.append("accuracy: ").append(nf.format(accuracy)).append(" (").append(correct).append("/").append(total).append(")\n"); sb.append("confidence weighted accuracy :").append(nf.format(confWeightedAccuracy)).append("\n"); sb.append("log-likelihood: ").append(logLikelihood).append("\n"); if (saveFile != null) { String f = saveFile + "-" + saveIndex; sb.append("saving accuracy info to ").append(f).append(".accuracy\n"); StringUtils.printToFile(f + ".accuracy", AccuracyStats.toStringArr(getAccCoverage())); saveIndex++; //sb.append("accuracy coverage: ").append(toStringArr(accrecall)).append("\n"); //sb.append("optimal accuracy coverage: ").append(toStringArr(optaccrecall)); } return sb.toString(); } @Override public String toString() { String accuracyType = null; if(scoreType == USE_ACCURACY) accuracyType = "classification_accuracy"; else if(scoreType == USE_LOGLIKELIHOOD) accuracyType = "log_likelihood"; else accuracyType = "unknown"; return "MultiClassAccuracyStats(" + accuracyType + ")" + scoreType + USE_ACCURACY + USE_LOGLIKELIHOOD; } }