package experimental.analyzer.simple; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; import java.util.List; import experimental.analyzer.AnalyzerInstance; import experimental.analyzer.simple.SimpleAnalyzer.Mode; import marmot.util.Numerics; public class SimpleEvaluator { private static class Entry implements Comparable<Entry>{ double prob; boolean active; int num_tags; @Override public int compareTo(Entry o) { return -Double.compare(prob, o.prob); } } public void eval(SimpleAnalyzer analyzer, Collection<AnalyzerInstance> in_instances, List<Double> ambiguities) { Collection<AnalyzerInstance> instances = new LinkedList<>(); for (AnalyzerInstance instance : instances) { if (analyzer.isUnknown(instance)) { instances.add(instance); } } SimpleAnalyzerModel model = analyzer.getModel(); Mode mode = analyzer.getMode(); int num_tags = model.getNumTags(); List<Entry> entries = new LinkedList<>(); for (AnalyzerInstance instance : instances) { SimpleAnalyzerInstance simple_instance = model .getInstance(instance); double[] scores = new double[num_tags]; model.score(simple_instance, scores); double sum = Double.NEGATIVE_INFINITY; if (mode == Mode.classifier) { for (double score : scores) { sum = Numerics.sumLogProb(score, sum); } } int num_readings = simple_instance.getTagIndexes().size(); List<Entry> current_entries = new ArrayList<>(num_tags); for (int tag_index = 0; tag_index < num_tags; tag_index++) { Entry entry = new Entry(); entry.active = false; if (mode == Mode.classifier) entry.prob = Math.exp(scores[tag_index] - sum); else entry.prob = Math.exp(scores[tag_index] - Numerics.sumLogProb(scores[tag_index], 0)); assert entry.prob >= 0.0 && entry.prob < 1.0; entry.num_tags = num_readings; current_entries.add(entry); } for (int tag_index : simple_instance.getTagIndexes()) { current_entries.get(tag_index).active = true; } entries.addAll(current_entries); } Collections.sort(entries); double current_coverage = 0.0; double current_ambiguity = 0.0; for (Entry entry : entries) { if (entry.active) { current_coverage += 1. / (entry.num_tags * instances.size()); } double prev_ambiguity = current_ambiguity; current_ambiguity += 1. / (instances.size()); for (double ambiguity_value : ambiguities) { if (prev_ambiguity <= ambiguity_value && current_ambiguity >= ambiguity_value) { System.err.format("Amb: %g Cov: %g (Th: %g)\n", current_ambiguity, current_coverage, entry.prob); } } } } }