package is2.util;
import is2.data.Parse;
import is2.data.SentenceData09;
import is2.io.CONLLReader09;
import is2.parser.Parser;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Map.Entry;
import org.apache.commons.math.MathException;
import org.apache.commons.math.stat.inference.TestUtils;
public class Evaluator {
public static void main(String[] args) {
Options options = new Options(args);
if (options.eval && options.significant1 == null) {
Results r = evaluate(options.goldfile, options.outfile);
} else if (options.significant1 != null && options.significant2 != null) {
Parser.out.println("compare1 " + options.significant1);
Parser.out.println("compare2 " + options.significant2);
Parser.out.println("gold " + options.goldfile);
Results r1 = evaluate(options.goldfile, options.significant1, false);
Parser.out.println("file 1 done ");
Results r2 = evaluate(options.goldfile, options.significant2, false);
double[] s1 = new double[r1.correctHead.size()];
double[] s2 = new double[r1.correctHead.size()];
for (int k = 0; k < r1.correctHead.size(); k++) {
s1[k] = r1.correctHead.get(k);
s2[k] = r2.correctHead.get(k);
}
try {
double p = TestUtils.pairedTTest(s1, s2);
Parser.out.print("significant to " + p);
} catch (IllegalArgumentException | MathException e) {
e.printStackTrace();
}
// significant(options.significant1, options.significant2) ;
} else if (options.significant1 != null) {
Results r = evaluate(options.goldfile, options.outfile, true);
// significant(options.significant1, options.significant2) ;
}
}
/**
*
* @param act_file
* @param pred_file
* @param what top, pos, length, mor
*/
public static void evaluateTagger(String act_file, String pred_file, String what) {
CONLLReader09 goldReader = new CONLLReader09(act_file);
CONLLReader09 predictedReader = new CONLLReader09();
predictedReader.startReading(pred_file);
HashMap<String, Integer> errors = new HashMap<>();
HashMap<String, StringBuffer> words = new HashMap<>();
int total = 0, numsent = 0, corrT = 0;
SentenceData09 goldInstance = goldReader.getNext();
SentenceData09 predInstance = predictedReader.getNext();
HashMap<Integer, int[]> correctL = new HashMap<>();
HashMap<String, int[]> pos = new HashMap<>();
HashMap<String, int[]> mor = new HashMap<>();
float correctM = 0, allM = 0;
while (goldInstance != null) {
int instanceLength = goldInstance.length();
if (instanceLength != predInstance.length()) {
Parser.out.println("Lengths do not match on sentence " + numsent);
}
String gold[] = goldInstance.gpos;
String pred[] = predInstance.ppos;
String goldM[] = goldInstance.ofeats;
String predM[] = predInstance.pfeats;
// NOTE: the first item is the root info added during nextInstance(), so we skip it.
for (int i = 1; i < instanceLength; i++) {
int[] cwr = correctL.get(i);
if (cwr == null) {
cwr = new int[2];
correctL.put(i, cwr);
}
cwr[1]++;
int[] correctPos = pos.get(gold[i]);
if (correctPos == null) {
correctPos = new int[2];
pos.put(gold[i], correctPos);
}
correctPos[1]++;
int[] correctMor = mor.get(goldM[i]);
if (correctMor == null) {
correctMor = new int[2];
mor.put(goldM[i], correctMor);
}
if ((goldM[i].equals("_") && predM[i] == null) || goldM[i].equals(predM[i])) {
correctM++;
correctMor[0]++;
}
allM++;
correctMor[1]++;
if (gold[i].equals(pred[i])) {
corrT++;
cwr[0]++;
correctPos[0]++;
} else {
String key = "gold: '" + gold[i] + "' pred: '" + pred[i] + "'";
Integer cnt = errors.get(key);
StringBuffer errWrd = words.get(key);
if (cnt == null) {
errors.put(key, 1);
words.put(key, new StringBuffer().append(goldInstance.forms[i]));
} else {
errors.put(key, cnt + 1);
errWrd.append(" ").append(goldInstance.forms[i]);
}
}
}
total += instanceLength - 1; // Subtract one to not score fake root token
numsent++;
goldInstance = goldReader.getNext();
predInstance = predictedReader.getNext();
}
// Parser.out.println("error gold:"+goldPos[i]+" pred:"+predPos[i]+" "+goldInstance.forms[i]+" snt "+numsent+" i:"+i);
ArrayList<Entry<String, Integer>> opsl = new ArrayList<>();
for (Entry<String, Integer> e : errors.entrySet()) {
opsl.add(e);
}
Collections.sort(opsl, new Comparator<Entry<String, Integer>>() {
@Override
public int compare(Entry<String, Integer> o1,
Entry<String, Integer> o2) {
return o1.getValue() == o2.getValue() ? 0 : o1.getValue() > o2.getValue() ? -1 : 1;
}
});
int cnt = 0;
if (what.contains("top")) {
Parser.out.println("top most errors:");
for (Entry<String, Integer> e : opsl) {
cnt++;
if (e.getValue() > 10) {
Parser.out.println(e.getKey() + " " + e.getValue() + " context: " + words.get(e.getKey()));
}
}
}
if (what.contains("length")) {
for (int k = 0; k < 60; k++) {
int[] cwr = correctL.get(k);
if (cwr == null) {
continue;
}
Parser.out.print(k + ":" + cwr[0] + ":" + cwr[1] + ":" + (((float) Math.round(10000 * (float) ((float) cwr[0]) / (float) cwr[1])) / 100) + " ");
}
Parser.out.println();
}
if (what.contains("pos")) {
for (Entry<String, int[]> e : pos.entrySet()) {
Parser.out.print(e.getKey() + ":" + e.getValue()[0] + ":" + e.getValue()[1] + ":"
+ (((float) Math.round(10000 * ((float) e.getValue()[0]) / ((float) e.getValue()[1]))) / 100) + " ");
}
Parser.out.print("");
}
Parser.out.println();
if (what.contains("mor")) {
for (Entry<String, int[]> e : mor.entrySet()) {
Parser.out.print(e.getKey() + ":" + e.getValue()[0] + ":" + e.getValue()[1] + ":"
+ (((float) Math.round(10000 * ((float) e.getValue()[0]) / ((float) e.getValue()[1]))) / 100) + " ");
}
Parser.out.print("");
}
Parser.out.println("\nTokens: " + total + " Correct: " + corrT + " " + (float) corrT / total + " Correct M.:" + (int) correctM + " morphology " + (correctM / total));
}
public static int errors(SentenceData09 s, boolean uas) {
int errors = 0;
for (int k = 1; k < s.length(); k++) {
if (s.heads[k] != s.pheads[k] && (uas || !s.labels[k].equals(s.plabels[k]))) {
errors++;
}
}
return errors;
}
public static int errors(SentenceData09 s1, SentenceData09 s2, HashMap<String, Integer> r1, HashMap<String, Integer> r2) {
int errors = 0;
for (int k = 1; k < s1.length(); k++) {
if (s1.heads[k] != s1.pheads[k] || (!s1.labels[k].equals(s1.plabels[k]))) {
if (s2.heads[k] != s2.pheads[k] || (!s2.labels[k].equals(s2.plabels[k]))) {
// equal do nothing
} else {
Integer cnt = r1.get(s1.labels[k]);
if (cnt == null) {
cnt = 0;
}
cnt++;
r1.put(s1.labels[k], cnt);
}
}
if (s2.heads[k] != s2.pheads[k] || (!s2.labels[k].equals(s2.plabels[k]))) {
if (s1.heads[k] != s1.pheads[k] || (!s1.labels[k].equals(s1.plabels[k]))) {
// equal do nothing
} else {
Integer cnt = r2.get(s2.labels[k]);
if (cnt == null) {
cnt = 0;
}
cnt++;
r2.put(s2.labels[k], cnt);
}
}
}
return errors;
}
public static final String PUNCT = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~";
public static class Results {
public int total;
public int corr;
public float las;
public float ula;
public float lpas;
public float upla;
ArrayList<Double> correctHead;
}
public static Results evaluate(String act_file, String pred_file) {
return evaluate(act_file, pred_file, true);
}
public static Results evaluate(String act_file, String pred_file, boolean printEval) {
return evaluate(act_file, pred_file, printEval, false);
}
public static Results evaluate(String act_file, String pred_file, boolean printEval, boolean sig) {
CONLLReader09 goldReader = new CONLLReader09(act_file, -1);
CONLLReader09 predictedReader = new CONLLReader09(pred_file, -1);
int total = 0, corr = 0, corrL = 0, Ptotal = 0, Pcorr = 0, PcorrL = 0, BPtotal = 0, BPcorr = 0, BPcorrL = 0, corrLableAndPos = 0, corrHeadAndPos = 0;
int corrLableAndPosP = 0, corrHeadAndPosP = 0, corrLableAndPosC = 0;
int numsent = 0, corrsent = 0, corrsentL = 0, Pcorrsent = 0, PcorrsentL = 0, sameProj = 0;
int proj = 0, nonproj = 0, pproj = 0, pnonproj = 0, nonProjOk = 0, nonProjWrong = 0;
int corrOne = 0;
int correctChnWoPunc = 0, correctLChnWoPunc = 0, CPtotal = 0;
SentenceData09 goldInstance = goldReader.getNext();
SentenceData09 predInstance = predictedReader.getNext();
HashMap<String, Integer> label = new HashMap<>();
HashMap<String, Integer> labelCount = new HashMap<>();
HashMap<String, Integer> labelCorrect = new HashMap<>();
HashMap<String, Integer> falsePositive = new HashMap<>();
// does the node have the correct head?
ArrayList<Double> correctHead = new ArrayList<>();
while (goldInstance != null) {
int instanceLength = goldInstance.length();
if (instanceLength != predInstance.length()) {
Parser.out.println("Lengths do not match on sentence " + numsent);
}
int[] goldHeads = goldInstance.heads;
String[] goldLabels = goldInstance.labels;
int[] predHeads = predInstance.pheads;
String[] predLabels = predInstance.plabels;
boolean whole = true;
boolean wholeL = true;
boolean Pwhole = true;
boolean PwholeL = true;
int tlasS = 0, totalS = 0, corrLabels = 0, XLabels = 0;
// NOTE: the first item is the root info added during nextInstance(), so we skip it.
int punc = 0, bpunc = 0, totalChnWoPunc = 0;
for (int i = 1; i < instanceLength; i++) {
Parse p = new Parse(predHeads.length);
for (int k = 0; k < p.heads.length; k++) {
p.heads[k] = (short) predHeads[k];
}
Parse g = new Parse(predHeads.length);
for (int k = 0; k < g.heads.length; k++) {
g.heads[k] = (short) goldHeads[k];
}
{
Integer count = labelCount.get(goldLabels[i]);
if (count == null) {
count = 0;
}
count++;
labelCount.put(goldLabels[i], count);
if (goldLabels[i].equals(predLabels[i])) {
Integer correct = labelCorrect.get(goldLabels[i]);
if (correct == null) {
correct = 0;
}
correct++;
labelCorrect.put(goldLabels[i], correct);
} else {
Integer fp = falsePositive.get(predLabels[i]);
if (fp == null) {
fp = 0;
}
fp++;
falsePositive.put(predLabels[i], fp);
}
}
if (goldLabels[i].startsWith("PMOD")) {
XLabels++;
}
boolean tlas = false;
if (predHeads[i] == goldHeads[i]) {
corr++;
if (goldInstance.gpos[i].equals(predInstance.ppos[i])) {
corrHeadAndPos++;
}
if (goldLabels[i].equals(predLabels[i])) {
corrL++;
// if (predLabels[i].startsWith("PMOD"))
corrLabels++;
// else correctHead.add(0);
if (goldInstance.gpos[i].equals(predInstance.ppos[i])) {
tlasS++;
tlas = true;
corrLableAndPos++;
}
} else {
// correctHead.add(0);
// Parser.out.println(numsent+" error gold "+goldLabels[i]+" "+predLabels[i]+" head "+goldHeads[i]+" child "+i);
wholeL = false;
}
} else {
//correctHead.add(0);
// Parser.out.println(numsent+"error gold "+goldLabels[i]+" "+predLabels[i]+" head "+goldHeads[i]+" child "+i);
whole = false;
wholeL = false;
Integer count = label.get(goldLabels[i]);
if (count == null) {
count = 0;
}
count++;
label.put(goldLabels[i], count);
int d = Math.abs(goldInstance.heads[i] - i);
}
if (!("!\"#$%&''()*+,-./:;<=>?@[\\]^_{|}~``".contains(goldInstance.forms[i]))) {
if (predHeads[i] == goldHeads[i]) {
BPcorr++;
if (goldLabels[i].equals(predLabels[i])) {
BPcorrL++;
} else {
// Parser.out.println(numsent+" error gold "+goldLabels[i]+" "+predLabels[i]+" head "+goldHeads[i]+" child "+i);
// PwholeL = false;
}
} else {
// Parser.out.println(numsent+"error gold "+goldLabels[i]+" "+predLabels[i]+" head "+goldHeads[i]+" child "+i);
//Pwhole = false; wholeL = false;
}
} else {
bpunc++;
}
if (!(",.:''``".contains(goldInstance.forms[i]))) {
if (predHeads[i] == goldHeads[i]) {
if (goldInstance.gpos[i].equals(predInstance.ppos[i])) {
corrHeadAndPosP++;
}
Pcorr++;
if (goldLabels[i].equals(predLabels[i])) {
PcorrL++;
if (goldInstance.gpos[i].equals(predInstance.ppos[i])) {
corrLableAndPosP++;
}
} else {
// Parser.out.println(numsent+" error gold "+goldLabels[i]+" "+predLabels[i]+" head "+goldHeads[i]+" child "+i);
PwholeL = false;
}
} else {
// Parser.out.println(numsent+"error gold "+goldLabels[i]+" "+predLabels[i]+" head "+goldHeads[i]+" child "+i);
Pwhole = false;
PwholeL = false;
}
} else {
punc++;
}
if (!(goldInstance.gpos[i].toLowerCase().startsWith("pu"))) {
if (predHeads[i] == goldHeads[i]) {
correctChnWoPunc++;
if (goldLabels[i].equals(predLabels[i])) {
correctLChnWoPunc++;
if (goldInstance.gpos[i].equals(predInstance.ppos[i])) {
corrLableAndPosC++;
}
} else {
// Parser.out.println(numsent+" error gold "+goldLabels[i]+" "+predLabels[i]+" head "+goldHeads[i]+" child "+i);
// PwholeL = false;
}
} else {
// Parser.out.println(numsent+"error gold "+goldLabels[i]+" "+predLabels[i]+" head "+goldHeads[i]+" child "+i);
// Pwhole = false; PwholeL = false;
}
} else {
totalChnWoPunc++;
}
if (sig) {
if (tlas) {
Parser.out.println("1\t");
} else {
Parser.out.println("0\t");
}
}
}
total += ((instanceLength - 1)); // Subtract one to not score fake root token
Ptotal += ((instanceLength - 1) - punc);
BPtotal += ((instanceLength - 1) - bpunc);
CPtotal += ((instanceLength - 1) - totalChnWoPunc);
if (whole) {
corrsent++;
}
if (wholeL) {
corrsentL++;
}
if (Pwhole) {
Pcorrsent++;
}
if (PwholeL) {
PcorrsentL++;
}
numsent++;
goldInstance = goldReader.getNext();
predInstance = predictedReader.getNext();
correctHead.add((double) ((double) corrLabels / (instanceLength - 1)));
// Parser.out.println(""+((double)corrLabels/(instanceLength - 1)));
}
Results r = new Results();
r.correctHead = correctHead;
int mult = 100000, diff = 1000;
r.total = total;
r.corr = corr;
r.las = (float) Math.round(((double) corrL / total) * mult) / diff;
r.ula = (float) Math.round(((double) corr / total) * mult) / diff;
r.lpas = (float) Math.round(((double) corrLableAndPos / total) * mult) / diff;
r.upla = (float) Math.round(((double) corrHeadAndPos / total) * mult) / diff;
float tlasp = (float) Math.round(((double) corrLableAndPosP / Ptotal) * mult) / diff;
float tlasc = (float) Math.round(((double) corrLableAndPosC / Ptotal) * mult) / diff;
// Parser.out.print("Total: " + total+" \tCorrect: " + corr+" ");
Parser.out.print(" LAS/Total/UAS/Total: " + r.las + "/" + (double) Math.round(((double) corrsentL / numsent) * mult) / diff
+ "/" + r.ula + "/" + (double) Math.round(((double) corrsent / numsent) * mult) / diff + " LPAS/UPAS " + r.lpas + "/" + r.upla);
Parser.out.println("; without . " + (double) Math.round(((double) PcorrL / Ptotal) * mult) / diff + "/"
+ (double) Math.round(((double) PcorrsentL / numsent) * mult) / diff
+ "/" + (double) Math.round(((double) Pcorr / Ptotal) * mult) / diff + "/"
+ (double) Math.round(((double) Pcorrsent / numsent) * mult) / diff + " TLAS " + tlasp
+ " V2 LAS/UAS " + (double) Math.round(((double) BPcorrL / BPtotal) * mult) / diff
+ "/" + (double) Math.round(((double) BPcorr / BPtotal) * mult) / diff
+ " CHN LAS/UAS " + (double) Math.round(((double) correctLChnWoPunc / CPtotal) * mult) / diff
+ "/" + (double) Math.round(((double) correctChnWoPunc / CPtotal) * mult) / diff + " TLAS " + tlasc);
float precisionNonProj = ((float) nonProjOk) / ((float) nonProjOk + nonProjWrong);
float recallNonProj = ((float) nonProjOk) / ((float) (nonproj));
Parser.out.println("proj " + proj + " nonp " + nonproj + "; predicted proj " + pproj + " non " + pnonproj + "; nonp correct "
+ nonProjOk + " nonp wrong " + nonProjWrong
+ " precision=(nonProjOk)/(non-projOk+nonProjWrong): " + precisionNonProj
+ " recall=nonProjOk/nonproj=" + recallNonProj + " F=" + (2 * precisionNonProj * recallNonProj) / (precisionNonProj + recallNonProj));
if (!printEval) {
return r;
}
HashMap<String, Integer> totalX = new HashMap<>();
HashMap<String, Integer> totalY = new HashMap<>();
String A = " "; // &
Parser.out.println("label\ttp\tcount\trecall\t\ttp\tfp+tp\tprecision\t F-Score ");
for (Entry<String, Integer> e : labelCount.entrySet()) {
int tp = labelCorrect.get(e.getKey()) == null ? 0 : labelCorrect.get(e.getKey()).intValue();
Integer count = labelCount.get(e.getKey());
int fp = falsePositive.get(e.getKey()) == null ? 0 : falsePositive.get(e.getKey()).intValue();
Parser.out.println(e.getKey() + "\t" + tp + "\t" + count + "\t" + roundPercent((float) tp / count) + "\t\t" + tp + "\t" + (fp + tp)
+ "\t" + roundPercent((float) tp / (fp + tp)) + "\t\t" + roundPercent((((float) tp / count)) + (float) tp / (fp + tp)) / 2F); //+totalD
}
return r;
}
public static float round(double v) {
return Math.round(v * 10000F) / 10000F;
}
public static float roundPercent(double v) {
return Math.round(v * 10000F) / 100F;
}
}