package edu.stanford.nlp.coref.statistical; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.util.Pair; /** * Loads the data used to train {@link Clusterer}. * @author Kevin Clark */ public class ClustererDataLoader { public static class ClustererDoc { public final int id; public final Counter<Pair<Integer, Integer>> classificationScores; public final Counter<Pair<Integer, Integer>> rankingScores; public final Counter<Integer> anaphoricityScores; public final List<List<Integer>> goldClusters; public final Map<Integer, List<Integer>> mentionToGold; public final List<Integer> mentions; public final Map<Integer, String> mentionTypes; public final Set<Pair<Integer, Integer>> positivePairs; public final Map<Integer, Integer> mentionIndices; public ClustererDoc(int id, Counter<Pair<Integer, Integer>> classificationScores, Counter<Pair<Integer, Integer>> rankingScores, Counter<Integer> anaphoricityScores, Map<Pair<Integer, Integer>, Boolean> labeledPairs, List<List<Integer>> goldClusters, Map<Integer, String> mentionTypes) { this.id = id; this.classificationScores = classificationScores; this.rankingScores = rankingScores; this.goldClusters = goldClusters; this.mentionTypes = mentionTypes; this.anaphoricityScores = anaphoricityScores; positivePairs = labeledPairs.keySet().stream().filter(p -> labeledPairs.get(p)) .collect(Collectors.toSet()); Set<Integer> mentionsSet = new HashSet<>(); for (Pair<Integer, Integer> pair : labeledPairs.keySet()) { mentionsSet.add(pair.first); mentionsSet.add(pair.second); } mentions = new ArrayList<>(mentionsSet); Collections.sort(mentions, (m1, m2) -> { Pair<Integer, Integer> p = new Pair<>(m1, m2); return m1 == m2 ? 0 : (classificationScores.containsKey(p) ? -1 : 1); }); mentionIndices = new HashMap<>(); for (int i = 0; i < mentions.size(); i++) { mentionIndices.put(mentions.get(i), i); } mentionToGold = new HashMap<>(); if (goldClusters != null) { for (List<Integer> gold : goldClusters) { for (int m : gold) { mentionToGold.put(m, gold); } } } } } public static List<ClustererDoc> loadDocuments(int maxDocs) throws Exception { Map<Integer, Map<Pair<Integer, Integer>, Boolean>> labeledPairs = IOUtils.readObjectFromFile(StatisticalCorefTrainer.datasetFile); Map<Integer, Map<Integer, String>> mentionTypes = IOUtils.readObjectFromFile(StatisticalCorefTrainer.mentionTypesFile); Map<Integer, List<List<Integer>>> goldClusters = IOUtils.readObjectFromFile(StatisticalCorefTrainer.goldClustersFile); Map<Integer, Counter<Pair<Integer, Integer>>> classificationScores = IOUtils.readObjectFromFile(StatisticalCorefTrainer.pairwiseModelsPath + StatisticalCorefTrainer.CLASSIFICATION_MODEL + "/" + StatisticalCorefTrainer.predictionsName + ".ser"); Map<Integer, Counter<Pair<Integer, Integer>>> rankingScores = IOUtils.readObjectFromFile(StatisticalCorefTrainer.pairwiseModelsPath + StatisticalCorefTrainer.RANKING_MODEL + "/" + StatisticalCorefTrainer.predictionsName + ".ser"); Map<Integer, Counter<Pair<Integer, Integer>>> anaphoricityScoresLoaded = IOUtils.readObjectFromFile(StatisticalCorefTrainer.pairwiseModelsPath + StatisticalCorefTrainer.ANAPHORICITY_MODEL + "/" + StatisticalCorefTrainer.predictionsName + ".ser"); Map<Integer, Counter<Integer>> anaphoricityScores = new HashMap<>(); for (Map.Entry<Integer, Counter<Pair<Integer, Integer>>> e : anaphoricityScoresLoaded.entrySet()) { Counter<Integer> scores = new ClassicCounter<>(); e.getValue().entrySet().forEach(e2 -> { scores.incrementCount(e2.getKey().second, e2.getValue()); }); anaphoricityScores.put(e.getKey(), scores); } return labeledPairs.keySet().stream().sorted().limit(maxDocs).map(i -> new ClustererDoc(i, classificationScores.get(i), rankingScores.get(i), anaphoricityScores.get(i), labeledPairs.get(i), goldClusters.get(i), mentionTypes.get(i))) .collect(Collectors.toList()); } }