package edu.stanford.nlp.coref.statistical; import java.io.PrintWriter; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; import edu.stanford.nlp.coref.statistical.MaxMarginMentionRanker.ErrorType; import edu.stanford.nlp.coref.data.Dictionaries.MentionType; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.logging.Redwood; /** * Class for training coreference models * @author Kevin Clark */ public class PairwiseModelTrainer { public static void trainRanking(PairwiseModel model) throws Exception { Redwood.log("scoref-train", "Reading compression..."); Compressor<String> compressor = IOUtils.readObjectFromFile( StatisticalCorefTrainer.compressorFile); Redwood.log("scoref-train", "Reading train data..."); List<DocumentExamples> trainDocuments = IOUtils.readObjectFromFile( StatisticalCorefTrainer.extractedFeaturesFile); Redwood.log("scoref-train", "Training..."); for (int i = 0; i < model.getNumEpochs(); i++) { Collections.shuffle(trainDocuments); int j = 0; for (DocumentExamples doc : trainDocuments) { j++; Redwood.log("scoref-train", "On epoch: " + i + " / " + model.getNumEpochs() + ", document: " + j + " / " + trainDocuments.size()); Map<Integer, List<Example>> mentionToPotentialAntecedents = new HashMap<>(); for (Example e : doc.examples) { int mention = e.mentionId2; List<Example> potentialAntecedents = mentionToPotentialAntecedents.get(mention); if (potentialAntecedents == null) { potentialAntecedents = new ArrayList<>(); mentionToPotentialAntecedents.put(mention, potentialAntecedents); } potentialAntecedents.add(e); } List<List<Example>> examples = new ArrayList<>( mentionToPotentialAntecedents.values()); Collections.shuffle(examples); for (List<Example> es : examples) { if (es.size() == 0) { continue; } if (model instanceof MaxMarginMentionRanker) { MaxMarginMentionRanker ranker = (MaxMarginMentionRanker) model; boolean noAntecedent = es.stream().allMatch(e -> e.label == 0); es.add(new Example(es.get(0), noAntecedent)); double maxPositiveScore = -Double.MAX_VALUE; Example maxScoringPositive = null; for (Example e : es) { double score = model.predict(e, doc.mentionFeatures, compressor); if (e.label == 1) { assert(!noAntecedent ^ e.isNewLink()); if (score > maxPositiveScore) { maxPositiveScore = score; maxScoringPositive = e; } } } assert(maxScoringPositive != null); double maxNegativeScore = -Double.MAX_VALUE; Example maxScoringNegative = null; ErrorType maxScoringEt = null; for (Example e : es) { double score = model.predict(e, doc.mentionFeatures, compressor); if (e.label != 1) { assert(!(noAntecedent && e.isNewLink())); ErrorType et = ErrorType.WL; if (noAntecedent && !e.isNewLink()) { et = ErrorType.FL; } else if (!noAntecedent && e.isNewLink()) { if (e.mentionType2 == MentionType.PRONOMINAL) { et = ErrorType.FN_PRON; } else { et = ErrorType.FN; } } if (ranker.multiplicativeCost) { score = ranker.costs[et.id] * (1 - maxPositiveScore + score); } else { score += ranker.costs[et.id]; } if (score > maxNegativeScore) { maxNegativeScore = score; maxScoringNegative = e; maxScoringEt = et; } } } assert(maxScoringNegative != null); ranker.learn(maxScoringPositive, maxScoringNegative, doc.mentionFeatures, compressor, maxScoringEt); } else { double maxPositiveScore = -Double.MAX_VALUE; double maxNegativeScore = -Double.MAX_VALUE; Example maxScoringPositive = null; Example maxScoringNegative = null; for (Example e : es) { double score = model.predict(e, doc.mentionFeatures, compressor); if (e.label == 1) { if (score > maxPositiveScore) { maxPositiveScore = score; maxScoringPositive = e; } } else { if (score > maxNegativeScore) { maxNegativeScore = score; maxScoringNegative = e; } } } model.learn(maxScoringPositive, maxScoringNegative, doc.mentionFeatures, compressor, 1); } } } } Redwood.log("scoref-train", "Writing models..."); model.writeModel(); } public static List<Pair<Example, Map<Integer, CompressedFeatureVector>>> getAnaphoricityExamples(List<DocumentExamples> documents) { int p = 0; int t = 0; List<Pair<Example, Map<Integer, CompressedFeatureVector>>> examples = new ArrayList<>(); while (!documents.isEmpty()) { DocumentExamples doc = documents.remove(documents.size() - 1); Map<Integer, Boolean> areAnaphoric = new HashMap<>(); for (Example e : doc.examples) { Boolean isAnaphoric = areAnaphoric.get(e.mentionId2); if (isAnaphoric == null) { areAnaphoric.put(e.mentionId2, false); } if (e.label == 1) { areAnaphoric.put(e.mentionId2, true); } } for (Map.Entry<Integer, Boolean> e : areAnaphoric.entrySet()) { if (e.getValue()) { p++; } t++; } for (Example e : doc.examples) { Boolean isAnaphoric = areAnaphoric.get(e.mentionId2); if (isAnaphoric != null) { areAnaphoric.remove(e.mentionId2); examples.add(new Pair<>(new Example(e, isAnaphoric), doc.mentionFeatures)); } } } Redwood.log("scoref-train", "Num anaphoricity examples " + p + " positive, " + t + " total"); return examples; } public static List<Pair<Example, Map<Integer, CompressedFeatureVector>>> getExamples( List<DocumentExamples> documents) { List<Pair<Example, Map<Integer, CompressedFeatureVector>>> examples = new ArrayList<>(); while (!documents.isEmpty()) { DocumentExamples doc = documents.remove(documents.size() - 1); Map<Integer, CompressedFeatureVector> mentionFeatures = doc.mentionFeatures; for (Example e : doc.examples) { examples.add(new Pair<>(e, mentionFeatures)); } } return examples; } public static void trainClassification(PairwiseModel model, boolean anaphoricityModel) throws Exception { int numTrainingExamples = model.getNumTrainingExamples(); Redwood.log("scoref-train", "Reading compression..."); Compressor<String> compressor = IOUtils.readObjectFromFile( StatisticalCorefTrainer.compressorFile); Redwood.log("scoref-train", "Reading train data..."); List<DocumentExamples> trainDocuments = IOUtils.readObjectFromFile( StatisticalCorefTrainer.extractedFeaturesFile); Redwood.log("scoref-train", "Building train set..."); List<Pair<Example, Map<Integer, CompressedFeatureVector>>> allExamples = anaphoricityModel ? getAnaphoricityExamples(trainDocuments) : getExamples(trainDocuments); Redwood.log("scoref-train", "Training..."); Random random = new Random(0); int i = 0; boolean stopTraining = false; while (!stopTraining) { Collections.shuffle(allExamples, random); for (Pair<Example, Map<Integer, CompressedFeatureVector>> pair : allExamples) { if (i++ > numTrainingExamples) { stopTraining = true; break; } if (i % 10000 == 0) { Redwood.log("scoref-train", String.format("On train example %d/%d = %.2f%%", i, numTrainingExamples, 100.0 * i / numTrainingExamples)); } model.learn(pair.first, pair.second, compressor); } } Redwood.log("scoref-train", "Writing models..."); model.writeModel(); } public static void test(PairwiseModel model, String predictionsName, boolean anaphoricityModel) throws Exception { Redwood.log("scoref-train", "Reading compression..."); Compressor<String> compressor = IOUtils.readObjectFromFile( StatisticalCorefTrainer.compressorFile); Redwood.log("scoref-train", "Reading test data..."); List<DocumentExamples> testDocuments = IOUtils.readObjectFromFile( StatisticalCorefTrainer.extractedFeaturesFile); Redwood.log("scoref-train", "Building test set..."); List<Pair<Example, Map<Integer, CompressedFeatureVector>>> allExamples = anaphoricityModel ? getAnaphoricityExamples(testDocuments) : getExamples(testDocuments); Redwood.log("scoref-train", "Testing..."); PrintWriter writer = new PrintWriter(model.getDefaultOutputPath() + predictionsName); Map<Integer, Counter<Pair<Integer, Integer>>> scores = new HashMap<>(); writeScores(allExamples, compressor, model, writer, scores); if (model instanceof MaxMarginMentionRanker) { writer.close(); writer = new PrintWriter(model.getDefaultOutputPath() + predictionsName + "_anaphoricity"); testDocuments = IOUtils.readObjectFromFile( StatisticalCorefTrainer.extractedFeaturesFile); allExamples = getAnaphoricityExamples(testDocuments); writeScores(allExamples, compressor, model, writer, scores); } IOUtils.writeObjectToFile(scores, model.getDefaultOutputPath() + predictionsName + ".ser"); writer.close(); } public static void writeScores(List<Pair<Example, Map<Integer, CompressedFeatureVector>>> examples, Compressor<String> compressor, PairwiseModel model, PrintWriter writer, Map<Integer, Counter<Pair<Integer, Integer>>> scores) { int i = 0; for (Pair<Example, Map<Integer, CompressedFeatureVector>> pair : examples) { if (i++ % 10000 == 0) { Redwood.log("scoref-train", String.format("On test example %d/%d = %.2f%%", i, examples.size(), 100.0 * i / examples.size())); } Example example = pair.first; Map<Integer, CompressedFeatureVector> mentionFeatures = pair.second; double p = model.predict(example, mentionFeatures, compressor); writer.println(example.docId + " " + example.mentionId1 + "," + example.mentionId2 + " " + p + " " + example.label); Counter<Pair<Integer, Integer>> docScores = scores.get(example.docId); if (docScores == null) { docScores = new ClassicCounter<>(); scores.put(example.docId, docScores); } docScores.incrementCount(new Pair<>(example.mentionId1, example.mentionId2), p); } } }