package edu.stanford.nlp.naturalli; import edu.stanford.nlp.classify.Classifier; import edu.stanford.nlp.classify.GeneralDataset; import edu.stanford.nlp.ie.machinereading.structure.Span; import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.HasIndex; import edu.stanford.nlp.ling.IndexedWord; import edu.stanford.nlp.pipeline.Annotation; import edu.stanford.nlp.pipeline.AnnotationPipeline; import edu.stanford.nlp.semgraph.SemanticGraph; import edu.stanford.nlp.semgraph.SemanticGraphEdge; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counters; import edu.stanford.nlp.trees.GrammaticalRelation; import edu.stanford.nlp.util.*; import java.text.DecimalFormat; import java.util.*; import java.util.stream.Collectors; import java.util.stream.StreamSupport; import static edu.stanford.nlp.util.logging.Redwood.log; /** * TODO(gabor) JavaDoc * * @author Gabor Angeli */ public class Util { /** * TODO(gabor) JavaDoc * * @param tokens * @param span * @return */ public static String guessNER(List<CoreLabel> tokens, Span span) { Counter<String> nerGuesses = new ClassicCounter<>(); for (int i : span) { nerGuesses.incrementCount(tokens.get(i).ner()); } nerGuesses.remove("O"); nerGuesses.remove(null); if (nerGuesses.size() > 0 && Counters.max(nerGuesses) >= span.size() / 2) { return Counters.argmax(nerGuesses); } else { return "O"; } } /** * TODO(gabor) JavaDoc * * @param tokens * @return */ public static String guessNER(List<CoreLabel> tokens) { return guessNER(tokens, new Span(0, tokens.size())); } /** * Returns a coherent NER span from a list of tokens. * * @param tokens The tokens of the entire sentence. * @param seed The seed span of the intended NER span that should be expanded. * @return A 0 indexed span corresponding to a coherent NER chunk from the given seed. */ public static Span extractNER(List<CoreLabel> tokens, Span seed) { // Error checks if (seed == null) { return new Span(0, 1); } if (seed.start() < 0 || seed.end() < 0) { return new Span(0, 0); } if (seed.start() >= tokens.size() || seed.end() > tokens.size()) { return new Span(tokens.size(),tokens.size()); } if (tokens.get(seed.start()).ner() == null) { return seed; } if (seed.start() < 0 || seed.end() > tokens.size()) { return Span.fromValues(Math.max(0, seed.start()), Math.min(tokens.size(), seed.end())); } // Find the span's beginning int begin = seed.start(); while (begin < seed.end() - 1 && "O".equals(tokens.get(begin).ner())) { begin += 1; } String beginNER = tokens.get(begin).ner(); if (!"O".equals(beginNER)) { while (begin > 0 && tokens.get(begin - 1).ner().equals(beginNER)) { begin -= 1; } } else { begin = seed.start(); } // Find the span's end int end = seed.end() - 1; while (end > begin && "O".equals(tokens.get(end).ner())) { end -= 1; } String endNER = tokens.get(end).ner(); if (!"O".equals(endNER)) { while (end < tokens.size() - 1 && tokens.get(end + 1).ner().equals(endNER)) { end += 1; } } else { end = seed.end() - 1; } // Check that the NER of the beginning and end are the same if (beginNER.equals(endNER)) { return Span.fromValues(begin, end + 1); } else { String bestNER = guessNER(tokens, Span.fromValues(begin, end + 1)); if (beginNER.equals(bestNER)) { return extractNER(tokens, Span.fromValues(begin, begin + 1)); } else if (endNER.equals(bestNER)){ return extractNER(tokens, Span.fromValues(end, end + 1)); } else { // Something super funky is going on... return Span.fromValues(begin, end + 1); } } } /** * TODO(gabor) JavaDoc * * @param sentence * @param pipeline */ public static void annotate(CoreMap sentence, AnnotationPipeline pipeline) { Annotation ann = new Annotation(StringUtils.join(sentence.get(CoreAnnotations.TokensAnnotation.class), " ")); ann.set(CoreAnnotations.TokensAnnotation.class, sentence.get(CoreAnnotations.TokensAnnotation.class)); ann.set(CoreAnnotations.SentencesAnnotation.class, Collections.singletonList(sentence)); pipeline.annotate(ann); } /** * Fix some bizarre peculiarities with certain trees. * So far, these include: * <ul> * <li>Sometimes there's a node from a word to itself. This seems wrong.</li> * </ul> * * @param tree The tree to clean (in place!). * @return A list of extra edges, which are valid but were removed. */ public static List<SemanticGraphEdge> cleanTree(SemanticGraph tree) { // assert !isCyclic(tree); // Clean nodes List<IndexedWord> toDelete = new ArrayList<>(); for (IndexedWord vertex : tree.vertexSet()) { // Clean punctuation if (vertex.tag() == null) { continue; } char tag = vertex.backingLabel().tag().charAt(0); if (tag == '.' || tag == ',' || tag == '(' || tag == ')' || tag == ':') { if (!tree.outgoingEdgeIterator(vertex).hasNext()) { // This should really never happen, but it does. toDelete.add(vertex); } } } toDelete.forEach(tree::removeVertex); // Clean edges Iterator<SemanticGraphEdge> iter = tree.edgeIterable().iterator(); List<Triple<IndexedWord, IndexedWord, SemanticGraphEdge>> toAdd = new ArrayList<>(); toDelete.clear(); while (iter.hasNext()) { SemanticGraphEdge edge = iter.next(); if (edge.getDependent().index() == edge.getGovernor().index()) { // Clean up copy-edges if (edge.getDependent().isCopy(edge.getGovernor())) { for (SemanticGraphEdge toCopy : tree.outgoingEdgeIterable(edge.getDependent())) { toAdd.add(Triple.makeTriple(edge.getGovernor(), toCopy.getDependent(), toCopy)); } toDelete.add(edge.getDependent()); } if (edge.getGovernor().isCopy(edge.getDependent())) { for (SemanticGraphEdge toCopy : tree.outgoingEdgeIterable(edge.getGovernor())) { toAdd.add(Triple.makeTriple(edge.getDependent(), toCopy.getDependent(), toCopy)); } toDelete.add(edge.getGovernor()); } // Clean self-edges iter.remove(); } else if (edge.getRelation().toString().equals("punct")) { // Clean punctuation (again) if (!tree.outgoingEdgeIterator(edge.getDependent()).hasNext()) { // This should really never happen, but it does. iter.remove(); } } } // (add edges we wanted to add) toDelete.forEach(tree::removeVertex); for (Triple<IndexedWord, IndexedWord, SemanticGraphEdge> edge : toAdd) { tree.addEdge(edge.first, edge.second, edge.third.getRelation(), edge.third.getWeight(), edge.third.isExtra()); } // Handle extra edges. // Two cases: // (1) the extra edge is a subj/obj edge and the main edge is a conj:.* // in this case, keep the extra // (2) otherwise, delete the extra List<SemanticGraphEdge> extraEdges = new ArrayList<>(); for (SemanticGraphEdge edge : tree.edgeIterable()) { if (edge.isExtra()) { List<SemanticGraphEdge> incomingEdges = tree.incomingEdgeList(edge.getDependent()); SemanticGraphEdge toKeep = null; for (SemanticGraphEdge candidate : incomingEdges) { if (toKeep == null) { toKeep = candidate; } else if (toKeep.getRelation().toString().startsWith("conj") && candidate.getRelation().toString().matches(".subj.*|.obj.*")) { toKeep = candidate; } else if (!candidate.isExtra() && !(candidate.getRelation().toString().startsWith("conj") && toKeep.getRelation().toString().matches(".subj.*|.obj.*"))) { toKeep = candidate; } } for (SemanticGraphEdge candidate : incomingEdges) { if (candidate != toKeep) { extraEdges.add(candidate); } } } } extraEdges.forEach(tree::removeEdge); // Add apposition edges (simple coref) for (SemanticGraphEdge extraEdge : new ArrayList<>(extraEdges)) { // note[gabor] prevent concurrent modification exception for (SemanticGraphEdge candidateAppos : tree.incomingEdgeIterable(extraEdge.getDependent())) { if (candidateAppos.getRelation().toString().equals("appos")) { extraEdges.add(new SemanticGraphEdge(extraEdge.getGovernor(), candidateAppos.getGovernor(), extraEdge.getRelation(), extraEdge.getWeight(), extraEdge.isExtra())); } } for (SemanticGraphEdge candidateAppos : tree.outgoingEdgeIterable(extraEdge.getDependent())) { if (candidateAppos.getRelation().toString().equals("appos")) { extraEdges.add(new SemanticGraphEdge(extraEdge.getGovernor(), candidateAppos.getDependent(), extraEdge.getRelation(), extraEdge.getWeight(), extraEdge.isExtra())); } } } // Brute force ensure tree // Remove incoming edges from roots List<SemanticGraphEdge> rootIncomingEdges = new ArrayList<>(); for (IndexedWord root : tree.getRoots()) { for (SemanticGraphEdge incomingEdge : tree.incomingEdgeIterable(root)) { rootIncomingEdges.add(incomingEdge); } } rootIncomingEdges.forEach(tree::removeEdge); // Loop until it becomes a tree. boolean changed = true; while (changed) { // I just want trees to be trees; is that so much to ask!? changed = false; List<IndexedWord> danglingNodes = new ArrayList<>(); List<SemanticGraphEdge> invalidEdges = new ArrayList<>(); for (IndexedWord vertex : tree.vertexSet()) { // Collect statistics Iterator<SemanticGraphEdge> incomingIter = tree.incomingEdgeIterator(vertex); boolean hasIncoming = incomingIter.hasNext(); boolean hasMultipleIncoming = false; if (hasIncoming) { incomingIter.next(); hasMultipleIncoming = incomingIter.hasNext(); } // Register actions if (!hasIncoming && !tree.getRoots().contains(vertex)) { danglingNodes.add(vertex); } else { if (hasMultipleIncoming) { for (SemanticGraphEdge edge : new IterableIterator<>(incomingIter)) { invalidEdges.add(edge); } } } } // Perform actions for (IndexedWord vertex : danglingNodes) { tree.removeVertex(vertex); changed = true; } for (SemanticGraphEdge edge : invalidEdges) { tree.removeEdge(edge); changed = true; } } // Edge case: remove duplicate dobj to "that." // This is a common parse error. for (IndexedWord vertex : tree.vertexSet()) { SemanticGraphEdge thatEdge = null; int dobjCount = 0; for (SemanticGraphEdge edge : tree.outgoingEdgeIterable(vertex)) { if ("that".equalsIgnoreCase(edge.getDependent().word())) { thatEdge = edge; } if ("dobj".equals(edge.getRelation().toString())) { dobjCount += 1; } } if (dobjCount > 1 && thatEdge != null) { // Case: there are two dobj edges, one of which goes to the word "that" // Action: rewrite the dobj edge to "that" to be a "mark" edge. tree.removeEdge(thatEdge); tree.addEdge(thatEdge.getGovernor(), thatEdge.getDependent(), GrammaticalRelation.valueOf(thatEdge.getRelation().getLanguage(), "mark"), thatEdge.getWeight(), thatEdge.isExtra()); } } // Return assert isTree(tree); return extraEdges; } /** * Strip away case edges, if the incoming edge is a preposition. * This replicates the behavior of the old Stanford dependencies on universal dependencies. * @param tree The tree to modify in place. */ public static void stripPrepCases(SemanticGraph tree) { // Find incoming case edges that have an 'nmod' incoming edge List<SemanticGraphEdge> toClean = new ArrayList<>(); for (SemanticGraphEdge edge : tree.edgeIterable()) { if ("case".equals(edge.getRelation().toString())) { boolean isPrepTarget = false; for (SemanticGraphEdge incoming : tree.incomingEdgeIterable(edge.getGovernor())) { if ("nmod".equals(incoming.getRelation().getShortName())) { isPrepTarget = true; break; } } if (isPrepTarget && !tree.outgoingEdgeIterator(edge.getDependent()).hasNext()) { toClean.add(edge); } } } // Delete these edges for (SemanticGraphEdge edge : toClean) { tree.removeEdge(edge); tree.removeVertex(edge.getDependent()); assert isTree(tree); } } /** * Determine if a tree is cyclic. * @param tree The tree to check. * @return True if the tree has at least once cycle in it. */ public static boolean isCyclic(SemanticGraph tree) { for (IndexedWord vertex : tree.vertexSet()) { if (tree.getRoots().contains(vertex)) { continue; } IndexedWord node = tree.incomingEdgeIterator(vertex).next().getGovernor(); Set<IndexedWord> seen = new HashSet<>(); seen.add(vertex); while (node != null) { if (seen.contains(node)) { return true; } seen.add(node); if (tree.incomingEdgeIterator(node).hasNext()) { node = tree.incomingEdgeIterator(node).next().getGovernor(); } else { node = null; } } } return false; } /** * A little utility function to make sure a SemanticGraph is a tree. * @param tree The tree to check. * @return True if this {@link edu.stanford.nlp.semgraph.SemanticGraph} is a tree (versus a DAG, or Graph). */ public static boolean isTree(SemanticGraph tree) { for (IndexedWord vertex : tree.vertexSet()) { // Check one and only one incoming edge if (tree.getRoots().contains(vertex)) { if (tree.incomingEdgeIterator(vertex).hasNext()) { return false; } } else { Iterator<SemanticGraphEdge> iter = tree.incomingEdgeIterator(vertex); if (!iter.hasNext()) { return false; } iter.next(); if (iter.hasNext()) { return false; } } // Check incoming and outgoing edges match for (SemanticGraphEdge edge : tree.outgoingEdgeIterable(vertex)) { boolean foundReverse = false; for (SemanticGraphEdge reverse : tree.incomingEdgeIterable(edge.getDependent())) { if (reverse == edge) { foundReverse = true; } } if (!foundReverse) { return false; } } for (SemanticGraphEdge edge : tree.incomingEdgeIterable(vertex)) { boolean foundReverse = false; for (SemanticGraphEdge reverse : tree.outgoingEdgeIterable(edge.getGovernor())) { if (reverse == edge) { foundReverse = true; } } if (!foundReverse) { return false; } } } // Check for cycles if (isCyclic(tree)) { return false; } // Check topological sort -- sometimes fails? // try { // tree.topologicalSort(); // } catch (Exception e) { // e.printStackTrace(); // return false; // } return true; } /** * Returns true if the given two spans denote the same consistent NER chunk. That is, if we call * {@link Util#extractNER(List, Span)} on these two spans, they would return the same span. * * @param tokens The tokens in the sentence. * @param a The first span. * @param b The second span. * @param parse The parse tree to traverse looking for coreference chains to exploit. * * @return True if these two spans contain exactly the same NER. */ public static boolean nerOverlap(List<CoreLabel> tokens, Span a, Span b, Optional<SemanticGraph> parse) { Span nerA = extractNER(tokens, a); Span nerB = extractNER(tokens, b); return nerA.equals(nerB); } /** @see Util#nerOverlap(List, Span, Span, Optional) */ public static boolean nerOverlap(List<CoreLabel> tokens, Span a, Span b) { return nerOverlap(tokens, a, b, Optional.empty()); } /** * A helper function for dumping the accuracy of the trained classifier. * * @param classifier The classifier to evaluate. * @param dataset The dataset to evaluate the classifier on. */ public static void dumpAccuracy(Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier, GeneralDataset<ClauseSplitter.ClauseClassifierLabel, String> dataset) { DecimalFormat df = new DecimalFormat("0.00%"); log("size: " + dataset.size()); log("split count: " + StreamSupport.stream(dataset.spliterator(), false).filter(x -> x.label() == ClauseSplitter.ClauseClassifierLabel.CLAUSE_SPLIT).collect(Collectors.toList()).size()); log("interm count: " + StreamSupport.stream(dataset.spliterator(), false).filter(x -> x.label() == ClauseSplitter.ClauseClassifierLabel.CLAUSE_INTERM).collect(Collectors.toList()).size()); Pair<Double, Double> pr = classifier.evaluatePrecisionAndRecall(dataset, ClauseSplitter.ClauseClassifierLabel.CLAUSE_SPLIT); log("p (split): " + df.format(pr.first)); log("r (split): " + df.format(pr.second)); log("f1 (split): " + df.format(2 * pr.first * pr.second / (pr.first + pr.second))); pr = classifier.evaluatePrecisionAndRecall(dataset, ClauseSplitter.ClauseClassifierLabel.CLAUSE_INTERM); log("p (interm): " + df.format(pr.first)); log("r (interm): " + df.format(pr.second)); log("f1 (interm): " + df.format(2 * pr.first * pr.second / (pr.first + pr.second))); } /** * The dictionary of privative adjectives, as per http://hci.stanford.edu/cstr/reports/2014-04.pdf */ public static final Set<String> PRIVATIVE_ADJECTIVES = Collections.unmodifiableSet(new HashSet<String>(){{ add("believed"); add("debatable"); add("disputed"); add("dubious"); add("hypothetical"); add("impossible"); add("improbable"); add("plausible"); add("putative"); add("questionable"); add("so called"); add("supposed"); add("suspicious"); add("theoretical"); add("uncertain"); add("unlikely"); add("would - be"); add("apparent"); add("arguable"); add("assumed"); add("likely"); add("ostensible"); add("possible"); add("potential"); add("predicted"); add("presumed"); add("probable"); add("seeming"); add("anti"); add("fake"); add("fictional"); add("fictitious"); add("imaginary"); add("mythical"); add("phony"); add("false"); add("artificial"); add("erroneous"); add("mistaken"); add("mock"); add("pseudo"); add("simulated"); add("spurious"); add("deputy"); add("faulty"); add("virtual"); add("doubtful"); add("erstwhile"); add("ex"); add("expected"); add("former"); add("future"); add("onetime"); add("past"); add("proposed"); }}); /** * Construct the spanning span of the given list of tokens. * * @param tokens The tokens that should define the span. * @return A span (0-indexed) that covers all of the tokens. */ public static Span tokensToSpan(List<? extends HasIndex> tokens) { int min = Integer.MAX_VALUE; int max = Integer.MIN_VALUE; for (HasIndex token : tokens) { min = Math.min(token.index() - 1, min); max = Math.max(token.index(), max); } if (min < 0 || max == Integer.MAX_VALUE) { throw new IllegalArgumentException("Could not compute span from tokens!"); } else if (min >= max) { throw new IllegalStateException("Either logic is broken or Gabor can't code."); } else { return new Span(min, max); } } }