package edu.stanford.nlp.parser.ensemble.utils; import java.io.*; import java.util.ArrayList; import java.util.HashSet; import java.util.List; public class Eisner { static enum PARSE_MODE { MAJORITY, ATTARDI, EISNER, CHU_LIU_EDMOND }; static PARSE_MODE mode = PARSE_MODE.EISNER; public static void main(String[] args) throws Exception { if (args.length == 0) { demo(); return; } List<String> bases = new ArrayList<String>(); for (int i = 2; i < args.length; i++) { bases.add(args[i]); } String outFile = args[0]; @SuppressWarnings("unused") String goldFile = args[1]; ensemble(goldFile, bases, outFile); } @SuppressWarnings("unchecked") public static void ensemble(String goldFile, List<String> sysFiles, String outFile, String parseMode) throws IOException { try { mode = PARSE_MODE.valueOf(parseMode.toUpperCase()); } catch(IllegalArgumentException ex){ throw new RuntimeException("Unknown mode: " + parseMode); } ensemble(goldFile, sysFiles, outFile); } @SuppressWarnings("unchecked") public static void ensemble(String goldFile, List<String> sysFiles, String outFile) throws IOException { if (mode == PARSE_MODE.CHU_LIU_EDMOND) { StringBuilder inputs = new StringBuilder(); for (int i =0 ;i <sysFiles.size(); i++) inputs.append(sysFiles.get(i)).append(" "); // Use MaltBlender Process p = Runtime.getRuntime().exec( "java " + //------------------// // Java arguments: // //------------------// // path to perl interpreter (default: "perl") "-DPERL=C:\\Perl\\bin\\perl.exe " + // path to perl evaluation script (default: "eval07.pl") "-DEVALUATOR=eval07.pl " + "-jar lib\\MaltBlender.jar " + //-------------------------// // NaltBlender arguments: // //-------------------------// // Weighting strategy wrt the parsers (default: 2) // 2 = parsers are weighted equally // 3 = parsers are weighted according to total labeled accuracy // 4 = parsers are weighted according to labeled accuracy per coarse grained postag "-w 4 " + // Weighting strategy wrt the labels (default: LABELED_BEST_AMONG_THE_SELECTED) // 1 = FIRST_BEST // 2 = BEST_AMONG_ALL // 3 = BEST_AMONG_THE_SELECTED // 4 = LABELED_BEST_AMONG_THE_SELECTED "-l 4 " + // merged output file "-o " + outFile + " " + // gold-standard file * //"-g " + goldFile + " " + // activate evaluation of all individual input CoNLL files //"-e " + // combine all subsets of Input CoNLL files, and if gold-standard file // is specified: sort results by accuracy (without -c: combine all input CoNLL files only) // "-c " + // dependency type to exclude (defalult: exclude none) // "-d " + // held-out test file (used for estimating weights; using gold-standard file if -h not specified) "-h " + goldFile + " " + // held-out parsed files... (used for estimating weights; using gold-standard file if -h not specified) "-H " + inputs + // Input CoNLL files... "-F " + inputs); BufferedReader stdInput = new BufferedReader(new InputStreamReader( p.getInputStream())); String s; while ((s = stdInput.readLine()) != null) { System.out.println(s); } } else { PrintStream os = new PrintStream(new FileOutputStream(outFile)); BufferedReader[] is = new BufferedReader[sysFiles.size()]; for (int i = 0; i < sysFiles.size(); i++) { is[i] = FileUtils.openForReading(sysFiles.get(i)); } List<Token>[] sents = new List[is.length]; int sentCount = 0; while ((sents[0] = Token.readNextSentCoNLLX(is[0])) != null) { sentCount++; for (int i = 1; i < is.length; i++) { sents[i] = Token.readNextSentCoNLLX(is[i]); } if (mode == PARSE_MODE.EISNER) { for (int i = 0; i < is.length; i++) { Token.fixMultipleRoots(sents[i]); } } List<Token> cands = Token.mergeSentences(sents); for (Token cand : cands) { cand.setScore(cand.getModels().size()); } List<Token> tree = null; if (mode == PARSE_MODE.ATTARDI) { tree = attardiVote(cands); } else if (mode == PARSE_MODE.EISNER) { Span<Token> top = parse(cands, sents[0].size() + 1); if (top != null) { if (verbose) { System.err.println("Score: " + top.score); } tree = top.dependencies; } else { throw new RuntimeException("Did not find TOP!"); } } else if (mode == PARSE_MODE.MAJORITY) { tree = majorityVote(cands, sents[0].size()); } else { throw new RuntimeException("Unknown mode: " + mode); } DependencyUtils.sortById(tree); for (Token t : tree) { os.println(t); } os.println(); } os.close(); for (BufferedReader i : is) { i.close(); } } } static List<Token> majorityVote(List<Token> cands, int len) { List<Token> out = new ArrayList<Token>(); for (int i = 1; i <= len; i++) { Token bestDep = null; int bestScore = 0; for (Token c : cands) { if (c.mod() == i && c.getModels().size() > bestScore) { bestDep = c; bestScore = c.getModels().size(); } } assert (bestDep != null); out.add(bestDep); } return out; } static int[] totalByVotes; static int[] correctByVotes; static { int len = 10; totalByVotes = new int[len]; correctByVotes = new int[len]; for (int i = 0; i < len; i++) { totalByVotes[i] = 0; correctByVotes[i] = 0; } } private static boolean isCorrect(Token dep, List<Token> golds) { for (Token gold : golds) { if (dep.mod() == gold.mod() && dep.head() == gold.head() && dep.label().equals(gold.label())) { return true; } } return false; } static List<Token> weightedMajorityVote(List<Token> cands, int len, List<Token> goldSent) { List<Token> out = new ArrayList<Token>(); for (int i = 1; i <= len; i++) { Token bestDep = null; Token firstDep = null; double bestScore = 0; for (Token c : cands) { if (c.mod() == i && c.score > bestScore) { bestDep = c; bestScore = c.score; } if (firstDep == null && c.getModels().contains(0)) { firstDep = c; } } if (bestDep == null) { bestDep = firstDep; } totalByVotes[bestDep.getModels().size()]++; if (isCorrect(bestDep, goldSent)) { correctByVotes[bestDep.getModels().size()]++; } out.add(bestDep); } return out; } static void demo() { List<Token> cands = new ArrayList<Token>(); cands.add(new Token(1, "The", "DT", 2, "NMOD", 1.0)); cands.add(new Token(2, "singer", "NN", 3, "SUBJ", 1.0)); cands.add(new Token(2, "singer", "NN", 5, "NMOD", 10.0)); cands.add(new Token(3, "played", "VBZ", 0, "ROOT", 1.0)); cands.add(new Token(4, "the", "DT", 5, "NMOD", 1.0)); cands.add(new Token(5, "celo", "NN", 3, "OBJ", 1.0)); cands.add(new Token(6, "well", "JJ", 3, "MNR", 1.0)); parse(cands, 7); } static class Span<T extends Dependency> { List<T> dependencies; double score; public Span() { score = 0; dependencies = new ArrayList<T>(); } public Span(Span<T> left, Span<T> right, T dep) { score = left.score + right.score + (dep != null ? dep.score() : 0.0); dependencies = new ArrayList<T>(); if (dep != null) { dependencies.add(dep); } for (T d : left.dependencies) { dependencies.add(d); } for (T d : right.dependencies) { dependencies.add(d); } } @Override public String toString() { StringBuilder os = new StringBuilder(); os.append("{").append(score).append("}"); for (T d : dependencies) { os.append(" [").append(d.mod()).append(", ").append(d.head()).append(", ").append(d.label()).append("]"); } return os.toString(); } }; static final int HEAD_LEFT = 0; static final int HEAD_RIGHT = 1; static class Chart<T extends Dependency> { Span<T>[][][] chart; @SuppressWarnings("unchecked") public Chart(int dimension) { chart = new Span[dimension][dimension][2]; for (int i = 0; i < chart.length; i++) { chart[i][i][0] = new Span(); chart[i][i][1] = new Span(); } } Span<T> get(int start, int end, int type) { return chart[start][end][type]; } void set(int start, int end, int type, Span<T> span) { if (chart[start][end][type] == null) { chart[start][end][type] = span; } else if (chart[start][end][type].score < span.score) { chart[start][end][type] = span; } } void display(PrintStream os, int dimension) { for (int i = 0; i < chart.length; i++) { for (int j = 0; j < chart[i].length; j++) { if (j - i + 1 == dimension) { for (int k = 0; k < chart[i][j].length; k++) { if (chart[i][j][k] != null) { os.printf("[%d, %d, %d]: ", i, j, k); os.println(chart[i][j][k]); } } } } } } }; static final boolean verbose = false; @SuppressWarnings("unchecked") static <T extends Dependency> Span<T> parse(List<T> cands, int length) { Chart<T> chart = new Chart<T>(length); Dependency[][] candTable = toTable(cands, length); for (int spanLength = 2; spanLength <= length; spanLength++) { if (verbose) { System.err.println("Span length: " + spanLength); } for (int start = 0; start + spanLength <= length; start++) { int end = start + spanLength - 1; if (verbose) { System.err.printf("Span: [%d, %d]\n", start, end); } for (int split = start; split < end; split++) { Span<T> l, r = null; Dependency d = null; // merge [start(m), split] and [split + 1, end(h)] if ((l = chart.get(start, split, HEAD_LEFT)) != null && (r = chart.get(split + 1, end, HEAD_RIGHT)) != null && (d = candTable[start][end]) != null) { Span<T> s = new Span<T>(l, r, (T) d); chart.set(start, end, HEAD_RIGHT, s); } // merge [start(m), split] and [split + 1(h), end] if ((l = chart.get(start, split, HEAD_LEFT)) != null && (r = chart.get(split + 1, end, HEAD_RIGHT)) != null && (d = candTable[start][split + 1]) != null) { Span<T> s = new Span<T>(l, r, (T) d); chart.set(start, end, HEAD_RIGHT, s); } // merge [start(h), split] and [split + 1, end(m)] if ((l = chart.get(start, split, HEAD_LEFT)) != null && (r = chart.get(split + 1, end, HEAD_RIGHT)) != null && (d = candTable[end][start]) != null) { Span<T> s = new Span<T>(l, r, (T) d); chart.set(start, end, HEAD_LEFT, s); } // merge [start, split(h)] and [split + 1, end(m)] if ((l = chart.get(start, split, HEAD_LEFT)) != null && (r = chart.get(split + 1, end, HEAD_RIGHT)) != null && (d = candTable[end][split]) != null) { Span<T> s = new Span<T>(l, r, (T) d); chart.set(start, end, HEAD_LEFT, s); } // merge [start, split(m)] and [split + 1(h), end] if ((l = chart.get(start, split, HEAD_RIGHT)) != null && (r = chart.get(split + 1, end, HEAD_RIGHT)) != null && (d = candTable[split][split + 1]) != null) { Span<T> s = new Span<T>(l, r, (T) d); chart.set(start, end, HEAD_RIGHT, s); } // merge [start, split(m)] and [split + 1, end(h)] if ((l = chart.get(start, split, HEAD_RIGHT)) != null && (r = chart.get(split + 1, end, HEAD_RIGHT)) != null && (d = candTable[split][end]) != null) { Span<T> s = new Span<T>(l, r, (T) d); chart.set(start, end, HEAD_RIGHT, s); } // merge [start, split(h)] and [split + 1(m), end] if ((l = chart.get(start, split, HEAD_LEFT)) != null && (r = chart.get(split + 1, end, HEAD_LEFT)) != null && (d = candTable[split + 1][split]) != null) { Span<T> s = new Span<T>(l, r, (T) d); chart.set(start, end, HEAD_LEFT, s); } // merge [start(h), split] and [split + 1(m), end] if ((l = chart.get(start, split, HEAD_LEFT)) != null && (r = chart.get(split + 1, end, HEAD_LEFT)) != null && (d = candTable[split + 1][start]) != null) { Span<T> s = new Span<T>(l, r, (T) d); chart.set(start, end, HEAD_LEFT, s); } // merge [start, split] and [split, end] if ((l = chart.get(start, split, HEAD_LEFT)) != null && (r = chart.get(split, end, HEAD_LEFT)) != null) { Span<T> s = new Span<T>(l, r, null); chart.set(start, end, HEAD_LEFT, s); } if ((l = chart.get(start, split, HEAD_RIGHT)) != null && (r = chart.get(split, end, HEAD_RIGHT)) != null) { Span<T> s = new Span<T>(l, r, null); chart.set(start, end, HEAD_RIGHT, s); } } } if (verbose) { chart.display(System.err, spanLength); } } return chart.get(0, length - 1, HEAD_LEFT); } /** * Stores all candidates in a table format (from start to end) for easier * access * * @param <T> * @param cands * @return */ static <T extends Dependency> Dependency[][] toTable(List<T> cands, int length) { int discarded = 0; Dependency[][] table = new Dependency[length][length]; for (T c : cands) { if (table[c.mod()][c.head()] == null) { table[c.mod()][c.head()] = c; } else if (table[c.mod()][c.head()].score() < c.score()) { table[c.mod()][c.head()] = c; } else { discarded++; } } if (discarded > 0 && verbose) { System.err.printf("Discarded %d redundant dependencies.\n", discarded); } return table; } /** * Implements Attardi's re-parsing algorithm Note: this is a poor man's * implementation. While it should have the exact same output, the runtime * complexity is higher (O(N^2)) vs. O(N) of the original algorithm * * @param <T> * @param cands * @return */ static <T extends Dependency> List<T> attardiVote(List<T> cands) { List<T> treeDeps = new ArrayList<T>(); HashSet<Integer> treeNodes = new HashSet<Integer>(); treeNodes.add(0); // used when ensemble have not tree form List<T> backupCands = new ArrayList<T>(cands); List<T> F = new ArrayList<T>(); for (int i = 0; i < cands.size();) { if (cands.get(i).head() == 0) { F.add(cands.get(i)); cands.remove(i); } else { i++; } } while (F.isEmpty() == false) { double bestScore = -1; T bestDep = null; for (T f : F) { if (treeNodes.contains(f.head()) && f.score() > bestScore) { bestScore = f.score(); bestDep = f; } } assert (bestDep != null); treeDeps.add(bestDep); treeNodes.add(bestDep.mod()); for (int k =0; k<backupCands.size();) { if (backupCands.get(k).mod() == bestDep.mod()) { backupCands.remove(k); } else { k++; } } for (int i = 0; i < F.size();) { if (treeNodes.contains(F.get(i).mod()) || F.get(i).head() == 0) { F.remove(i); } else { i++; } } for (int i = 0; i < cands.size();) { if (treeNodes.contains(cands.get(i).head()) && !treeNodes.contains(cands.get(i).mod())) { F.add(cands.get(i)); cands.remove(i); } else { i++; } } } // if not empty, grapgh not connected. find another subtrees if (!backupCands.isEmpty()) { List<T> addedTreeDeps = attardiVote(backupCands); treeDeps.addAll(addedTreeDeps); } return treeDeps; } }