package edu.stanford.nlp.coref.md; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Properties; import java.util.Set; import edu.stanford.nlp.coref.CorefProperties; import edu.stanford.nlp.coref.data.Dictionaries; import edu.stanford.nlp.coref.data.Mention; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.IndexedWord; import edu.stanford.nlp.pipeline.Annotation; import edu.stanford.nlp.semgraph.SemanticGraph; import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations; import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations.BasicDependenciesAnnotation; import edu.stanford.nlp.semgraph.SemanticGraphEdge; import edu.stanford.nlp.semgraph.SemanticGraphUtils; import edu.stanford.nlp.trees.GrammaticalRelation; import edu.stanford.nlp.trees.Tree; import edu.stanford.nlp.trees.TreeCoreAnnotations.TreeAnnotation; import edu.stanford.nlp.trees.UniversalEnglishGrammaticalRelations; import edu.stanford.nlp.util.CoreMap; import edu.stanford.nlp.util.Generics; import edu.stanford.nlp.util.IntPair; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.logging.Redwood; public class DependencyCorefMentionFinder extends CorefMentionFinder { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(DependencyCorefMentionFinder.class); public DependencyCorefMentionFinder(Properties props) throws ClassNotFoundException, IOException { this.lang = CorefProperties.getLanguage(props); mdClassifier = (CorefProperties.isMentionDetectionTraining(props)) ? null : IOUtils.readObjectFromURLOrClasspathOrFileSystem(CorefProperties.getMentionDetectionModel(props)); } public MentionDetectionClassifier mdClassifier = null; /** Main method of mention detection. * Extract all NP, PRP or NE, and filter out by manually written patterns. */ @Override public List<List<Mention>> findMentions(Annotation doc, Dictionaries dict, Properties props) { List<List<Mention>> predictedMentions = new ArrayList<>(); Set<String> neStrings = Generics.newHashSet(); List<Set<IntPair>> mentionSpanSetList = Generics.newArrayList(); List<CoreMap> sentences = doc.get(CoreAnnotations.SentencesAnnotation.class); for (CoreMap s : sentences) { List<Mention> mentions = new ArrayList<>(); predictedMentions.add(mentions); Set<IntPair> mentionSpanSet = Generics.newHashSet(); Set<IntPair> namedEntitySpanSet = Generics.newHashSet(); extractPremarkedEntityMentions(s, mentions, mentionSpanSet, namedEntitySpanSet); HybridCorefMentionFinder.extractNamedEntityMentions(s, mentions, mentionSpanSet, namedEntitySpanSet); extractNPorPRPFromDependency(s, mentions, mentionSpanSet, namedEntitySpanSet); addNamedEntityStrings(s, neStrings, namedEntitySpanSet); mentionSpanSetList.add(mentionSpanSet); } // extractNamedEntityModifiers(sentences, mentionSpanSetList, predictedMentions, neStrings); for(int i=0 ; i<sentences.size() ; i++ ) { findHead(sentences.get(i), predictedMentions.get(i)); } // mention selection based on document-wise info removeSpuriousMentions(doc, predictedMentions, dict, CorefProperties.removeNestedMentions(props), lang); // if this is for MD training, skip classification if(!CorefProperties.isMentionDetectionTraining(props)) { mdClassifier.classifyMentions(predictedMentions, dict, props); } return predictedMentions; } protected static void assignMentionIDs(List<List<Mention>> predictedMentions, int maxID) { for(List<Mention> mentions : predictedMentions) { for(Mention m : mentions) { m.mentionID = (++maxID); } } } protected static void setBarePlural(List<Mention> mentions) { for (Mention m : mentions) { String pos = m.headWord.get(CoreAnnotations.PartOfSpeechAnnotation.class); if(m.originalSpan.size()==1 && pos.equals("NNS")) m.generic = true; } } private void extractNPorPRPFromDependency(CoreMap s, List<Mention> mentions, Set<IntPair> mentionSpanSet, Set<IntPair> namedEntitySpanSet) { List<CoreLabel> sent = s.get(CoreAnnotations.TokensAnnotation.class); SemanticGraph basic = s.get(BasicDependenciesAnnotation.class); List<IndexedWord> nounsOrPrp = basic.getAllNodesByPartOfSpeechPattern("N.*|PRP.*|DT"); // DT is for "this, these, etc" Tree tree = s.get(TreeAnnotation.class); for(IndexedWord w : nounsOrPrp) { SemanticGraphEdge edge = basic.getEdge(basic.getParent(w), w); GrammaticalRelation rel = null; String shortname = "root"; // if edge is null, it's root if(edge!=null) { rel = edge.getRelation(); shortname = rel.getShortName(); } // TODO: what to remove? remove more? if(shortname.matches("det|compound")) { // // for debug --------------- // Tree t = tree.getLeaves().get(w.index()-1); // for(Tree p : tree.pathNodeToNode(t, tree)) { // if(p.label().value().equals("NP")) { // HeadFinder headFinder = new SemanticHeadFinder(); // Tree head = headFinder.determineHead(p); // if(head == t.parent(tree)) { // log.info(); // } // break; // } // } // for debug ------------- continue; } else { extractMentionForHeadword(w, basic, s, mentions, mentionSpanSet, namedEntitySpanSet); } } } private void extractMentionForHeadword(IndexedWord headword, SemanticGraph dep, CoreMap s, List<Mention> mentions, Set<IntPair> mentionSpanSet, Set<IntPair> namedEntitySpanSet) { List<CoreLabel> sent = s.get(CoreAnnotations.TokensAnnotation.class); SemanticGraph basic = s.get(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class); SemanticGraph enhanced = s.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class); if (enhanced == null) { enhanced = s.get(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class); } // pronoun if(headword.tag().startsWith("PRP")) { extractPronounForHeadword(headword, dep, s, mentions, mentionSpanSet, namedEntitySpanSet); return; } // add NP mention IntPair npSpan = getNPSpan(headword, dep, sent); int beginIdx = npSpan.get(0); int endIdx = npSpan.get(1)+1; if (",".equals(sent.get(endIdx-1).word())) { endIdx--; } // try not to have span that ends with , if ("IN".equals(sent.get(beginIdx).tag())) { beginIdx++; } // try to remove first IN. addMention(beginIdx, endIdx, headword, mentions, mentionSpanSet, namedEntitySpanSet, sent, basic, enhanced); // // extract the first element in conjunction (A and B -> extract A here "A and B", "B" will be extracted above) // // to make sure we find the first conjunction Set<IndexedWord> conjChildren = dep.getChildrenWithReln(headword, UniversalEnglishGrammaticalRelations.CONJUNCT); if(conjChildren.size() > 0) { IndexedWord conjChild = dep.getChildWithReln(headword, UniversalEnglishGrammaticalRelations.CONJUNCT); for(IndexedWord c : conjChildren) { if(c.index() < conjChild.index()) conjChild = c; } IndexedWord left = SemanticGraphUtils.leftMostChildVertice(conjChild, dep); for(int endIdxFirstElement = left.index()-1 ; endIdxFirstElement > beginIdx ; endIdxFirstElement--) { if(!sent.get(endIdxFirstElement-1).tag().matches("CC|,")) { if(headword.index()-1 < endIdxFirstElement) { addMention(beginIdx, endIdxFirstElement, headword, mentions, mentionSpanSet, namedEntitySpanSet, sent, basic, enhanced); } break; } } } } /** * return the left and right most node except copula relation (nsubj & cop) and some others (maybe discourse?) * e.g., you are the person -> return "the person" */ private IntPair getNPSpan(IndexedWord headword, SemanticGraph dep, List<CoreLabel> sent) { int headwordIdx = headword.index()-1; List<IndexedWord> children = dep.getChildList(headword); // if(children.size()==0) return new IntPair(headwordIdx, headwordIdx); // the headword is the only word // check if we have copula relation IndexedWord cop = dep.getChildWithReln(headword, UniversalEnglishGrammaticalRelations.COPULA); int startIdx = (cop==null)? 0 : children.indexOf(cop)+1; // children which will be inside of NP List<IndexedWord> insideNP = Generics.newArrayList(); for(int i=startIdx ; i < children.size() ; i++) { IndexedWord child = children.get(i); SemanticGraphEdge edge = dep.getEdge(headword, child); if(edge.getRelation().getShortName().matches("dep|discourse|punct")) { continue; // skip } else { insideNP.add(child); } } if(insideNP.size()==0) return new IntPair(headwordIdx, headwordIdx); // the headword is the only word Pair<IndexedWord, IndexedWord> firstChildLeftRight = SemanticGraphUtils.leftRightMostChildVertices(insideNP.get(0), dep); Pair<IndexedWord, IndexedWord> lastChildLeftRight = SemanticGraphUtils.leftRightMostChildVertices(insideNP.get(insideNP.size()-1), dep); // headword can be first or last word int beginIdx = Math.min(headwordIdx, firstChildLeftRight.first.index()-1); int endIdx = Math.max(headwordIdx, lastChildLeftRight.second.index()-1); return new IntPair(beginIdx, endIdx); } private IntPair getNPSpanOld(IndexedWord headword, SemanticGraph dep, List<CoreLabel> sent) { IndexedWord cop = dep.getChildWithReln(headword, UniversalEnglishGrammaticalRelations.COPULA); Pair<IndexedWord, IndexedWord> leftRight = SemanticGraphUtils.leftRightMostChildVertices(headword, dep); // headword can be first or last word int beginIdx = Math.min(headword.index()-1, leftRight.first.index()-1); int endIdx = Math.max(headword.index()-1, leftRight.second.index()-1); // no copula relation if(cop==null) return new IntPair(beginIdx, endIdx); // if we have copula relation List<IndexedWord> children = dep.getChildList(headword); int copIdx = children.indexOf(cop); if(copIdx+1 < children.size()) { beginIdx = Math.min(headword.index()-1, SemanticGraphUtils.leftMostChildVertice(children.get(copIdx+1), dep).index()-1); } else { beginIdx = headword.index()-1; } return new IntPair(beginIdx, endIdx); } private void addMention(int beginIdx, int endIdx, IndexedWord headword, List<Mention> mentions, Set<IntPair> mentionSpanSet, Set<IntPair> namedEntitySpanSet, List<CoreLabel> sent, SemanticGraph basic, SemanticGraph enhanced) { IntPair mSpan = new IntPair(beginIdx, endIdx); if(!mentionSpanSet.contains(mSpan) && (!insideNE(mSpan, namedEntitySpanSet)) ) { int dummyMentionId = -1; Mention m = new Mention(dummyMentionId, beginIdx, endIdx, sent, basic, enhanced, new ArrayList<>(sent.subList(beginIdx, endIdx))); m.headIndex = headword.index()-1; m.headWord = sent.get(m.headIndex); m.headString = m.headWord.word().toLowerCase(Locale.ENGLISH); mentions.add(m); mentionSpanSet.add(mSpan); } } private void extractPronounForHeadword(IndexedWord headword, SemanticGraph dep, CoreMap s, List<Mention> mentions, Set<IntPair> mentionSpanSet, Set<IntPair> namedEntitySpanSet) { List<CoreLabel> sent = s.get(CoreAnnotations.TokensAnnotation.class); SemanticGraph basic = s.get(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class); SemanticGraph enhanced = s.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class); if (enhanced == null) { enhanced = s.get(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class); } int beginIdx = headword.index()-1; int endIdx = headword.index(); // handle "you all", "they both" etc if(sent.size() > headword.index() && sent.get(headword.index()).word().matches("all|both")) { IndexedWord c = dep.getNodeByIndex(headword.index()+1); SemanticGraphEdge edge = dep.getEdge(headword, c); if(edge!=null) endIdx++; } IntPair mSpan = new IntPair(beginIdx, endIdx); if(!mentionSpanSet.contains(mSpan) && (!insideNE(mSpan, namedEntitySpanSet)) ) { int dummyMentionId = -1; Mention m = new Mention(dummyMentionId, beginIdx, endIdx, sent, basic, enhanced, new ArrayList<>(sent.subList(beginIdx, endIdx))); m.headIndex = headword.index()-1; m.headWord = sent.get(m.headIndex); m.headString = m.headWord.word().toLowerCase(Locale.ENGLISH); mentions.add(m); mentionSpanSet.add(mSpan); } // when pronoun is a part of conjunction (e.g., you and I) Set<IndexedWord> conjChildren = dep.getChildrenWithReln(headword, UniversalEnglishGrammaticalRelations.CONJUNCT); if(conjChildren.size() > 0) { IntPair npSpan = getNPSpan(headword, dep, sent); beginIdx = npSpan.get(0); endIdx = npSpan.get(1)+1; if (",".equals(sent.get(endIdx-1).word())) { endIdx--; } // try not to have span that ends with , addMention(beginIdx, endIdx, headword, mentions, mentionSpanSet, namedEntitySpanSet, sent, basic, enhanced); } } public static void findHeadInDependency(CoreMap s, List<Mention> mentions) { for (Mention m : mentions){ findHeadInDependency(s, m); } } @Override public void findHead(CoreMap s, List<Mention> mentions) { for (Mention m : mentions){ findHeadInDependency(s, m); } } // TODO: still errors in head finder public static void findHeadInDependency(CoreMap s, Mention m) { List<CoreLabel> sent = s.get(CoreAnnotations.TokensAnnotation.class); SemanticGraph basicDep = s.get(BasicDependenciesAnnotation.class); if(m.headWord == null) { // when there's punctuation, no node found in the dependency tree int curIdx; IndexedWord cur = null; for(curIdx = m.endIndex-1 ; curIdx >= m.startIndex ; curIdx--) { if((cur = basicDep.getNodeByIndexSafe(curIdx+1)) != null) break; } if(cur==null) curIdx = m.endIndex-1; while(cur!=null) { IndexedWord p = basicDep.getParent(cur); if(p==null || p.index()-1 < m.startIndex || p.index()-1 >= m.endIndex) break; curIdx = p.index()-1; cur = basicDep.getNodeByIndexSafe(curIdx+1); } // for(IndexedWord p : basicDep.getPathToRoot(basicDep.getNodeByIndex(curIdx+1))) { // if(p.index()-1 < m.startIndex || p.index()-1 >= m.endIndex) { // break; // } // curIdx = p.index()-1; // } m.headIndex = curIdx; m.headWord = sent.get(m.headIndex); m.headString = m.headWord.word().toLowerCase(Locale.ENGLISH); } } }