package edu.berkeley.cs.nlp.ocular.eval;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import tberg.murphy.counter.Counter;
import tberg.murphy.counter.CounterMap;
import edu.berkeley.cs.nlp.ocular.data.textreader.Charset;
import edu.berkeley.cs.nlp.ocular.eval.MarkovEditDistanceComputer.EditDistanceParams;
import tberg.murphy.tuple.Pair;
import tberg.murphy.util.Iterators;
/**
* @author Taylor Berg-Kirkpatrick (tberg@eecs.berkeley.edu)
*/
public class Evaluator {
public static class EvalSuffStats {
private double score;
private double docCount;
public EvalSuffStats() {
this.score = 0;
this.docCount = 0;
}
public EvalSuffStats(double score, double docCount) {
this.score = score;
this.docCount = docCount;
}
public EvalSuffStats(Pair<Integer,Integer> numerDenom) {
this.score = ((double)numerDenom.getFirst())/((double)numerDenom.getSecond());
this.docCount = 1;
}
public double getScore() {
return score;
}
public double getDocCount() {
return docCount;
}
public void increment(EvalSuffStats suffStats) {
double nextDocCount = getDocCount() + suffStats.getDocCount();
double nextScore = ((getDocCount() / nextDocCount) * getScore()) + ((suffStats.getDocCount() / nextDocCount) * suffStats.getScore());
this.score = nextScore;
this.docCount = nextDocCount;
}
}
public static String renderEval(Map<String,EvalSuffStats> evals) {
StringBuffer buf = new StringBuffer();
List<String> evalTypes = new ArrayList<String>(evals.keySet());
Collections.sort(evalTypes);
for (String evalType : evalTypes) {
buf.append(evalType+": "+evals.get(evalType).getScore()+"\n");
}
return buf.toString();
}
public static Map<String,EvalSuffStats> getUnsegmentedEval(List<String>[] guessChars, List<String>[] goldChars, boolean charIncludesDiacritic) {
Map<String,EvalSuffStats> evals = new HashMap<String,EvalSuffStats>();
evals.put("CER, keep punc, allow f->s", new EvalSuffStats(getCERSuffStats(guessChars, goldChars, false, true, charIncludesDiacritic)));
evals.put("CER, keep punc ", new EvalSuffStats(getCERSuffStats(guessChars, goldChars, false, false, charIncludesDiacritic)));
evals.put("CER, remove punc, allow f->s", new EvalSuffStats(getCERSuffStats(guessChars, goldChars, true, true, charIncludesDiacritic)));
evals.put("CER, remove punc", new EvalSuffStats(getCERSuffStats(guessChars, goldChars, true, false, charIncludesDiacritic)));
evals.put("WER, keep punc, allow f->s", new EvalSuffStats(getWERSuffStats(guessChars, goldChars, false, true)));
evals.put("WER, keep punc ", new EvalSuffStats(getWERSuffStats(guessChars, goldChars, false, false)));
evals.put("WER, remove punc, allow f->s", new EvalSuffStats(getWERSuffStats(guessChars, goldChars, true, true)));
evals.put("WER, remove punc", new EvalSuffStats(getWERSuffStats(guessChars, goldChars, true, false)));
return evals;
}
public static Pair<Integer,Integer> getCERSuffStats(List<String>[] guessChars, List<String>[] goldChars, boolean removePunc, boolean allowFSConfusion, boolean charIncludesDiacritic) {
String guessStr = fullyNormalize(guessChars, removePunc);
String goldStr = fullyNormalize(goldChars, removePunc);
Form guessForm = Form.charsAsGlyphs(guessStr, charIncludesDiacritic);
Form goldForm = Form.charsAsGlyphs(goldStr, charIncludesDiacritic);
EditDistanceParams params = EditDistanceParams.getStandardParams(guessForm, goldForm, allowFSConfusion);
MarkovEditDistanceComputer medc = new MarkovEditDistanceComputer(params);
AlignedFormPair alignedPair = medc.runEditDistance();
return Pair.makePair((int)alignedPair.cost, goldForm.length());
}
public static Pair<Integer,Integer> getWERSuffStats(List<String>[] guessChars, List<String>[] goldChars, boolean removePunc, boolean allowFSConfusion) {
AlignedFormPair alignedPair = getWordAlignments(guessChars, goldChars, removePunc, allowFSConfusion);
return Pair.makePair((int)alignedPair.cost, alignedPair.trg.length());
}
public static String errorAnalyze(List<String>[] guessChars, List<String>[] goldChars, boolean removePunc, boolean allowFSConfusion) {
AlignedFormPair alignedPair = getWordAlignments(guessChars, goldChars, removePunc, allowFSConfusion);
assert alignedPair != null;
assert alignedPair.ops != null;
CounterMap<String,String> recallConfusions = new CounterMap<String,String>();
Counter<String> recallErrors = new Counter<String>();
int guessIndex = 0;
int goldIndex = 0;
int insertions = 0;
int deletions = 0;
int isolatedSubstitutions = 0;
int nonIsolatedSubstitutions = 0;
for (int i = 0; i < alignedPair.ops.size(); i++) {
Operation op = alignedPair.ops.get(i);
switch (op) {
case EQUAL:
guessIndex++;
goldIndex++;
break;
case SUBST:
if ((i == 0 || alignedPair.ops.get(i-1) == Operation.EQUAL) &&
(i == alignedPair.ops.size() - 1 || alignedPair.ops.get(i+1) == Operation.EQUAL)) {
isolatedSubstitutions++;
recallConfusions.incrementCount(alignedPair.trg.charAt(goldIndex).toString(), alignedPair.src.charAt(guessIndex).toString(), 1.0);
} else {
nonIsolatedSubstitutions++;
}
guessIndex++;
goldIndex++;
break;
case INSERT:
insertions++;
goldIndex++;
break;
case DELETE:
deletions++;
guessIndex++;
break;
default: throw new RuntimeException("Unrecognized operation: " + op);
}
}
for (String word : recallConfusions.keySet()) {
recallErrors.incrementCount(word, recallConfusions.getCount(word));
}
String analysis = isolatedSubstitutions + " isolated substitutions, " + nonIsolatedSubstitutions + " non-isolated substitutions, " +
insertions + " insertions, " + deletions + " deletions\n";
int[] wordLengthErrorCounts = new int[10];
int[] editDistancePerWordCounts = new int[10];
for (Pair<String,String> wordPair : Iterators.able(recallConfusions.getPairIterator())) {
int count = (int)recallConfusions.getCount(wordPair.getFirst(), wordPair.getSecond());
String goldStr = wordPair.getFirst();
String guessStr = wordPair.getSecond();
int goldLen = Math.min(10, goldStr.length());
wordLengthErrorCounts[goldLen-1] += 1;
Form guessForm = Form.charsAsGlyphs(guessStr);
Form goldForm = Form.charsAsGlyphs(goldStr);
EditDistanceParams params = EditDistanceParams.getStandardParams(guessForm, goldForm, allowFSConfusion);
MarkovEditDistanceComputer medc = new MarkovEditDistanceComputer(params);
int cost = (int)medc.runEditDistance().cost;
cost = Math.min(10, cost);
assert cost > 0;
editDistancePerWordCounts[cost-1] += count;
}
analysis += "Errors by word length (starts at 1): " + Arrays.toString(wordLengthErrorCounts) + "\n";
analysis += "Edit distance per error (starts at 1): " + Arrays.toString(editDistancePerWordCounts) + "\n";
analysis += "Most frequent missed words\n";
int numPrinted = 0;
for (String word : Iterators.able(recallErrors.asPriorityQueue())) {
analysis += " " + word + ": " + recallErrors.getCount(word) + "\n";
numPrinted++;
if (numPrinted >= 20) {
analysis += " ..." + recallErrors.size() + " total word types missed";
break;
}
}
return analysis;
}
public static AlignedFormPair getWordAlignments(List<String>[] guessChars, List<String>[] goldChars, boolean splitOutPunc, boolean allowFSConfusion) {
String guessStr = fullyNormalize(guessChars, splitOutPunc);
String goldStr = fullyNormalize(goldChars, splitOutPunc);
Form guessForm = Form.wordsAsGlyphs(Arrays.asList(guessStr.split("\\s+")));
Form goldForm = Form.wordsAsGlyphs(Arrays.asList(goldStr.split("\\s+")));
EditDistanceParams params = EditDistanceParams.getStandardParams(guessForm, goldForm, allowFSConfusion);
MarkovEditDistanceComputer medc = new MarkovEditDistanceComputer(params);
AlignedFormPair alignedPair = medc.runEditDistance();
assert alignedPair.trg.length() == goldForm.length();
return alignedPair;
}
private static String fullyNormalize(List<String>[] chars, boolean splitOutPunc) {
// String str = convertToOneLineRemoveDashes(chars);
String str = convertToOneLine(chars);
// System.out.println(str);
//str = str.replaceAll("\\|", "s");
if (splitOutPunc) {
str = splitOutPunc(str);
}
str = normalizeWhitespace(str);
// System.out.println("Normalized: <begin>" + str + "<end>");
// System.out.println(str);
return str;
}
@SuppressWarnings("unused")
private static String convertToOneLineRemoveDashes(List<String>[] chars) {
String str = "";
for (List<String> line : chars) {
String lineString = "";
for (int i = 0; i < line.size(); i++) {
lineString += line.get(i);
}
lineString = lineString.trim();
if (str.endsWith("-")) {
str = str.substring(0,str.length()-1) + lineString;
} else {
str = str.substring(0,str.length()) + " " + lineString;
}
}
return str;
}
private static String convertToOneLine(List<String>[] chars) {
String str = "";
for (List<String> line : chars) {
for (int i = 0; i < line.size(); i++) {
str += line.get(i);
}
str += " ";
}
return str;
}
private static String normalizeWhitespace(String str) {
return str.trim().replaceAll("\\s+", " ");
}
private static String splitOutPunc(String str) {
StringBuffer buf = new StringBuffer();
for (String c: Charset.readNormalizeCharacters(str)) {
if (!Charset.isPunctuationChar(c)) buf.append(c);
}
return normalizeWhitespace(buf.toString());
}
public static void main(String[] args) {
String guess = "this is a longer, more nuanced test of the system";
String gold = "tis is a logner, more nunced test of the sstem";
System.out.println(renderEval(getUnsegmentedEval(convertToLines(guess), convertToLines(gold), true)));
String guess2 = "deletion deletion this is a longer, more nuanced test of the system";
String gold2 = "tis is a logner, more nunced test of the sstem insertion insertion";
System.out.println(renderEval(getUnsegmentedEval(convertToLines(guess2), convertToLines(gold2), true)));
String guess3 = "this is a longer, more nuanced test of the system deletion deletion";
String gold3 = "insertion insertion tis is a logner, more nunced test of the sstem";
System.out.println(renderEval(getUnsegmentedEval(convertToLines(guess3), convertToLines(gold3), true)));
String guess4 = "this is \n a longer, more\n nuan-\nced \n test of the system deletion deletion";
String gold4 = "this is a lon- \n ger, more nuanced test of the system deletion deletion";
System.out.println(renderEval(getUnsegmentedEval(convertToLines(guess4), convertToLines(gold4), true)));
String guess5 = "this is a longer, more nuanced t\\'est of the system";
String gold5 = "tis is a logner, more nunced t\\'est of the sstem";
System.out.println(renderEval(getUnsegmentedEval(convertToLines(guess5), convertToLines(gold5), true)));
}
private static List<String>[] convertToLines(String rawStr) {
String[] lines = rawStr.split("\n");
@SuppressWarnings("unchecked")
List<String>[] charsPerLine = new List[lines.length];
for (int i = 0; i < lines.length; i++) {
charsPerLine[i] = Arrays.asList(split(lines[i]));
}
return charsPerLine;
}
public static String[] split(String str) {
String[] result = new String[str.length()];
for (int i=0; i<result.length; ++i) {
result[i] = str.substring(i, i+1);
}
return result;
}
}