package experimental.analyzer.simple; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.LinkedList; import java.util.List; import experimental.analyzer.AnalyzerInstance; import experimental.analyzer.AnalyzerResult; import experimental.analyzer.simple.SimpleAnalyzer.Mode; import marmot.util.Numerics; public class SimpleThresholdOptimizer { boolean simple_; private static class Entry implements Comparable<Entry>{ double prob; boolean active; @Override public int compareTo(Entry o) { return Double.compare(prob, o.prob); } } public SimpleThresholdOptimizer(boolean simple) { simple_ = simple; } public double findTreshold(SimpleAnalyzerModel model, Collection<AnalyzerInstance> instances, Mode mode) { if (simple_) { return simpleFindTreshold(model, instances, mode); } int num_tags = model.getNumTags(); int num_actives = 0; List<Entry> entries = new LinkedList<>(); for (AnalyzerInstance instance : instances) { SimpleAnalyzerInstance simple_instance = model .getInstance(instance); double[] scores = new double[num_tags]; model.score(simple_instance, scores); double sum = Double.NEGATIVE_INFINITY; if (mode == Mode.classifier) { for (double score : scores) { sum = Numerics.sumLogProb(score, sum); } } List<Entry> current_entries = new ArrayList<>(num_tags); for (int tag_index = 0; tag_index < num_tags; tag_index++) { Entry entry = new Entry(); entry.active = false; if (mode == Mode.classifier) entry.prob = Math.exp(scores[tag_index] - sum); else entry.prob = Math.exp(scores[tag_index] - Numerics.sumLogProb(scores[tag_index], 0)); assert entry.prob >= 0.0 && entry.prob < 1.0; current_entries.add(entry); } for (int tag_index : simple_instance.getTagIndexes()) { current_entries.get(tag_index).active = true; num_actives ++; } entries.addAll(current_entries); } Collections.sort(entries); // Correct at threshold = 0.0; int best_correct = 0; double best_threshold = 0.0; int correct = num_actives; for (Entry entry : entries) { if (entry.active) { correct --; } else { correct ++; } if (correct > best_correct) { best_correct = correct; best_threshold = entry.prob + 1e-10; } } System.err.println("Correct: " + best_correct); return best_threshold; } private double simpleFindTreshold(SimpleAnalyzerModel model, Collection<AnalyzerInstance> instances, Mode mode) { double[] thresholds = { 0.5, 0.35, 0.3, 0.25, 0.20, 0.15, 0.1, 0.05 }; System.err.println("Thresholds: " + Arrays.toString(thresholds)); double best_threshold = 0.0; double best_fscore = -1; for (double threshold : thresholds) { double fscore = getFscore(model, instances, threshold, mode); if (fscore > best_fscore) { best_fscore = fscore; best_threshold = threshold; } System.err.format("Threshold: %g F1-Score on train: %g\n", threshold, fscore); } return best_threshold; } private double getFscore(SimpleAnalyzerModel model, Collection<AnalyzerInstance> instances, double threshold, Mode tag_mode) { SimpleAnalyzer analyzer = new SimpleAnalyzer(model, threshold, tag_mode, null); AnalyzerResult result = AnalyzerResult.test(analyzer, instances); double fscore = result.getFscore(); return fscore; } }