package mstparser; import java.io.IOException; import mstparser.io.DependencyReader; public class DependencyEvaluator { public static void evaluate(String act_file, String pred_file, String format, boolean hasConfidence) throws IOException { DependencyReader goldReader = DependencyReader.createDependencyReader(format); boolean labeled = goldReader.startReading(act_file); DependencyReader predictedReader; if (hasConfidence) { predictedReader = DependencyReader.createDependencyReaderWithConfidenceScores(format); } else { predictedReader = DependencyReader.createDependencyReader(format); } boolean predLabeled = predictedReader.startReading(pred_file); if (labeled != predLabeled) { DependencyParser.out.println("Gold file and predicted file appear to differ " + "on whether or not they are labeled. Expect problems!!!"); } int total = 0; int corr = 0; int corrL = 0; int numsent = 0; int corrsent = 0; int corrsentL = 0; int root_act = 0; int root_guess = 0; int root_corr = 0; DependencyInstance goldInstance = goldReader.getNext(); DependencyInstance predInstance = predictedReader.getNext(); while (goldInstance != null) { int instanceLength = goldInstance.length(); if (instanceLength != predInstance.length()) { DependencyParser.out.println("Lengths do not match on sentence " + numsent); } int[] goldHeads = goldInstance.heads; String[] goldLabels = goldInstance.deprels; int[] predHeads = predInstance.heads; String[] predLabels = predInstance.deprels; boolean whole = true; boolean wholeL = true; // NOTE: the first item is the root info added during nextInstance(), so we skip it. for (int i = 1; i < instanceLength; i++) { if (predHeads[i] == goldHeads[i]) { corr++; if (labeled) { if (goldLabels[i].equals(predLabels[i])) { corrL++; } else { wholeL = false; } } } else { whole = false; wholeL = false; } } total += instanceLength - 1; // Subtract one to not score fake root token if (whole) { corrsent++; } if (wholeL) { corrsentL++; } numsent++; goldInstance = goldReader.getNext(); predInstance = predictedReader.getNext(); } DependencyParser.out.println("Tokens: " + total); DependencyParser.out.println("Correct: " + corr); DependencyParser.out.println("Unlabeled Accuracy: " + ((double) corr / total)); DependencyParser.out.println("Unlabeled Complete Correct: " + ((double) corrsent / numsent)); if (labeled) { DependencyParser.out.println("Labeled Accuracy: " + ((double) corrL / total)); DependencyParser.out.println("Labeled Complete Correct: " + ((double) corrsentL / numsent)); } } public static void main(String[] args) throws IOException { String format = "CONLL"; if (args.length > 2) { format = args[2]; } evaluate(args[0], args[1], format, false); } }