package is2.mtag; import is2.data.SentenceData09; import is2.io.*; import is2.parser.Parser; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Map.Entry; public class Evaluator { public static void evaluate(String act_file, String pred_file, int formatTask) throws Exception { CONLLReader goldReader = null; CONLLReader predictedReader = null; if (formatTask == 4) { goldReader = new CONLLReader04(act_file); predictedReader = new CONLLReader04(act_file); } else if (formatTask == 6) { goldReader = new CONLLReader06(act_file); predictedReader = new CONLLReader06(act_file); } else if (formatTask == 8) { goldReader = new CONLLReader08(act_file); predictedReader = new CONLLReader08(act_file); } else if (formatTask == 9) { goldReader = new CONLLReader09(act_file); predictedReader = new CONLLReader09(act_file); } //DependencyReader.createDependencyReader(); // boolean labeled = goldReader.startReading(act_file); predictedReader.startReading(pred_file); // if (labeled != predLabeled) // Parser.out.println("Gold file and predicted file appear to differ on whether or not they are labeled. Expect problems!!!"); int total = 0, totalP = 0, corr = 0, corrL = 0, corrT = 0, totalX = 0; int totalD = 0, corrD = 0, err = 0; int numsent = 0, corrsent = 0, corrsentL = 0; SentenceData09 goldInstance = goldReader.getNext(); SentenceData09 predInstance = predictedReader.getNext(); HashMap<String, Integer> errors = new HashMap<>(); HashMap<String, StringBuffer> words = new HashMap<>(); while (goldInstance != null) { int instanceLength = goldInstance.length(); if (instanceLength != predInstance.length()) { Parser.out.println("Lengths do not match on sentence " + numsent); } String gold[] = goldInstance.ofeats; String pred[] = predInstance.pfeats; boolean whole = true; boolean wholeL = true; // NOTE: the first item is the root info added during nextInstance(), so we skip it. for (int i = 1; i < instanceLength; i++) { if (gold[i].equals(pred[i]) || (gold[i].equals("_") && pred[i] == null)) { corrT++; } else { // Parser.out.println("gold:"+goldFeats[i]+" pred:"+predFeats[i]+" "+goldInstance.forms[i]+" snt "+numsent+" i:"+i); //for (int k = 1; k < instanceLength; k++) { // Parser.out.print(goldInstance.forms[k]+":"+goldInstance.gpos[k]); // if (k==i) Parser.out.print(":"+predInstance.gpos[k]); // Parser.out.print(" "); // } //Parser.out.println(); String key = "gold: '" + gold[i] + "' pred: '" + pred[i] + "'"; Integer cnt = errors.get(key); StringBuffer errWrd = words.get(key); if (cnt == null) { errors.put(key, 1); words.put(key, new StringBuffer().append(goldInstance.forms[i])); } else { errors.put(key, cnt + 1); errWrd.append(" ").append(goldInstance.forms[i]); } err++; } String[] gf = gold[i].split("|"); int eq = 0; if (pred[i] != null) { String[] pf = pred[i].split("|"); totalP += pf.length; if (pf.length > gf.length) { totalX += pf.length; } else { totalX += gf.length; } for (String g : gf) { for (String p : pf) { if (g.equals(p)) { eq++; break; } } } } else { totalX += gf.length; } totalD += gf.length; corrD += eq; } total += instanceLength - 1; // Subtract one to not score fake root token if (whole) { corrsent++; } if (wholeL) { corrsentL++; } numsent++; goldInstance = goldReader.getNext(); predInstance = predictedReader.getNext(); } ArrayList<Entry<String, Integer>> opsl = new ArrayList<>(); for (Entry<String, Integer> e : errors.entrySet()) { opsl.add(e); } Collections.sort(opsl, new Comparator<Entry<String, Integer>>() { @Override public int compare(Entry<String, Integer> o1, Entry<String, Integer> o2) { return o1.getValue() == o2.getValue() ? 0 : o1.getValue() > o2.getValue() ? -1 : 1; } }); int cnt = 0; Parser.out.println("10 top most errors:"); for (Entry<String, Integer> e : opsl) { cnt++; // Parser.out.println(e.getKey()+" "+e.getValue()+" context: "+words.get(e.getKey())); } Parser.out.println("Tokens: " + total + " Correct: " + corrT + " " + (float) corrT / total + " R " + ((float) corrD / totalD) + " tP " + totalP + " tG " + totalD + " P " + (float) corrD / totalP); Parser.out.println("err: " + err + " total " + total + " corr " + corrT); // Parser.out.println("Unlabeled Complete Correct: " + ((double)corrsent/numsent)); } public static void main(String[] args) throws Exception { int format = 9; if (args.length > 2) { try { format = Integer.parseInt(args[2]); } catch(Exception ex) {} } evaluate(args[0], args[1], format); } }