package edu.stanford.nlp.coref.neural;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Properties;
import org.ejml.simple.SimpleMatrix;
import edu.stanford.nlp.coref.CorefProperties;
import edu.stanford.nlp.coref.CorefRules;
import edu.stanford.nlp.coref.data.Dictionaries;
import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.util.Pair;
/**
* Extracts string matching, speaker, distance, and document genre features from mentions.
* @author Kevin Clark
*/
public class CategoricalFeatureExtractor {
private final Dictionaries dictionaries;
private final Map<String, Integer> genres;
private final boolean conll;
public CategoricalFeatureExtractor(Properties props, Dictionaries dictionaries) {
this.dictionaries = dictionaries;
conll = CorefProperties.conll(props);
if (conll) {
genres = new HashMap<>();
genres.put("bc", 0);
genres.put("bn", 1);
genres.put("mz", 2);
genres.put("nw", 3);
boolean english = CorefProperties.getLanguage(props) == Locale.ENGLISH;
if (english) {
genres.put("pt", 4);
}
genres.put("tc", english ? 5 : 4);
genres.put("wb", english ? 6 : 5);
} else {
genres = null;
}
}
public SimpleMatrix getPairFeatures(Pair<Integer, Integer> pair, Document document,
Map<Integer, List<Mention>> mentionsByHeadIndex) {
Mention m1 = document.predictedMentionsByID.get(pair.first);
Mention m2 = document.predictedMentionsByID.get(pair.second);
List<Integer> featureVals = pairwiseFeatures(document, m1, m2, dictionaries, conll);
SimpleMatrix features = new SimpleMatrix(featureVals.size(), 1);
for (int i = 0; i < featureVals.size(); i++) {
features.set(i, featureVals.get(i));
}
features = NeuralUtils.concatenate(features,
encodeDistance(m2.sentNum - m1.sentNum),
encodeDistance(m2.mentionNum - m1.mentionNum - 1),
new SimpleMatrix(new double[][] {{
m1.sentNum == m2.sentNum && m1.endIndex > m2.startIndex ? 1 : 0}}),
getMentionFeatures(m1, document, mentionsByHeadIndex),
getMentionFeatures(m2, document, mentionsByHeadIndex),
encodeGenre(document));
return features;
}
public static List<Integer> pairwiseFeatures(Document document, Mention m1, Mention m2,
Dictionaries dictionaries, boolean isConll) {
String speaker1 = m1.headWord.get(CoreAnnotations.SpeakerAnnotation.class);
String speaker2 = m2.headWord.get(CoreAnnotations.SpeakerAnnotation.class);
List<Integer> features = new ArrayList<>();
features.add(isConll ? (speaker1.equals(speaker2) ? 1 : 0) : 0);
features.add(isConll ?
(CorefRules.antecedentIsMentionSpeaker(document, m2, m1, dictionaries) ? 1 : 0) : 0);
features.add(isConll ?
(CorefRules.antecedentIsMentionSpeaker(document, m1, m2, dictionaries) ? 1 : 0) : 0);
features.add(m1.headsAgree(m2) ? 1 : 0);
features.add(
m1.toString().trim().toLowerCase().equals(m2.toString().trim().toLowerCase()) ? 1 : 0);
features.add(edu.stanford.nlp.coref.statistical.FeatureExtractor.relaxedStringMatch(m1, m2)
? 1 : 0);
return features;
}
public SimpleMatrix getAnaphoricityFeatures(Mention m, Document document,
Map<Integer, List<Mention>> mentionsByHeadIndex) {
return NeuralUtils.concatenate(
getMentionFeatures(m, document, mentionsByHeadIndex),
encodeGenre(document)
);
}
private SimpleMatrix getMentionFeatures(Mention m, Document document,
Map<Integer, List<Mention>> mentionsByHeadIndex) {
return NeuralUtils.concatenate(
NeuralUtils.oneHot(m.mentionType.ordinal(), 4),
encodeDistance(m.endIndex - m.startIndex - 1),
new SimpleMatrix(new double[][] {
{m.mentionNum / (double) document.predictedMentionsByID.size()},
{mentionsByHeadIndex.get(m.headIndex).stream()
.anyMatch(m2 -> m != m2 && m.insideIn(m2)) ? 1 : 0}})
);
}
private static SimpleMatrix encodeDistance(int d) {
SimpleMatrix m = new SimpleMatrix(11, 1);
if (d < 5) {
m.set(d, 1);
} else if (d < 8) {
m.set(5, 1);
} else if (d < 16) {
m.set(6, 1);
} else if (d < 32) {
m.set(7, 1);
} else if (d < 64) {
m.set(8, 1);
} else {
m.set(9, 1);
}
m.set(10, Math.min(d, 64) / 64.0);
return m;
}
private SimpleMatrix encodeGenre(Document document) {
return conll ? NeuralUtils.oneHot(
genres.get(document.docInfo.get("DOC_ID").split("/")[0]), genres.size()) :
new SimpleMatrix(1, 1);
}
}