package chipmunk.segmenter; import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.List; import java.util.Set; import marmot.util.Numerics; public class Scorer { Score precision = new Score(); Score recall = new Score(); private static class Boundary{ @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + position_; return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; Boundary other = (Boundary) obj; if (position_ != other.position_) return false; return true; } public Boundary(int position) { position_ = position; } int position_; } private Set<Boundary> getBrackets(SegmentationReading reading, int length) { Set<Boundary> brackets = new HashSet<>(); int start = 0; Iterator<String> segment_iterator = reading.getSegments().iterator(); while (segment_iterator.hasNext()) { String segment = segment_iterator.next(); int end = start + segment.length(); if (end < length) brackets.add(new Boundary(end)); start = end; } return brackets; } private static Set<Boundary> getBoundary(SegmentationResult candidate, int length) { Set<Boundary> brackets = new HashSet<>(); Iterator<Integer> index_iterator = candidate.getInputIndexes().iterator(); while (index_iterator.hasNext()) { int end = index_iterator.next(); if (end < length) brackets.add(new Boundary(end)); } return brackets; } public void eval(List<Set<Boundary>> predicted, List<Set<Boundary>> reference) { eval(predicted, reference, recall); eval(reference, predicted, precision); } public void eval(Collection<Word> words, Segmenter segmenter) { for (Word word : words) { SegmentationReading reading = segmenter.segment(word); Set<Boundary> brackets = getBrackets(reading, word.getLength()); List<Set<Boundary>> predicted = Collections.singletonList(brackets); List<Set<Boundary>> reference = new LinkedList<>(); for (SegmentationReading ref_reading : word.getReadings()) { reference.add(getBrackets(ref_reading, word.getLength())); } eval(reference, predicted); } } public String report() { double p = getPrecision(); double r = getRecall(); double f = getFscore(); return String.format("F1: %g Pr: %g / %g = %g Re:%g / %g = %g", f, precision.score, precision.total, p, recall.score, recall.total, r); } private static class Score { double score; double total; } void eval(Collection<Set<Boundary>> predicted, Collection<Set<Boundary>> reference, Score s) { double max_score = 0; double max_total = -1; for (Set<Boundary> ref : reference) { double total = ref.size(); for (Set<Boundary> pre : predicted) { Score m_tmp = new Score(); eval_single(pre, ref, m_tmp); if (max_total == -1 || m_tmp.score > max_score) { max_score = m_tmp.score; max_total = total; } } } // Macro-average: // max_score is proportion of correct boundaries, max_total is one max_total = 1; s.total += max_total; s.score += max_score; } private void eval_single(Set<Boundary> pre, Set<Boundary> ref, Score s) { int total = ref.size(); if (total == 0) { s.score = 1.0; s.total = 0.0; return; } Set<Boundary> intersect = new HashSet<>(pre); intersect.retainAll(ref); s.score = intersect.size() / (double) total; s.total = total; } public double getFscore() { double p = getPrecision(); double r = getRecall(); if (Numerics.approximatelyLesserEqual(p + r, 0.0)) { return 0.0; } double f = (2. * p * r) / (p + r); return f; } private double getRecall() { return 100. * recall.score / recall.total; } private double getPrecision() { return 100. * precision.score / precision.total; } public static SegmentationResult closest(SegmentationResult result, Collection<SegmentationResult> results, int length) { double best_score = Double.NEGATIVE_INFINITY; SegmentationResult best_result = null; Set<Boundary> brackets = getBoundary(result, length); for (SegmentationResult candidate : results) { Set<Boundary> other = getBoundary(candidate, length); Scorer scorer = new Scorer(); scorer.eval(Collections.singletonList(brackets), Collections.singletonList(other)); double score = scorer.getFscore(); assert !Double.isNaN(score); assert !Double.isInfinite(score); if (score > best_score) { best_result = candidate; best_score = score; } } return best_result; } }