package edu.stanford.nlp.coref.neural; import java.io.File; import java.io.PrintWriter; import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Properties; import javax.json.Json; import javax.json.JsonArray; import javax.json.JsonArrayBuilder; import javax.json.JsonObject; import javax.json.JsonObjectBuilder; import edu.stanford.nlp.coref.CorefDocumentProcessor; import edu.stanford.nlp.coref.CorefProperties; import edu.stanford.nlp.coref.CorefProperties.Dataset; import edu.stanford.nlp.coref.CorefUtils; import edu.stanford.nlp.coref.data.CorefCluster; import edu.stanford.nlp.coref.data.Dictionaries; import edu.stanford.nlp.coref.data.Document; import edu.stanford.nlp.coref.data.Document.DocType; import edu.stanford.nlp.coref.data.Mention; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.ling.CoreAnnotations.SentencesAnnotation; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.semgraph.SemanticGraphEdge; import edu.stanford.nlp.util.CoreMap; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.StringUtils; /** * Outputs the CoNLL CoNLL data for training the neural coreference system * (implented in python/theano). * See <a href="https://github.com/clarkkev/deep-coref">https://github.com/clarkkev/deep-coref</a> * for the training code. * @author Kevin Clark */ public class NeuralCorefDataExporter implements CorefDocumentProcessor { private final boolean conll; private final PrintWriter dataWriter; private final PrintWriter goldClusterWriter; private final Dictionaries dictionaries; public NeuralCorefDataExporter(Properties props, Dictionaries dictionaries, String dataPath, String goldClusterPath) { conll = CorefProperties.conll(props); this.dictionaries = dictionaries; try { dataWriter = IOUtils.getPrintWriter(dataPath); goldClusterWriter = IOUtils.getPrintWriter(goldClusterPath); } catch (Exception e) { throw new RuntimeException("Error creating data exporter", e); } } @Override public void process(int id, Document document) { JsonArrayBuilder clusters = Json.createArrayBuilder(); for (CorefCluster gold : document.goldCorefClusters.values()) { JsonArrayBuilder c = Json.createArrayBuilder(); for (Mention m : gold.corefMentions) { c.add(m.mentionID); } clusters.add(c.build()); } goldClusterWriter.println(Json.createObjectBuilder().add(String.valueOf(id), clusters.build()).build()); Map<Pair<Integer, Integer>, Boolean> mentionPairs = CorefUtils.getLabeledMentionPairs(document); List<Mention> mentionsList = CorefUtils.getSortedMentions(document); Map<Integer, List<Mention>> mentionsByHeadIndex = new HashMap<>(); for (int i = 0; i < mentionsList.size(); i++) { Mention m = mentionsList.get(i); List<Mention> withIndex = mentionsByHeadIndex.get(m.headIndex); if (withIndex == null) { withIndex = new ArrayList<>(); mentionsByHeadIndex.put(m.headIndex, withIndex); } withIndex.add(m); } JsonObjectBuilder docFeatures = Json.createObjectBuilder(); docFeatures.add("doc_id", id); docFeatures.add("type", document.docType == DocType.ARTICLE ? 1 : 0); docFeatures.add("source", document.docInfo.get("DOC_ID").split("/")[0]); JsonArrayBuilder sentences = Json.createArrayBuilder(); for (CoreMap sentence : document.annotation.get(SentencesAnnotation.class)) { sentences.add(getSentenceArray(sentence.get(CoreAnnotations.TokensAnnotation.class))); } JsonObjectBuilder mentions = Json.createObjectBuilder(); for (Mention m : document.predictedMentionsByID.values()) { Iterator<SemanticGraphEdge> iterator = m.enhancedDependency.incomingEdgeIterator(m.headIndexedWord); SemanticGraphEdge relation = iterator.hasNext() ? iterator.next() : null; String depRelation = relation == null ? "no-parent" : relation.getRelation().toString(); String depParent = relation == null ? "<missing>" : relation.getSource().word(); mentions.add(String.valueOf(m.mentionNum), Json.createObjectBuilder() .add("doc_id", id) .add("mention_id", m.mentionID) .add("mention_num", m.mentionNum) .add("sent_num", m.sentNum) .add("start_index", m.startIndex) .add("end_index", m.endIndex) .add("head_index", m.headIndex) .add("mention_type", m.mentionType.toString()) .add("dep_relation", depRelation) .add("dep_parent", depParent) .add("sentence", getSentenceArray(m.sentenceWords)) .add("contained-in-other-mention", mentionsByHeadIndex.get(m.headIndex).stream() .anyMatch(m2 -> m != m2 && m.insideIn(m2)) ? 1 : 0) .build()); } JsonArrayBuilder featureNames = Json.createArrayBuilder() .add("same-speaker") .add("antecedent-is-mention-speaker") .add("mention-is-antecedent-speaker") .add("relaxed-head-match") .add("exact-string-match") .add("relaxed-string-match"); JsonObjectBuilder features = Json.createObjectBuilder(); JsonObjectBuilder labels = Json.createObjectBuilder(); for (Map.Entry<Pair<Integer, Integer>, Boolean> e : mentionPairs.entrySet()) { Mention m1 = document.predictedMentionsByID.get(e.getKey().first); Mention m2 = document.predictedMentionsByID.get(e.getKey().second); String key = m1.mentionNum + " " + m2.mentionNum; JsonArrayBuilder builder = Json.createArrayBuilder(); for (int val : CategoricalFeatureExtractor.pairwiseFeatures( document, m1, m2, dictionaries, conll)) { builder.add(val); } features.add(key, builder.build()); labels.add(key, e.getValue() ? 1 : 0); } JsonObject docData = Json.createObjectBuilder() .add("sentences", sentences.build()) .add("mentions", mentions.build()) .add("labels", labels.build()) .add("pair_feature_names", featureNames.build()) .add("pair_features", features.build()) .add("document_features", docFeatures.build()) .build(); dataWriter.println(docData); } @Override public void finish() throws Exception { dataWriter.close(); goldClusterWriter.close(); } private static JsonArray getSentenceArray(List<CoreLabel> sentence) { JsonArrayBuilder sentenceBuilder = Json.createArrayBuilder(); sentence.stream().map(CoreLabel::word) .map(w -> w.equals("/.") ? "." : w) .map(w -> w.equals("/?") ? "?" : w) .forEach(sentenceBuilder::add); return sentenceBuilder.build(); } public static void exportData(String outputPath, Dataset dataset, Properties props, Dictionaries dictionaries) throws Exception { CorefProperties.setInput(props, dataset); String dataPath = outputPath + "/data_raw/"; String goldClusterPath = outputPath + "/gold/"; IOUtils.ensureDir(new File(outputPath)); IOUtils.ensureDir(new File(dataPath)); IOUtils.ensureDir(new File(goldClusterPath)); new NeuralCorefDataExporter(props, dictionaries, dataPath + dataset.toString().toLowerCase(), goldClusterPath + dataset.toString().toLowerCase()).run(props, dictionaries); } public static void main(String[] args) throws Exception { Properties props = StringUtils.argsToProperties(new String[] {"-props", args[0]}); Dictionaries dictionaries = new Dictionaries(props); String outputPath = args[1]; exportData(outputPath, Dataset.TRAIN, props, dictionaries); exportData(outputPath, Dataset.DEV, props, dictionaries); exportData(outputPath, Dataset.TEST, props, dictionaries); } }