package edu.stanford.nlp.coref.statistical; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import java.util.stream.Collectors; import edu.stanford.nlp.coref.CorefDocumentProcessor; import edu.stanford.nlp.coref.CorefUtils; import edu.stanford.nlp.coref.data.Document; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.util.Pair; /** * Produces train/dev/test sets for training coreference models with (optionally) sampling. * @author Kevin Clark */ public class DatasetBuilder implements CorefDocumentProcessor { private final int maxExamplesPerDocument; private final double minClassImbalancedPerDocument; private final Map<Integer, Map<Pair<Integer, Integer>, Boolean>> mentionPairs; private final Random random; public DatasetBuilder() { this(0, Integer.MAX_VALUE); } public DatasetBuilder(double minClassImbalancedPerDocument, int maxExamplesPerDocument) { this.maxExamplesPerDocument = maxExamplesPerDocument; this.minClassImbalancedPerDocument = minClassImbalancedPerDocument; mentionPairs = new HashMap<>(); random = new Random(0); } @Override public void process(int id, Document document) { Map<Pair<Integer, Integer>, Boolean> labeledPairs = CorefUtils.getLabeledMentionPairs(document); long numP = labeledPairs.keySet().stream().filter(m -> labeledPairs.get(m)).count(); List<Pair<Integer, Integer>> negative = labeledPairs.keySet().stream() .filter(m -> !labeledPairs.get(m)) .collect(Collectors.toList()); int numN = negative.size(); if (numP / (float) (numP + numN) < minClassImbalancedPerDocument) { numN = (int) (numP / minClassImbalancedPerDocument - numP); Collections.shuffle(negative); for (int i = numN; i < negative.size(); i++) { labeledPairs.remove(negative.get(i)); } } Map<Integer, List<Integer>> mentionToCandidateAntecedents = new HashMap<>(); for (Pair<Integer, Integer> pair : labeledPairs.keySet()) { List<Integer> candidateAntecedents = mentionToCandidateAntecedents.get(pair.second); if (candidateAntecedents == null) { candidateAntecedents = new ArrayList<>(); mentionToCandidateAntecedents.put(pair.second, candidateAntecedents); } candidateAntecedents.add(pair.first); } List<Integer> mentions = new ArrayList<>(mentionToCandidateAntecedents.keySet()); while (labeledPairs.size() > maxExamplesPerDocument) { int mention = mentions.remove(random.nextInt(mentions.size())); for (int candidateAntecedent : mentionToCandidateAntecedents.get(mention)) { labeledPairs.remove(new Pair<>(candidateAntecedent, mention)); } } mentionPairs.put(id, labeledPairs); } @Override public void finish() throws Exception { IOUtils.writeObjectToFile(mentionPairs, StatisticalCorefTrainer.datasetFile); } }