package edu.stanford.nlp.coref.statistical;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import edu.stanford.nlp.coref.CorefDocumentProcessor;
import edu.stanford.nlp.coref.data.CorefCluster;
import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Pair;
/**
* Writes various pieces of information about coreference documents to disk.
* @author Kevin Clark
*/
public class MetadataWriter implements CorefDocumentProcessor {
private final Map<Integer, Map<Integer, String>> mentionTypes;
private final Map<Integer, List<List<Integer>>> goldClusters;
private final Counter<String> wordCounts;
private final Map<Integer, Map<Pair<Integer, Integer>, Boolean>> mentionPairs;
private final boolean countWords;
public MetadataWriter(boolean countWords) {
this.countWords = countWords;
mentionTypes = new HashMap<>();
goldClusters = new HashMap<>();
wordCounts = new ClassicCounter<>();
try {
mentionPairs = IOUtils.readObjectFromFile(StatisticalCorefTrainer.datasetFile);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void process(int id, Document document) {
// Mention types
mentionTypes.put(id, document.predictedMentionsByID.entrySet().stream().collect(
Collectors.toMap(Map.Entry::getKey, e -> e.getValue().mentionType.toString())));
// Gold clusters
List<List<Integer>> clusters = new ArrayList<>();
for (CorefCluster c : document.goldCorefClusters.values()) {
List<Integer> cluster = new ArrayList<>();
for (Mention m : c.getCorefMentions()) {
cluster.add(m.mentionID);
}
clusters.add(cluster);
}
goldClusters.put(id, clusters);
// Word counting
if (countWords && mentionPairs.containsKey(id)) {
Set<Pair<Integer, Integer>> pairs = mentionPairs.get(id).keySet();
Set<Integer> mentions = new HashSet<>();
for (Pair<Integer, Integer> pair : pairs) {
mentions.add(pair.first);
mentions.add(pair.second);
Mention m1 = document.predictedMentionsByID.get(pair.first);
Mention m2 = document.predictedMentionsByID.get(pair.second);
wordCounts.incrementCount("h_" + m1.headWord.word().toLowerCase() + "_"
+ m2.headWord.word().toLowerCase());
}
Map<Integer, List<CoreLabel>> sentences = new HashMap<>();
for (int mention : mentions) {
Mention m = document.predictedMentionsByID.get(mention);
if (!sentences.containsKey(m.sentNum)) {
sentences.put(m.sentNum, m.sentenceWords);
}
}
for (List<CoreLabel> sentence : sentences.values()) {
for (int i = 0; i < sentence.size(); i++) {
CoreLabel cl = sentence.get(i);
if (cl == null) {
continue;
}
String w = cl.word().toLowerCase();
wordCounts.incrementCount(w);
if (i > 0) {
CoreLabel clp = sentence.get(i - 1);
if (clp == null) {
continue;
}
String wp = clp.word().toLowerCase();
wordCounts.incrementCount(wp + "_" + w);
}
}
}
}
}
@Override
public void finish() throws Exception {
IOUtils.writeObjectToFile(mentionTypes, StatisticalCorefTrainer.mentionTypesFile);
IOUtils.writeObjectToFile(goldClusters, StatisticalCorefTrainer.goldClustersFile);
if (countWords) {
IOUtils.writeObjectToFile(wordCounts, StatisticalCorefTrainer.wordCountsFile);
}
}
}