package edu.stanford.nlp.sentiment; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.neural.rnn.RNNCoreAnnotations; import edu.stanford.nlp.neural.rnn.TopNGramRecord; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.IntCounter; import edu.stanford.nlp.trees.Tree; import edu.stanford.nlp.util.ConfusionMatrix; import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.StringUtils; import java.text.DecimalFormat; import java.text.NumberFormat; import java.util.List; import java.util.Set; /** * * @author John Bauer * @author Michael Haas <haas@cl.uni-heidelberg.de> (extracted this abstract class from Evaluate) */ public abstract class AbstractEvaluate { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(AbstractEvaluate.class); String[] equivalenceClassNames; int labelsCorrect; int labelsIncorrect; // the matrix will be [gold][predicted] int[][] labelConfusion; int rootLabelsCorrect; int rootLabelsIncorrect; int[][] rootLabelConfusion; IntCounter<Integer> lengthLabelsCorrect; IntCounter<Integer> lengthLabelsIncorrect; TopNGramRecord ngrams; // TODO: make this an option static final int NUM_NGRAMS = 5; int[][] equivalenceClasses; protected static final NumberFormat NF = new DecimalFormat("0.000000"); private RNNOptions op = null; public AbstractEvaluate(RNNOptions options) { this.op = options; this.reset(); } protected static void printConfusionMatrix(String name, int[][] confusion) { log.info(name + " confusion matrix"); ConfusionMatrix<Integer> confusionMatrix = new ConfusionMatrix<>(); confusionMatrix.setUseRealLabels(true); for (int i = 0; i < confusion.length; ++i) { for (int j = 0; j < confusion[i].length; ++j) { confusionMatrix.add(j, i, confusion[i][j]); } } log.info(confusionMatrix); } protected static double[] approxAccuracy(int[][] confusion, int[][] classes) { int[] correct = new int[classes.length]; int[] total = new int[classes.length]; double[] results = new double[classes.length]; for (int i = 0; i < classes.length; ++i) { for (int j = 0; j < classes[i].length; ++j) { for (int k = 0; k < classes[i].length; ++k) { correct[i] += confusion[classes[i][j]][classes[i][k]]; } for (int k = 0; k < confusion[classes[i][j]].length; ++k) { total[i] += confusion[classes[i][j]][k]; } } results[i] = ((double) correct[i]) / ((double) (total[i])); } return results; } protected static double approxCombinedAccuracy(int[][] confusion, int[][] classes) { int correct = 0; int total = 0; for (int[] aClass : classes) { for (int j = 0; j < aClass.length; ++j) { for (int k = 0; k < aClass.length; ++k) { correct += confusion[aClass[j]][aClass[k]]; } for (int k = 0; k < confusion[aClass[j]].length; ++k) { total += confusion[aClass[j]][k]; } } } return ((double) correct) / ((double) (total)); } public void reset() { labelsCorrect = 0; labelsIncorrect = 0; labelConfusion = new int[op.numClasses][op.numClasses]; rootLabelsCorrect = 0; rootLabelsIncorrect = 0; rootLabelConfusion = new int[op.numClasses][op.numClasses]; lengthLabelsCorrect = new IntCounter<>(); lengthLabelsIncorrect = new IntCounter<>(); equivalenceClasses = op.equivalenceClasses; equivalenceClassNames = op.equivalenceClassNames; if (op.testOptions.ngramRecordSize > 0) { ngrams = new TopNGramRecord(op.numClasses, op.testOptions.ngramRecordSize, op.testOptions.ngramRecordMaximumLength); } else { ngrams = null; } } public void eval(List<Tree> trees) { this.populatePredictedLabels(trees); for (Tree tree : trees) { eval(tree); } } public void eval(Tree tree) { //cag.forwardPropagateTree(tree); countTree(tree); countRoot(tree); countLengthAccuracy(tree); if (ngrams != null) { ngrams.countTree(tree); } } protected int countLengthAccuracy(Tree tree) { if (tree.isLeaf()) { return 0; } Integer gold = RNNCoreAnnotations.getGoldClass(tree); Integer predicted = RNNCoreAnnotations.getPredictedClass(tree); int length; if (tree.isPreTerminal()) { length = 1; } else { length = 0; for (Tree child : tree.children()) { length += countLengthAccuracy(child); } } if (gold >= 0) { if (gold.equals(predicted)) { lengthLabelsCorrect.incrementCount(length); } else { lengthLabelsIncorrect.incrementCount(length); } } return length; } protected void countTree(Tree tree) { if (tree.isLeaf()) { return; } for (Tree child : tree.children()) { countTree(child); } Integer gold = RNNCoreAnnotations.getGoldClass(tree); Integer predicted = RNNCoreAnnotations.getPredictedClass(tree); if (gold >= 0) { if (gold.equals(predicted)) { labelsCorrect++; } else { labelsIncorrect++; } labelConfusion[gold][predicted]++; } } protected void countRoot(Tree tree) { Integer gold = RNNCoreAnnotations.getGoldClass(tree); Integer predicted = RNNCoreAnnotations.getPredictedClass(tree); if (gold >= 0) { if (gold.equals(predicted)) { rootLabelsCorrect++; } else { rootLabelsIncorrect++; } rootLabelConfusion[gold][predicted]++; } } public double exactNodeAccuracy() { return (double) labelsCorrect / ((double) (labelsCorrect + labelsIncorrect)); } public double exactRootAccuracy() { return (double) rootLabelsCorrect / ((double) (rootLabelsCorrect + rootLabelsIncorrect)); } public Counter<Integer> lengthAccuracies() { Set<Integer> keys = Generics.newHashSet(); keys.addAll(lengthLabelsCorrect.keySet()); keys.addAll(lengthLabelsIncorrect.keySet()); Counter<Integer> results = new ClassicCounter<>(); for (Integer key : keys) { results.setCount(key, lengthLabelsCorrect.getCount(key) / (lengthLabelsCorrect.getCount(key) + lengthLabelsIncorrect.getCount(key))); } return results; } public void printLengthAccuracies() { Counter<Integer> accuracies = lengthAccuracies(); Set<Integer> keys = Generics.newTreeSet(); keys.addAll(accuracies.keySet()); log.info("Label accuracy at various lengths:"); for (Integer key : keys) { log.info(StringUtils.padLeft(Integer.toString(key), 4) + ": " + NF.format(accuracies.getCount(key))); } } public void printSummary() { log.info("EVALUATION SUMMARY"); log.info("Tested " + (labelsCorrect + labelsIncorrect) + " labels"); log.info(" " + labelsCorrect + " correct"); log.info(" " + labelsIncorrect + " incorrect"); log.info(" " + NF.format(exactNodeAccuracy()) + " accuracy"); log.info("Tested " + (rootLabelsCorrect + rootLabelsIncorrect) + " roots"); log.info(" " + rootLabelsCorrect + " correct"); log.info(" " + rootLabelsIncorrect + " incorrect"); log.info(" " + NF.format(exactRootAccuracy()) + " accuracy"); printConfusionMatrix("Label", labelConfusion); printConfusionMatrix("Root label", rootLabelConfusion); if (equivalenceClasses != null && equivalenceClassNames != null) { double[] approxLabelAccuracy = approxAccuracy(labelConfusion, equivalenceClasses); for (int i = 0; i < equivalenceClassNames.length; ++i) { log.info("Approximate " + equivalenceClassNames[i] + " label accuracy: " + NF.format(approxLabelAccuracy[i])); } log.info("Combined approximate label accuracy: " + NF.format(approxCombinedAccuracy(labelConfusion, equivalenceClasses))); double[] approxRootLabelAccuracy = approxAccuracy(rootLabelConfusion, equivalenceClasses); for (int i = 0; i < equivalenceClassNames.length; ++i) { log.info("Approximate " + equivalenceClassNames[i] + " root label accuracy: " + NF.format(approxRootLabelAccuracy[i])); } log.info("Combined approximate root label accuracy: " + NF.format(approxCombinedAccuracy(rootLabelConfusion, equivalenceClasses))); log.info(); } if (op.testOptions.ngramRecordSize > 0) { log.info(ngrams); } if (op.testOptions.printLengthAccuracies) { printLengthAccuracies(); } } /** * Sets the predicted sentiment label for all trees given. * * This method sets the {@link RNNCoreAnnotations.PredictedClass} annotation * for all nodes in all trees. * * @param trees List of Trees to be annotated */ public abstract void populatePredictedLabels(List<Tree> trees); }