// Copyright 2015 Thomas Müller // This file is part of MarMoT, which is licensed under GPLv3. package lemming.lemma; import java.util.Collection; import java.util.LinkedList; import java.util.List; import java.util.logging.Logger; import marmot.morph.io.SentenceReader; public class LemmaResult { private int num_tokens_; private List<LemmaError> errors_; private int num_oov_tokens_; public LemmaResult(int num_tokens, int num_oov_tokens, List<LemmaError> errors) { num_tokens_ = num_tokens; num_oov_tokens_ = num_oov_tokens; errors_ = errors; } public static LemmaResult test(Lemmatizer lemmatizer, String file) { return test(lemmatizer, LemmaInstance.getInstances(new SentenceReader(file))); } public static LemmaResult test(Lemmatizer lemmatizer, Collection<LemmaInstance> instances) { int total = 0; int num_oovs = 0; List<LemmaError> errors = new LinkedList<>(); for (LemmaInstance instance : instances) { String predicted_lemma = lemmatizer.lemmatize(instance); if (lemmatizer.isOOV(instance)) { num_oovs += instance.getCount(); } if (predicted_lemma == null || !predicted_lemma.equals(instance.getLemma())) { errors.add(new LemmaError(instance, predicted_lemma, lemmatizer.isOOV(instance))); } total += instance.getCount(); } return new LemmaResult(total, num_oovs, errors); } public static void logTest(Lemmatizer lemmatizer, String file, int limit) { LemmaResult result = test(lemmatizer, file); result.logAccuracy(); result.logErrors(limit); } private String format(int correct, int total) { double acc = correct * 100. / total; return String.format("%6d / %6d = %g", correct, total, acc); } public void logAccuracy() { int errors = 0; int oov_errors = 0; for (LemmaError error : errors_) { errors += error.getInstance().getCount(); if (error.isOOV()) { oov_errors += error.getInstance().getCount(); } } int correct = num_tokens_ - errors; int oov_correct = num_oov_tokens_ - oov_errors; Logger.getLogger(getClass().getName()).info( String.format("%s (OOV: %s)", format(correct, num_tokens_), format(oov_correct, num_oov_tokens_))); } public void logErrors(int limit) { StringBuilder sb = new StringBuilder(); sb.append("Errors:\n"); int number = 0; for (LemmaError error : errors_) { sb.append(error); sb.append('\n'); number++; if (limit >= 0 && number >= limit) { break; } } Logger.getLogger(getClass().getName()).info(sb.toString()); } public double getTokenAccuracy() { int correct = num_tokens_; for (LemmaError error : errors_) { correct -= error.getInstance().getCount(); } return correct * 100. / num_tokens_; } public static LemmaResult testGenerator(LemmatizerGenerator generator, String filename) { return testGenerator(generator, LemmaInstance.getInstances(filename)); } public static LemmaResult testGenerator(LemmatizerGenerator generator, List<LemmaInstance> instances) { int total = 0; int oov_total = 0; List<LemmaError> errors = new LinkedList<>(); for (LemmaInstance instance : instances) { if (generator.isOOV(instance)) { oov_total += instance.getCount(); } LemmaCandidateSet set = new LemmaCandidateSet(); generator.addCandidates(instance, set); if (!set.contains(instance.getLemma())) { errors.add(new LemmaError(instance, null, generator.isOOV(instance))); } total += instance.getCount(); } return new LemmaResult(total, oov_total, errors); } }