/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.cogroo.cmdline.chunker2;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import opennlp.tools.util.Span;
import opennlp.tools.util.eval.EvaluationMonitor;
/**
* This listener will gather detailed information about the sample under evaluation and will
* allow detailed FMeasure for each outcome.
* <p>
* <b>Note:</b> Do not use this class, internal use only!
*/
public abstract class DetailedFMeasureForSizeListener<T> implements
EvaluationMonitor<T> {
private int samples = 0;
private Stats generalStats = new Stats();
private Map<Integer, Stats> statsForOutcome = new HashMap<Integer, Stats>();
protected abstract Span[] asSpanArray(T sample);
public void correctlyClassified(T reference, T prediction) {
samples++;
// add all true positives!
Span[] spans = asSpanArray(reference);
for (Span span : spans) {
addTruePositive(span.length());
}
}
public void missclassified(T reference, T prediction) {
samples++;
Span[] references = asSpanArray(reference);
Span[] predictions = asSpanArray(prediction);
Set<Span> refSet = new HashSet<Span>(Arrays.asList(references));
Set<Span> predSet = new HashSet<Span>(Arrays.asList(predictions));
for (Span ref : refSet) {
if (predSet.contains(ref)) {
addTruePositive(ref.length());
} else {
addFalseNegative(ref.length());
}
}
for (Span pred : predSet) {
if (!refSet.contains(pred)) {
addFalsePositive(pred.length());
}
}
}
private void addTruePositive(Integer type) {
Stats s = initStatsForOutcomeAndGet(type);
s.incrementTruePositive();
s.incrementTarget();
generalStats.incrementTruePositive();
generalStats.incrementTarget();
}
private void addFalsePositive(Integer type) {
Stats s = initStatsForOutcomeAndGet(type);
s.incrementFalsePositive();
generalStats.incrementFalsePositive();
}
private void addFalseNegative(Integer type) {
Stats s = initStatsForOutcomeAndGet(type);
s.incrementTarget();
generalStats.incrementTarget();
}
private Stats initStatsForOutcomeAndGet(Integer size) {
if (!statsForOutcome.containsKey(size)) {
statsForOutcome.put(size, new Stats());
}
return statsForOutcome.get(size);
}
private static final String PERCENT = "%\u00207.2f%%";
private static final String FORMAT = "%12s: precision: " + PERCENT
+ "; recall: " + PERCENT + "; F1: " + PERCENT + ".";
private static final String FORMAT_EXTRA = FORMAT
+ " [target: %3d; tp: %3d; fp: %3d]";
public String createReport() {
return createReport(Locale.getDefault());
}
public String createReport(Locale locale) {
StringBuilder ret = new StringBuilder();
int tp = generalStats.getTruePositives();
int found = generalStats.getFalsePositives() + tp;
ret.append("Evaluated " + samples + " samples with "
+ generalStats.getTarget() + " entities; found: " + found
+ " entities; correct: " + tp + ".\n");
ret.append(String.format(locale, FORMAT, "TOTAL",
zeroOrPositive(generalStats.getPrecisionScore() * 100),
zeroOrPositive(generalStats.getRecallScore() * 100),
zeroOrPositive(generalStats.getFMeasure() * 100)));
ret.append("\n");
SortedSet<Integer> set = new TreeSet<Integer>();
set.addAll(statsForOutcome.keySet());
for (Integer type : set) {
ret.append(String.format(locale, FORMAT_EXTRA, type,
zeroOrPositive(statsForOutcome.get(type).getPrecisionScore() * 100),
zeroOrPositive(statsForOutcome.get(type).getRecallScore() * 100),
zeroOrPositive(statsForOutcome.get(type).getFMeasure() * 100),
statsForOutcome.get(type).getTarget(), statsForOutcome.get(type)
.getTruePositives(), statsForOutcome.get(type)
.getFalsePositives()));
ret.append("\n");
}
return ret.toString();
}
@Override
public String toString() {
return createReport();
}
private double zeroOrPositive(double v) {
if (v < 0) {
return 0;
}
return v;
}
private class F1Comparator implements Comparator<String> {
public int compare(String o1, String o2) {
if (o1.equals(o2))
return 0;
double t1 = 0;
double t2 = 0;
if (statsForOutcome.containsKey(o1))
t1 += statsForOutcome.get(o1).getFMeasure();
if (statsForOutcome.containsKey(o2))
t2 += statsForOutcome.get(o2).getFMeasure();
t1 = zeroOrPositive(t1);
t2 = zeroOrPositive(t2);
if (t1 + t2 > 0d) {
if (t1 > t2)
return -1;
return 1;
}
return o1.compareTo(o2);
}
}
/**
* Store the statistics.
*/
private class Stats {
// maybe we could use FMeasure class, but it wouldn't allow us to get
// details like total number of false positives and true positives.
private int falsePositiveCounter = 0;
private int truePositiveCounter = 0;
private int targetCounter = 0;
public void incrementFalsePositive() {
falsePositiveCounter++;
}
public void incrementTruePositive() {
truePositiveCounter++;
}
public void incrementTarget() {
targetCounter++;
}
public int getFalsePositives() {
return falsePositiveCounter;
}
public int getTruePositives() {
return truePositiveCounter;
}
public int getTarget() {
return targetCounter;
}
/**
* Retrieves the arithmetic mean of the precision scores calculated for each
* evaluated sample.
*
* @return the arithmetic mean of all precision scores
*/
public double getPrecisionScore() {
int tp = getTruePositives();
int selected = tp + getFalsePositives();
return selected > 0 ? (double) tp / (double) selected : 0;
}
/**
* Retrieves the arithmetic mean of the recall score calculated for each
* evaluated sample.
*
* @return the arithmetic mean of all recall scores
*/
public double getRecallScore() {
int target = getTarget();
int tp = getTruePositives();
return target > 0 ? (double) tp / (double) target : 0;
}
/**
* Retrieves the f-measure score.
*
* f-measure = 2 * precision * recall / (precision + recall)
*
* @return the f-measure or -1 if precision + recall <= 0
*/
public double getFMeasure() {
if (getPrecisionScore() + getRecallScore() > 0) {
return 2 * (getPrecisionScore() * getRecallScore())
/ (getPrecisionScore() + getRecallScore());
} else {
// cannot divide by zero, return error code
return -1;
}
}
}
}