package experimental.analyzer; import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; import java.util.logging.Logger; public class AnalyzerResult { private int num_errors_; private int total_; private double macro_pre_; private double macro_rec_; private static class Error { private AnalyzerInstance instance_; private Collection<AnalyzerTag> missed_; private Collection<AnalyzerTag> toomuch_; public Error(AnalyzerInstance instance, Collection<AnalyzerTag> missed, Collection<AnalyzerTag> tomuch) { instance_ = instance; missed_ = missed; toomuch_ = tomuch; } @Override public String toString() { StringBuilder sb = new StringBuilder(String.format("%s:", instance_.getForm())); if (!missed_.isEmpty()) { sb.append(String.format(" missed: %s", missed_)); } if (!toomuch_.isEmpty()) { sb.append(String.format(" toomuch: %s", toomuch_)); } return sb.toString(); } }; private Collection<Error> errors_; private int label_correct_; private int label_total_; public AnalyzerResult() { this(0, 0, 0., 0., new LinkedList<Error>(), 0, 0); } public AnalyzerResult(int num_errors, int total, double macro_pre, double macro_rec, Collection<Error> errors, int label_correct, int label_total) { num_errors_ = num_errors; total_ = total; macro_pre_ = macro_pre; macro_rec_ = macro_rec; errors_ = errors; label_correct_ = label_correct; label_total_ = label_total; } public void increment(AnalyzerResult result) { num_errors_ += result.num_errors_; total_ += result.total_; macro_pre_ += result.macro_pre_; macro_rec_ += result.macro_rec_; errors_.addAll(result.errors_); label_correct_ += result.label_correct_; label_total_ += result.label_total_; } public static void logResult(Analyzer analyzer, String filename) { logResult(analyzer, filename, 100); } public static AnalyzerResult test(Analyzer analyzer, String filename) { return test(analyzer, AnalyzerInstance.getInstances(filename)); } public static AnalyzerResult test(Analyzer analyzer, Collection<AnalyzerInstance> instances) { AnalyzerResult result = new AnalyzerResult(); for (AnalyzerInstance instance : instances) { result.increment(test(analyzer, instance)); } return result; } public static AnalyzerResult test(Analyzer analyzer, AnalyzerInstance instance) { Collection<AnalyzerTag> actual = new HashSet<>(AnalyzerReading.toTags(analyzer.analyze(instance))); Collection<AnalyzerTag> expected = new HashSet<>(AnalyzerReading.toTags(instance.getReadings())); Collection<AnalyzerTag> missed = new LinkedList<>(); Collection<AnalyzerTag> toomuch = new LinkedList<>(); int correct = 0; for (AnalyzerTag tag : actual) { if (expected.contains(tag)) { correct ++; } else { toomuch.add(tag); } } for (AnalyzerTag tag : expected) { if (!actual.contains(tag)) { missed.add(tag); } } int label_total = analyzer.getNumTags(); int label_correct = label_total - (toomuch.size() + missed.size()); double macro_pre; if (actual.isEmpty()) macro_pre = 1.0; else macro_pre = correct / (double) actual.size(); double macro_rec = correct / (double) expected.size(); int total = 1; int num_errors = (correct == actual.size() && actual.size() == expected.size()) ? 0 : 1; Collection<Error> errors; if (missed.isEmpty() && toomuch.isEmpty()) { errors = Collections.emptyList(); } else { errors = Collections.singletonList(new Error(instance, missed, toomuch)); } return new AnalyzerResult(num_errors, total, macro_pre, macro_rec, errors, label_correct, label_total); } public void logFscore() { Logger logger = Logger.getLogger(getClass().getName()); double recall = macro_rec_ / total_; double prec = macro_pre_ / total_; double macro_fsc = getFscore(); logger.info(String.format("F1: %g Pr: %g Re %g", macro_fsc * 100., prec * 100., recall * 100.)); } public void logAcc() { Logger logger = Logger.getLogger(getClass().getName()); logger.info(String.format("Acc: %g", 100. * (total_ - num_errors_) / total_)); } public void logLabelAcc() { Logger logger = Logger.getLogger(getClass().getName()); logger.info(String.format("Label Acc: %d / %d = %g", label_correct_, label_total_, 100. * (label_correct_) / label_total_)); } public void logErrors(int num_errors) { logSubList(errors_, 0); } private void logSubList(Collection<Error> errors, int first) { Logger logger = Logger.getLogger(getClass().getName()); if (errors.size() > first) { errors = new LinkedList<Error>(errors).subList(0, first); } StringBuilder sb = new StringBuilder("Errors:\n"); for (Error error : errors) { sb.append(error.toString()); sb.append('\n'); } logger.info(sb.toString()); } public double getFscore() { double recall = macro_rec_ / total_; double prec = macro_pre_ / total_; double macro_fsc; if (recall + prec < 1e-5) { macro_fsc = 0.0; } else { macro_fsc = 2. * prec * recall / (prec + recall); } return macro_fsc; } public static void logResult(Analyzer analyzer, String filename, int num_errors) { AnalyzerResult result = test(analyzer, filename); result.logAcc(); result.logLabelAcc(); result.logFscore(); result.logErrors(num_errors); } }