package edu.stanford.nlp.parser.metrics; import edu.stanford.nlp.util.logging.Redwood; import java.io.File; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Stack; import java.util.TreeMap; import edu.stanford.nlp.international.Language; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.HasIndex; import edu.stanford.nlp.ling.Label; import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams; import edu.stanford.nlp.trees.Tree; import edu.stanford.nlp.trees.TreeTransformer; import edu.stanford.nlp.trees.Treebank; import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.StringUtils; /** * Implementation of the Leaf Ancestor metric first described by Sampson and Babarczy (2003) and * later analyzed more completely by Clegg and Shepherd (2005). * <p> * This implementation assumes that the guess/gold input files are of equal length, and have one tree per * line. * <p> * TODO (spenceg): This implementation doesn't insert the "boundary symbols" as described by both * Sampson and Clegg. Need to add those. * * @author Spence Green * */ public class LeafAncestorEval { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(LeafAncestorEval.class); private final String name; private static final boolean DEBUG = false; //Corpus level (macro-averaged) private double sentAvg = 0.0; private double sentNum = 0.0; private int sentExact = 0; //Sentence level (micro-averaged) private double corpusAvg = 0.0; private double corpusNum = 0.0; //Category level private final Map<List<CoreLabel>,Double> catAvg; private final Map<List<CoreLabel>,Double> catNum; public LeafAncestorEval(String str) { this.name = str; catAvg = Generics.newHashMap(); catNum = Generics.newHashMap(); } /** * Depth-first (post-order) search through the tree, recording the stack state as the * lineage every time a terminal is reached. * * This implementation uses the Index annotation to store depth. If CoreLabels are * not present in the trees (or at least something that implements HasIndex), an exception will result. * * @param t The tree * @return A list of lineages */ private static List<List<CoreLabel>> makeLineages(final Tree t) { if(t == null) return null; ((HasIndex) t.label()).setIndex(0); final Stack<Tree> treeStack = new Stack<>(); treeStack.push(t); final Stack<CoreLabel> labelStack = new Stack<>(); CoreLabel rootLabel = new CoreLabel(t.label()); rootLabel.setIndex(0); labelStack.push(rootLabel); final List<List<CoreLabel>> lineages = new ArrayList<>(); while(!treeStack.isEmpty()) { Tree node = treeStack.pop(); int nodeDepth = ((HasIndex) node.label()).index(); while(!labelStack.isEmpty() && labelStack.peek().index() != nodeDepth - 1) labelStack.pop(); if(node.isPreTerminal()) { List<CoreLabel> lin = new ArrayList<>(labelStack); lineages.add(lin); } else { for(Tree kid : node.children()) { ((HasIndex) kid.label()).setIndex(nodeDepth + 1); treeStack.push(kid); } CoreLabel nodeLabel = new CoreLabel(node.label()); nodeLabel.setIndex(nodeDepth); labelStack.add(nodeLabel); } } if(DEBUG) { System.out.println("Lineages:"); for(List<CoreLabel> lin : lineages) { for(CoreLabel cl : lin) System.out.print(cl.value() + " <- "); System.out.println(); } } return lineages; } private void updateCatAverages(final List<CoreLabel> lineage, double score) { if(catAvg.get(lineage) == null) { catAvg.put(lineage, score); catNum.put(lineage, 1.0); } else { double newAvg = catAvg.get(lineage) + score; catAvg.put(lineage, newAvg); double newNum = catNum.get(lineage) + 1.0; catNum.put(lineage, newNum); } } public void evaluate(Tree guess, Tree gold, PrintWriter pw) { if(gold == null || guess == null) { System.err.printf("%s: Cannot compare against a null gold or guess tree!%n",this.getClass().getName()); return; } final List<List<CoreLabel>> guessLineages = makeLineages(guess); final List<List<CoreLabel>> goldLineages = makeLineages(gold); if(guessLineages.size() == goldLineages.size()) { double localScores = 0.0; for(int i = 0; i < guessLineages.size(); i++) { List<CoreLabel> guessLin = guessLineages.get(i); List<CoreLabel> goldLin = goldLineages.get(i); double levDist = editDistance(guessLin, goldLin); double la = 1.0 - (levDist / (double) (guessLin.size() + goldLin.size())); localScores += la; updateCatAverages(goldLin, la); } corpusAvg += localScores; corpusNum += goldLineages.size(); double localSentAvg = localScores / goldLineages.size(); if(localSentAvg == 1.0) sentExact++; sentAvg += localSentAvg; sentNum++; } else { System.err.printf("%s: Number of guess (%d) gold (%d) don't match!%n",this.getClass().getName(),guessLineages.size(),goldLineages.size()); log.info("Cannot evaluate!"); System.err.printf("GUESS tree:%n%s%n", guess.toString()); System.err.printf("GOLD tree:%n%s%n", gold.toString()); } } /** * Computes Levenshtein edit distance between two lists of labels; * * @param l1 * @param l2 */ private static int editDistance(final List<CoreLabel> l1, final List<CoreLabel> l2) { int[][] m = new int[l1.size()+1][l2.size()+1]; for(int i = 1; i <= l1.size(); i++) m[i][0] = i; for(int j = 1; j <= l2.size(); j++) m[0][j] = j; for(int i = 1; i <= l1.size(); i++) { for(int j = 1; j <= l2.size(); j++) { m[i][j] = Math.min(m[i-1][j-1] + ((l1.get(i-1).equals(l2.get(j-1))) ? 0 : 1), m[i-1][j] + 1); m[i][j] = Math.min(m[i][j], m[i][j-1] + 1); } } return m[l1.size()][l2.size()]; } private static String toString(final List<CoreLabel> lineage) { StringBuilder sb = new StringBuilder(); for(CoreLabel cl : lineage) { sb.append(cl.value()); sb.append(" <-- "); } return sb.toString(); } public void display(boolean verbose, PrintWriter pw) { final Random rand = new Random(); double corpusLevel = corpusAvg / corpusNum; double sentLevel = sentAvg / sentNum; double sentEx = 100.0 * sentExact / sentNum; if(verbose) { Map<Double,List<CoreLabel>> avgMap = new TreeMap<>(); for (Map.Entry<List<CoreLabel>, Double> entry : catAvg.entrySet()) { double avg = entry.getValue() / catNum.get(entry.getKey()); if (Double.isNaN(avg)) { avg = -1.0; } if (avgMap.containsKey(avg)) { avgMap.put(avg + (rand.nextDouble() / 10000.0), entry.getKey()); } else { avgMap.put(avg, entry.getKey()); } } pw.println("============================================================"); pw.println("Leaf Ancestor Metric" + "(" + name + ") -- final statistics"); pw.println("============================================================"); pw.println("#Sentences: " + (int) sentNum); pw.println(); pw.println("Sentence-level (macro-averaged)"); pw.printf(" Avg: %.3f%n", sentLevel); pw.printf(" Exact: %.2f%%%n", sentEx); pw.println(); pw.println("Corpus-level (micro-averaged)"); pw.printf(" Avg: %.3f%n", corpusLevel); pw.println("============================================================"); for (List<CoreLabel> lineage : avgMap.values()) { if(catNum.get(lineage) < 30.0) continue; double avg = catAvg.get(lineage) / catNum.get(lineage); pw.printf(" %.3f\t%d\t%s%n",avg, (int) ((double)catNum.get(lineage)),toString(lineage)); } pw.println("============================================================"); } else { pw.printf("%s summary: corpus: %.3f sent: %.3f sent-ex: %.2f%n", name,corpusLevel,sentLevel,sentEx); } } private static final String USAGE = String.format("Usage: java %s [OPTS] goldFile guessFile%n%nOptions:%n -l lang : Language name %s%n" + " -y num : Skip gold trees with yields longer than num.%n -v : Verbose output%n", LeafAncestorEval.class.getName(), Language.langList); private static final int MIN_ARGS = 2; //Command line options private static boolean VERBOSE = false; private static Language LANGUAGE = Language.English; private static int MAX_GOLD_YIELD = Integer.MAX_VALUE; private static File guessFile = null; private static File goldFile = null; public static final Map<String,Integer> optionArgDefs = Generics.newHashMap(); static { optionArgDefs.put("-y", 1); optionArgDefs.put("-l", 1); optionArgDefs.put("-v", 0); } private static boolean validateCommandLine(String[] args) { Map<String, String[]> argsMap = StringUtils.argsToMap(args,optionArgDefs); for(Map.Entry<String, String[]> opt : argsMap.entrySet()) { String key = opt.getKey(); if (key != null) { switch (key) { case "-y": MAX_GOLD_YIELD = Integer.parseInt(opt.getValue()[0]); break; case "-l": LANGUAGE = Language.valueOf(opt.getValue()[0]); break; case "-v": VERBOSE = true; break; default: return false; } } } //Regular arguments String[] rest = argsMap.get(null); if(rest == null || rest.length != MIN_ARGS) { return false; } else { goldFile = new File(rest[0]); guessFile = new File(rest[1]); } return true; } /** * Execute with no arguments for usage. */ public static void main(String[] args) { if(!validateCommandLine(args)) { log.info(USAGE); System.exit(-1); } final TreebankLangParserParams tlpp = LANGUAGE.params; final PrintWriter pwOut = tlpp.pw(); final Treebank guessTreebank = tlpp.diskTreebank(); guessTreebank.loadPath(guessFile); pwOut.println("GUESS TREEBANK:"); pwOut.println(guessTreebank.textualSummary()); final Treebank goldTreebank = tlpp.diskTreebank(); goldTreebank.loadPath(goldFile); pwOut.println("GOLD TREEBANK:"); pwOut.println(goldTreebank.textualSummary()); final LeafAncestorEval metric = new LeafAncestorEval("LeafAncestor"); final TreeTransformer tc = tlpp.collinizer(); //The evalb ref implementation assigns status for each tree pair as follows: // // 0 - Ok (yields match) // 1 - length mismatch // 2 - null parse e.g. (()). // //In the cases of 1,2, evalb does not include the tree pair in the LP/LR computation. final Iterator<Tree> goldItr = goldTreebank.iterator(); final Iterator<Tree> guessItr = guessTreebank.iterator(); int goldLineId = 0; int guessLineId = 0; int skippedGuessTrees = 0; while( guessItr.hasNext() && goldItr.hasNext() ) { Tree guessTree = guessItr.next(); List<Label> guessYield = guessTree.yield(); guessLineId++; Tree goldTree = goldItr.next(); List<Label> goldYield = goldTree.yield(); goldLineId++; // Check that we should evaluate this tree if(goldYield.size() > MAX_GOLD_YIELD) { skippedGuessTrees++; continue; } // Only trees with equal yields can be evaluated if(goldYield.size() != guessYield.size()) { pwOut.printf("Yield mismatch gold: %d tokens vs. guess: %d tokens (lines: gold %d guess %d)%n", goldYield.size(), guessYield.size(), goldLineId, guessLineId); skippedGuessTrees++; continue; } final Tree evalGuess = tc.transformTree(guessTree); final Tree evalGold = tc.transformTree(goldTree); metric.evaluate(evalGuess, evalGold, ((VERBOSE) ? pwOut : null)); } if(guessItr.hasNext() || goldItr.hasNext()) { System.err.printf("Guess/gold files do not have equal lengths (guess: %d gold: %d)%n.", guessLineId, goldLineId); } pwOut.println("================================================================================"); if(skippedGuessTrees != 0) pwOut.printf("%s %d guess trees%n", "Unable to evaluate", skippedGuessTrees); metric.display(true, pwOut); pwOut.close(); } }