package edu.stanford.nlp.coref;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import edu.stanford.nlp.coref.data.CorefCluster;
import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.RuntimeInterruptedException;
/**
* Useful utilities for coreference resolution.
* @author Kevin Clark
*/
public class CorefUtils {
public static List<Mention> getSortedMentions(Document document) {
List<Mention> mentions = new ArrayList<>(document.predictedMentionsByID.values());
Collections.sort(mentions, (m1, m2) -> m1.appearEarlierThan(m2) ? -1 : 1);
return mentions;
}
public static List<Pair<Integer, Integer>> getMentionPairs(Document document) {
List<Pair<Integer, Integer>> pairs = new ArrayList<>();
List<Mention> mentions = getSortedMentions(document);
for (int i = 0; i < mentions.size(); i++) {
for (int j = 0; j < i; j++) {
pairs.add(new Pair<>(mentions.get(j).mentionID, mentions.get(i).mentionID));
}
}
return pairs;
}
public static Map<Pair<Integer, Integer>, Boolean> getUnlabeledMentionPairs(Document document) {
return CorefUtils.getMentionPairs(document).stream()
.collect(Collectors.toMap(p -> p, p -> false));
}
public static Map<Pair<Integer, Integer>, Boolean> getLabeledMentionPairs(Document document) {
Map<Pair<Integer, Integer>, Boolean> mentionPairs = getUnlabeledMentionPairs(document);
for (CorefCluster c : document.goldCorefClusters.values()) {
List<Mention> clusterMentions = new ArrayList<>(c.getCorefMentions());
for (Mention clusterMention : clusterMentions) {
for (Mention clusterMention2 : clusterMentions) {
Pair<Integer, Integer> mentionPair = new Pair<>(
clusterMention.mentionID, clusterMention2.mentionID);
if (mentionPairs.containsKey(mentionPair)) {
mentionPairs.put(mentionPair, true);
}
}
}
}
return mentionPairs;
}
public static void mergeCoreferenceClusters(Pair<Integer, Integer> mentionPair,
Document document) {
Mention m1 = document.predictedMentionsByID.get(mentionPair.first);
Mention m2 = document.predictedMentionsByID.get(mentionPair.second);
if (m1.corefClusterID == m2.corefClusterID) {
return;
}
int removeId = m1.corefClusterID;
CorefCluster c1 = document.corefClusters.get(m1.corefClusterID);
CorefCluster c2 = document.corefClusters.get(m2.corefClusterID);
CorefCluster.mergeClusters(c2, c1);
document.corefClusters.remove(removeId);
}
public static void removeSingletonClusters(Document document) {
for (CorefCluster c : new ArrayList<>(document.corefClusters.values())) {
if (c.getCorefMentions().size() == 1) {
document.corefClusters.remove(c.clusterID);
}
}
}
public static void checkForInterrupt() {
if (Thread.interrupted()) {
throw new RuntimeInterruptedException();
}
}
public static Map<Integer, List<Integer>> heuristicFilter(List<Mention> sortedMentions,
int maxMentionDistance, int maxMentionDistanceWithStringMatch) {
Map<String, List<Mention>> wordToMentions = new HashMap<>();
for (int i = 0; i < sortedMentions.size(); i++) {
Mention m = sortedMentions.get(i);
for (String word : getContentWords(m)) {
wordToMentions.putIfAbsent(word, new ArrayList<>());
wordToMentions.get(word).add(m);
}
}
Map<Integer, List<Integer>> mentionToCandidateAntecedents = new HashMap<>();
for (int i = 0; i < sortedMentions.size(); i++) {
Mention m = sortedMentions.get(i);
List<Integer> candidateAntecedents = new ArrayList<>();
for (int j = Math.max(0, i - maxMentionDistance); j < i; j++) {
candidateAntecedents.add(sortedMentions.get(j).mentionID);
}
for (String word : getContentWords(m)) {
List<Mention> withStringMatch = wordToMentions.get(word);
if (withStringMatch != null) {
for (Mention match : withStringMatch) {
if (match.mentionNum < m.mentionNum
&& match.mentionNum >= m.mentionNum - maxMentionDistanceWithStringMatch) {
if (!candidateAntecedents.contains(match.mentionID)) {
candidateAntecedents.add(match.mentionID);
}
}
}
}
}
if (!candidateAntecedents.isEmpty()) {
mentionToCandidateAntecedents.put(m.mentionID, candidateAntecedents);
}
}
return mentionToCandidateAntecedents;
}
private static List<String> getContentWords(Mention m) {
List<String> words = new ArrayList<>();
for (int i = m.startIndex; i < m.endIndex; i++) {
CoreLabel cl = m.sentenceWords.get(i);
String POS = cl.get(CoreAnnotations.PartOfSpeechAnnotation.class);
if (POS.equals("NN") || POS.equals("NNS") || POS.equals("NNP") || POS.equals("NNPS")) {
words.add(cl.word().toLowerCase());
}
}
return words;
}
}