package edu.stanford.nlp.coref.neural;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import edu.stanford.nlp.coref.CorefAlgorithm;
import edu.stanford.nlp.coref.CorefProperties;
import edu.stanford.nlp.coref.CorefUtils;
import edu.stanford.nlp.coref.data.Dictionaries;
import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import org.ejml.simple.SimpleMatrix;
/**
* Neural mention-ranking coreference model as described in
* <p/>
* Kevin Clark and Christopher D. Manning. 2016.
* <a href="http://nlp.stanford.edu/pubs/clark2016deep.pdf">
* Deep Reinforcement Learning for Mention-Ranking Coreference Models</a>.
* In Empirical Methods on Natural Language Processing.
* <p/>
* Training code is implemented in python and is available at
* <a href="https://github.com/clarkkev/deep-coref">https://github.com/clarkkev/deep-coref</a>.
* @author Kevin Clark
*/
public class NeuralCorefAlgorithm implements CorefAlgorithm {
private static Redwood.RedwoodChannels log = Redwood.channels(NeuralCorefAlgorithm.class);
private final double greedyness;
private final int maxMentionDistance;
private final int maxMentionDistanceWithStringMatch;
private final CategoricalFeatureExtractor featureExtractor;
private final EmbeddingExtractor embeddingExtractor;
private final NeuralCorefModel model;
public NeuralCorefAlgorithm(Properties props, Dictionaries dictionaries) {
greedyness = NeuralCorefProperties.greedyness(props);
maxMentionDistance = CorefProperties.maxMentionDistance(props);
maxMentionDistanceWithStringMatch = CorefProperties.maxMentionDistanceWithStringMatch(props);
model = IOUtils.readObjectAnnouncingTimingFromURLOrClasspathOrFileSystem(
log, "Loading coref model", NeuralCorefProperties.modelPath(props));
embeddingExtractor = new EmbeddingExtractor(CorefProperties.conll(props),
IOUtils.readObjectAnnouncingTimingFromURLOrClasspathOrFileSystem(
log, "Loading coref embeddings", NeuralCorefProperties.pretrainedEmbeddingsPath(props)),
model.getWordEmbeddings());
featureExtractor = new CategoricalFeatureExtractor(props, dictionaries);
}
@Override
public void runCoref(Document document) {
List<Mention> sortedMentions = CorefUtils.getSortedMentions(document);
Map<Integer, List<Mention>> mentionsByHeadIndex = new HashMap<>();
for (Mention m : sortedMentions) {
List<Mention> withIndex = mentionsByHeadIndex.get(m.headIndex);
if (withIndex == null) {
withIndex = new ArrayList<>();
mentionsByHeadIndex.put(m.headIndex, withIndex);
}
withIndex.add(m);
}
SimpleMatrix documentEmbedding = embeddingExtractor.getDocumentEmbedding(document);
Map<Integer, SimpleMatrix> antecedentEmbeddings = new HashMap<>();
Map<Integer, SimpleMatrix> anaphorEmbeddings = new HashMap<>();
Counter<Integer> anaphoricityScores = new ClassicCounter<>();
for (Mention m : sortedMentions) {
SimpleMatrix mentionEmbedding = embeddingExtractor.getMentionEmbeddings(m, documentEmbedding);
antecedentEmbeddings.put(m.mentionID, model.getAntecedentEmbedding(mentionEmbedding));
anaphorEmbeddings.put(m.mentionID, model.getAnaphorEmbedding(mentionEmbedding));
anaphoricityScores.incrementCount(m.mentionID,
model.getAnaphoricityScore(mentionEmbedding,
featureExtractor.getAnaphoricityFeatures(m, document, mentionsByHeadIndex)));
}
Map<Integer, List<Integer>> mentionToCandidateAntecedents = CorefUtils.heuristicFilter(sortedMentions,
maxMentionDistance, maxMentionDistanceWithStringMatch);
for (Map.Entry<Integer, List<Integer>> e : mentionToCandidateAntecedents.entrySet()) {
double bestScore = anaphoricityScores.getCount(e.getKey()) - 50 * (greedyness - 0.5);
int m = e.getKey();
Integer antecedent = null;
for (int ca : e.getValue()) {
double score = model.getPairwiseScore(antecedentEmbeddings.get(ca),
anaphorEmbeddings.get(m), featureExtractor.getPairFeatures(
new Pair<>(ca, m), document, mentionsByHeadIndex));
if (score > bestScore) {
bestScore = score;
antecedent = ca;
}
}
if (antecedent != null) {
CorefUtils.mergeCoreferenceClusters(new Pair<>(antecedent, m), document);
}
}
}
}