package edu.stanford.nlp.coref.md; import edu.stanford.nlp.util.logging.Redwood; import java.io.IOException; import java.io.Serializable; import java.util.List; import java.util.Map; import java.util.Properties; import java.util.Set; import edu.stanford.nlp.coref.hybrid.rf.RandomForest; import edu.stanford.nlp.coref.data.Dictionaries; import edu.stanford.nlp.coref.data.Mention; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.stats.Counters; import edu.stanford.nlp.util.Generics; public class MentionDetectionClassifier implements Serializable { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(MentionDetectionClassifier.class); private static final long serialVersionUID = -4100580709477023158L; public RandomForest rf; public MentionDetectionClassifier(RandomForest rf) { this.rf = rf; } public static Counter<String> extractFeatures(Mention p, Set<Mention> shares, Set<String> neStrings, Dictionaries dict, Properties props) { Counter<String> features = new ClassicCounter<>(); String span = p.lowercaseNormalizedSpanString(); String ner = p.headWord.ner(); int sIdx = p.startIndex; int eIdx = p.endIndex; List<CoreLabel> sent = p.sentenceWords; CoreLabel preWord = (sIdx==0)? null : sent.get(sIdx-1); CoreLabel nextWord = (eIdx == sent.size())? null : sent.get(eIdx); CoreLabel firstWord = p.originalSpan.get(0); CoreLabel lastWord = p.originalSpan.get(p.originalSpan.size()-1); features.incrementCount("B-NETYPE-"+ner); if(neStrings.contains(span)) { features.incrementCount("B-NE-STRING-EXIST"); if( ( preWord==null || !preWord.ner().equals(ner) ) && ( nextWord==null || !nextWord.ner().equals(ner) ) ) { features.incrementCount("B-NE-FULLSPAN"); } } if(preWord!=null) features.incrementCount("B-PRECEDINGWORD-"+preWord.word()); if(nextWord!=null) features.incrementCount("B-FOLLOWINGWORD-"+nextWord.word()); if(preWord!=null) features.incrementCount("B-PRECEDINGPOS-"+preWord.tag()); if(nextWord!=null) features.incrementCount("B-FOLLOWINGPOS-"+nextWord.tag()); features.incrementCount("B-FIRSTWORD-"+firstWord.word()); features.incrementCount("B-FIRSTPOS-"+firstWord.tag()); features.incrementCount("B-LASTWORD-"+lastWord.word()); features.incrementCount("B-LASTWORD-"+lastWord.tag()); for(Mention s : shares) { if(s==p) continue; if(s.insideIn(p)) { features.incrementCount("B-BIGGER-THAN-ANOTHER"); break; } } for(Mention s : shares) { if(s==p) continue; if(p.insideIn(s)) { features.incrementCount("B-SMALLER-THAN-ANOTHER"); break; } } return features; } public static MentionDetectionClassifier loadMentionDetectionClassifier(String filename) throws ClassNotFoundException, IOException { log.info("loading MentionDetectionClassifier ..."); MentionDetectionClassifier mdc = IOUtils.readObjectFromURLOrClasspathOrFileSystem(filename); log.info("done"); return mdc; } public double probabilityOf(Mention p, Set<Mention> shares, Set<String> neStrings, Dictionaries dict, Properties props) { try { boolean dummyLabel = false; RVFDatum<Boolean, String> datum = new RVFDatum<>(extractFeatures(p, shares, neStrings, dict, props), dummyLabel); return rf.probabilityOfTrue(datum); } catch (Exception e) { throw new RuntimeException(e); } } public void classifyMentions(List<List<Mention>> predictedMentions, Dictionaries dict, Properties props) { Set<String> neStrings = Generics.newHashSet(); for (List<Mention> predictedMention : predictedMentions) { for (Mention m : predictedMention) { String ne = m.headWord.ner(); if (ne.equals("O")) continue; for (CoreLabel cl : m.originalSpan) { if (!cl.ner().equals(ne)) continue; } neStrings.add(m.lowercaseNormalizedSpanString()); } } for (List<Mention> predicts : predictedMentions) { Map<Integer, Set<Mention>> headPositions = Generics.newHashMap(); for (Mention p : predicts) { if (!headPositions.containsKey(p.headIndex)) headPositions.put(p.headIndex, Generics.newHashSet()); headPositions.get(p.headIndex).add(p); } Set<Mention> remove = Generics.newHashSet(); for (int hPos : headPositions.keySet()) { Set<Mention> shares = headPositions.get(hPos); if (shares.size() > 1) { Counter<Mention> probs = new ClassicCounter<>(); for (Mention p : shares) { double trueProb = probabilityOf(p, shares, neStrings, dict, props); probs.incrementCount(p, trueProb); } // add to remove Mention keep = Counters.argmax(probs, (m1, m2) -> m1.spanToString().compareTo(m2.spanToString())); probs.remove(keep); remove.addAll(probs.keySet()); } } for (Mention r : remove) { predicts.remove(r); } } } }