// Copyright 2013 Thomas Müller
// This file is part of MarMoT, which is licensed under GPLv3.
package marmot.morph.cmd;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import marmot.core.Sequence;
import marmot.morph.Word;
import marmot.morph.io.SentenceReader;
public class SoftEvaluator {
enum Mode {
Jaccard,
Acc,
Cosine,
Fscore
};
class Result {
double sim = 0;
private Mode _mode;
public Result(Mode mode) {
_mode = mode;
}
protected Map<String, Double> toVector(String pos, String morph) {
Map<String, Double> map = new HashMap<String, Double>();
map.put(pos, 1.0);
if (morph != null && !morph.equals("_")) {
for (String morpheme : morph.split("\\|")) {
map.put(morpheme, 1.0);
}
}
// Normalize
double norm = 0;
for (double d : map.values()) {
norm += d*d;
}
norm = Math.sqrt(norm);
for (Map.Entry<String, Double> entry : map.entrySet()) {
entry.setValue(entry.getValue() / norm);
}
return map;
}
protected double calc_jaccard(String gold_pos, String gold_morph,
String pred_pos, String pred_morph) {
return jaccard(toSet(gold_pos, gold_morph), toSet(pred_pos, pred_morph));
}
protected double calc_fscore(String gold_pos, String gold_morph,
String pred_pos, String pred_morph) {
return fscore(toSet(gold_pos, gold_morph), toSet(pred_pos, pred_morph));
}
private double fscore(Set<String> set, Set<String> set2) {
Set<String> intersection = new HashSet<String>(set);
intersection.retainAll(set2);
assert intersection.size() <= set2.size();
double p = intersection.size() / (double) set2.size();
double r = intersection.size() / (double) set.size();
assert (p <= 1.0);
assert (r <= 1.0);
if (p < 1e-10)
return 0;
if (r < 1e-10)
return 0;
double f = p*r*2 / (p + r);
assert (f <= 1.0 + 1e-5);
return 100 * 1.0 ;
}
private double jaccard(Set<String> set, Set<String> set2) {
Set<String> intersection = new HashSet<String>(set);
intersection.retainAll(set2);
Set<String> union = new HashSet<String>(set);
union.addAll(set2);
double score = intersection.size() / (double) union.size();
// if (score < 0.99) {
// System.err.println(set + " " + set2 + " " + intersection + " " + union + " " + score);
// }
return score;
}
private Set<String> toSet(String pos, String morph) {
Set<String> set = new HashSet<String>();
set.add("POS=" + pos);
if (morph != null && !morph.equals("_")) {
for (String morpheme : morph.split("\\|")) {
set.add(morpheme);
}
}
return set;
}
protected double cosineSim(Map<String, Double> vec, Map<String, Double> vec2) {
double sim = 0;
for (Map.Entry<String, Double> entry : vec.entrySet()) {
Double d2 = vec2.get(entry.getKey());
if (d2 != null) {
sim += entry.getValue() * d2;
}
}
return sim * 100.;
}
public void eval(Word gold_token, Word pred_token, double factor) {
String gold_pos = gold_token.getPosTag();
String pred_pos = pred_token.getPosTag();
String gold_morph = gold_token.getMorphTag();
String pred_morph = pred_token.getMorphTag();
double pair_sim = 0.0;
switch (_mode) {
case Acc:
pair_sim = calc_acc(gold_pos, gold_morph, pred_pos, pred_morph);
break;
case Jaccard:
pair_sim = calc_jaccard(gold_pos, gold_morph, pred_pos, pred_morph);
break;
case Cosine:
pair_sim = calc_cosineSim(gold_pos, gold_morph, pred_pos, pred_morph);
break;
case Fscore:
pair_sim = calc_fscore(gold_pos, gold_morph, pred_pos, pred_morph);
break;
default:
System.err.println("What?");
}
sim += pair_sim * factor;
}
private double calc_acc(String gold_pos, String gold_morph, String pred_pos,
String pred_morph) {
if (gold_pos.equals(pred_pos) && (gold_morph == pred_morph || gold_morph.equals(pred_morph))) {
return 100.0;
}
return 0.0;
}
private double calc_cosineSim(String gold_pos, String gold_morph,
String pred_pos, String pred_morph) {
return cosineSim(toVector(gold_pos, gold_morph), toVector(pred_pos, pred_morph));
}
public String report() {
return String.format("%s: %g", _mode.toString(), sim);
}
}
void eval(Sequence gold_sentence, Sequence pred_sentence, Result result, int num_tokens) {
assert gold_sentence.size() == pred_sentence.size();
for (int i = 0; i < gold_sentence.size(); i++) {
Word gold_token = (Word) gold_sentence.get(i);
Word pred_token = (Word) pred_sentence.get(i);
result.eval(gold_token, pred_token, 1. / num_tokens);
}
}
void eval(String pred_file, Result result) {
Iterable<Sequence> gold_sentences = new SentenceReader("form-index=1,tag-index=4,morph-index=6," + pred_file);
Iterable<Sequence> pred_sentences = new SentenceReader("form-index=1,tag-index=5,morph-index=7," + pred_file);
eval(gold_sentences, pred_sentences, result);
}
void eval(Iterable<Sequence> gold_sentences,
Iterable<Sequence> pred_sentences, Result result) {
int num_tokens = 0;
for (Sequence seq : gold_sentences) {
num_tokens += seq.size();
}
Iterator<Sequence> gold_iter = gold_sentences.iterator();
Iterator<Sequence> pred_iter = pred_sentences.iterator();
while (gold_iter.hasNext()) {
eval(gold_iter.next(), pred_iter.next(), result, num_tokens);
}
assert !pred_iter.hasNext();
}
public static void main(String[] args) {
SoftEvaluator evaluator = new SoftEvaluator();
Result result = evaluator.new Result(Mode.Acc);
evaluator.eval(args, result);
System.out.print(result.report());
result = evaluator.new Result(Mode.Jaccard);
evaluator.eval(args, result);
System.out.println(" " + result.report());
// result = evaluator.new Result(Mode.Cosine);
// evaluator.eval(args, result);
// System.out.println(result.report());
}
public void eval(String[] pred_files, Result result) {
for (String pred_file : pred_files) {
eval(pred_file, result);
}
result.sim = result.sim / pred_files.length;
}
}