package edu.stanford.nlp.coref.statistical;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import edu.stanford.nlp.coref.CorefAlgorithm;
import edu.stanford.nlp.coref.CorefProperties;
import edu.stanford.nlp.coref.CorefUtils;
import edu.stanford.nlp.coref.data.Dictionaries;
import edu.stanford.nlp.coref.data.Dictionaries.MentionType;
import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.RuntimeInterruptedException;
/**
* Does best-first coreference resolution by linking each mention to its highest scoring candidate
* antecedent if that score is above a threshold. The model is described in
* <p/>
* Kevin Clark and Christopher D. Manning. 2015.
* <a href="http://nlp.stanford.edu/pubs/clark-manning-acl15-entity.pdf">
* Entity-Centric Coreference Resolution with Model Stacking</a>.
* In Association for Computational Linguistics.
* <p/>
* See {@link StatisticalCorefTrainer} for training a new model.
* @author Kevin Clark
*/
public class StatisticalCorefAlgorithm implements CorefAlgorithm {
private final Map<Pair<Boolean, Boolean>, Double> thresholds;
private final FeatureExtractor extractor;
private final PairwiseModel classifier;
private final int maxMentionDistance;
private final int maxMentionDistanceWithStringMatch;
public StatisticalCorefAlgorithm(Properties props, Dictionaries dictionaries) {
this(props, dictionaries,
StatisticalCorefProperties.wordCountsPath(props),
StatisticalCorefProperties.rankingModelPath(props),
CorefProperties.maxMentionDistance(props),
CorefProperties.maxMentionDistanceWithStringMatch(props),
StatisticalCorefProperties.pairwiseScoreThresholds(props));
}
public StatisticalCorefAlgorithm(Properties props, Dictionaries dictionaries, String wordCountsFile,
String modelFile, int maxMentionDistance, int maxMentionDistanceWithStringMatch,
double threshold) {
this(props, dictionaries, wordCountsFile, modelFile, maxMentionDistance,
maxMentionDistanceWithStringMatch, new double[] {threshold, threshold, threshold,
threshold});
}
public StatisticalCorefAlgorithm(Properties props, Dictionaries dictionaries, String wordCountsFile,
String modelPath, int maxMentionDistance, int maxMentionDistanceWithStringMatch,
double[] thresholds) {
extractor = new FeatureExtractor(props, dictionaries, null, wordCountsFile);
classifier = PairwiseModel.newBuilder("classifier",
MetaFeatureExtractor.newBuilder().build()).modelPath(modelPath).build();
this.maxMentionDistance = maxMentionDistance;
this.maxMentionDistanceWithStringMatch = maxMentionDistanceWithStringMatch;
this.thresholds = makeThresholds(thresholds);
}
private static Map<Pair<Boolean, Boolean>, Double> makeThresholds(double[] thresholds) {
Map<Pair<Boolean, Boolean>, Double> thresholdsMap = new HashMap<>();
thresholdsMap.put(new Pair<>(true, true), thresholds[0]);
thresholdsMap.put(new Pair<>(true, false), thresholds[1]);
thresholdsMap.put(new Pair<>(false, true), thresholds[2]);
thresholdsMap.put(new Pair<>(false, false), thresholds[3]);
return thresholdsMap;
}
@Override
public void runCoref(Document document) {
Compressor<String> compressor = new Compressor<>();
if (Thread.interrupted()) { // Allow interrupting
throw new RuntimeInterruptedException();
}
Map<Pair<Integer, Integer>, Boolean> pairs = new HashMap<>();
for (Map.Entry<Integer, List<Integer>> e: CorefUtils.heuristicFilter(
CorefUtils.getSortedMentions(document),
maxMentionDistance, maxMentionDistanceWithStringMatch).entrySet()) {
for (int m1 : e.getValue()) {
pairs.put(new Pair<>(m1, e.getKey()), true);
}
}
DocumentExamples examples = extractor.extract(0, document, pairs, compressor);
Counter<Pair<Integer, Integer>> pairwiseScores = new ClassicCounter<>();
for (Example mentionPair : examples.examples) {
if (Thread.interrupted()) { // Allow interrupting
throw new RuntimeInterruptedException();
}
pairwiseScores.incrementCount(new Pair<>(mentionPair.mentionId1, mentionPair.mentionId2),
classifier.predict(mentionPair, examples.mentionFeatures, compressor));
}
List<Pair<Integer, Integer>> mentionPairs = new ArrayList<>(pairwiseScores.keySet());
Collections.sort(mentionPairs, (p1, p2) -> {
double diff = pairwiseScores.getCount(p2) - pairwiseScores.getCount(p1);
return diff == 0 ? 0 : (int) Math.signum(diff);
});
Set<Integer> seenAnaphors = new HashSet<>();
for (Pair<Integer, Integer> pair : mentionPairs) {
if (seenAnaphors.contains(pair.second)) {
continue;
}
if (Thread.interrupted()) { // Allow interrupting
throw new RuntimeInterruptedException();
}
seenAnaphors.add(pair.second);
MentionType mt1 = document.predictedMentionsByID.get(pair.first).mentionType;
MentionType mt2 = document.predictedMentionsByID.get(pair.second).mentionType;
if (pairwiseScores.getCount(pair) > thresholds.get(new Pair<>(mt1 == MentionType.PRONOMINAL,
mt2 == MentionType.PRONOMINAL))) {
CorefUtils.mergeCoreferenceClusters(pair, document);
}
}
}
}