package edu.stanford.nlp.coref.statistical;
import java.io.File;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import edu.stanford.nlp.coref.statistical.ClustererDataLoader.ClustererDoc;
import edu.stanford.nlp.coref.statistical.EvalUtils.B3Evaluator;
import edu.stanford.nlp.coref.statistical.EvalUtils.Evaluator;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
/**
* System for building up coreference clusters incrementally, merging a pair of clusters each step.
* Trained with a variant of the SEARN imitation learning algorithm.
* @author Kevin Clark
*/
public class Clusterer {
private static final boolean USE_CLASSIFICATION = true;
private static final boolean USE_RANKING = true;
private static final boolean LEFT_TO_RIGHT = false;
private static final boolean EXACT_LOSS = false;
private static final double MUC_WEIGHT = 0.25;
private static final double EXPERT_DECAY = 0.0;
private static final double LEARNING_RATE = 0.05;
private static final int BUFFER_SIZE_MULTIPLIER = 20;
private static final int MAX_DOCS = 1000;
private static final int RETRAIN_ITERATIONS = 100;
private static final int NUM_EPOCHS = 15;
private static final int EVAL_FREQUENCY = 1;
private static final int MIN_PAIRS = 10;
private static final double MIN_PAIRWISE_SCORE = 0.15;
private static final int EARLY_STOP_THRESHOLD = 1000;
private static final double EARLY_STOP_VAL = 1500 / 0.2;
public static int currentDocId = 0;
public static int isTraining = 1;
private final ClustererClassifier classifier;
private final Random random;
public Clusterer() {
random = new Random(0);
classifier = new ClustererClassifier(LEARNING_RATE);
}
public Clusterer(String modelPath) {
random = new Random(0);
classifier = new ClustererClassifier(modelPath, LEARNING_RATE);
}
public List<Pair<Integer, Integer>> getClusterMerges(ClustererDoc doc) {
List<Pair<Integer, Integer>> merges = new ArrayList<>();
State currentState = new State(doc);
while (!currentState.isComplete()) {
Pair<Integer, Integer> currentPair =
currentState.mentionPairs.get(currentState.currentIndex);
if (currentState.doBestAction(classifier)) {
merges.add(currentPair);
}
}
return merges;
}
public void doTraining(String modelName) {
classifier.setWeight("bias", -0.3);
classifier.setWeight("anaphorSeen", -1);
classifier.setWeight("max-ranking", 1);
classifier.setWeight("bias-single", -0.3);
classifier.setWeight("anaphorSeen-single", -1);
classifier.setWeight("max-ranking-single", 1);
String outputPath = StatisticalCorefTrainer.clusteringModelsPath +
modelName + "/";
File outDir = new File(outputPath);
if (!outDir.exists()) {
outDir.mkdir();
}
PrintWriter progressWriter;
List<ClustererDoc> trainDocs;
try {
PrintWriter configWriter = new PrintWriter(outputPath + "config", "UTF-8");
configWriter.print(StatisticalCorefTrainer.fieldValues(this));
configWriter.close();
progressWriter = new PrintWriter(outputPath + "progress", "UTF-8");
Redwood.log("scoref.train", "Loading training data");
StatisticalCorefTrainer.setDataPath("dev");
trainDocs = ClustererDataLoader.loadDocuments(MAX_DOCS);
} catch (Exception e) {
throw new RuntimeException("Error setting up training", e);
}
double bestTrainScore = 0;
List<List<Pair<CandidateAction, CandidateAction>>> examples = new ArrayList<>();
for (int iteration = 0; iteration < RETRAIN_ITERATIONS; iteration++) {
Redwood.log("scoref.train", "ITERATION " + iteration);
classifier.printWeightVector(null);
Redwood.log("scoref.train", "");
try {
classifier.writeWeights(outputPath + "model");
classifier.printWeightVector(IOUtils.getPrintWriter(outputPath + "weights"));
} catch (Exception e) {
throw new RuntimeException();
}
long start = System.currentTimeMillis();
Collections.shuffle(trainDocs, random);
examples = examples.subList(Math.max(0, examples.size()
- BUFFER_SIZE_MULTIPLIER * trainDocs.size()), examples.size());
trainPolicy(examples);
if (iteration % EVAL_FREQUENCY == 0) {
double trainScore = evaluatePolicy(trainDocs, true);
if (trainScore > bestTrainScore) {
bestTrainScore = trainScore;
writeModel("best", outputPath);
}
if (iteration % 10 == 0) {
writeModel("iter_" + iteration, outputPath);
}
writeModel("last", outputPath);
double timeElapsed = (System.currentTimeMillis() - start) / 1000.0;
double ffhr = State.ffHits / (double) (State.ffHits + State.ffMisses);
double shr = State.sHits / (double) (State.sHits + State.sMisses);
double fhr = featuresCacheHits /
(double) (featuresCacheHits + featuresCacheMisses);
Redwood.log("scoref.train", modelName);
Redwood.log("scoref.train", String.format("Best train: %.4f", bestTrainScore));
Redwood.log("scoref.train", String.format("Time elapsed: %.2f", timeElapsed));
Redwood.log("scoref.train", String.format("Cost hit rate: %.4f", ffhr));
Redwood.log("scoref.train", String.format("Score hit rate: %.4f", shr));
Redwood.log("scoref.train", String.format("Features hit rate: %.4f", fhr));
Redwood.log("scoref.train", "");
progressWriter.write(iteration + " " + trainScore + " "
+ " " + timeElapsed + " " + ffhr + " " + shr
+ " " + fhr + "\n");
progressWriter.flush();
}
for (ClustererDoc trainDoc : trainDocs) {
examples.add(runPolicy(trainDoc, Math.pow(EXPERT_DECAY,
(iteration + 1))));
}
}
progressWriter.close();
}
private void writeModel(String name, String modelPath) {
try {
classifier.writeWeights(modelPath + name + "_model.ser");
classifier.printWeightVector(
IOUtils.getPrintWriter(modelPath + name + "_weights"));
} catch (Exception e) {
throw new RuntimeException();
}
}
private void trainPolicy(List<List<Pair<CandidateAction, CandidateAction>>> examples) {
List<Pair<CandidateAction, CandidateAction>> flattenedExamples = new ArrayList<>();
examples.stream().forEach(flattenedExamples::addAll);
for (int epoch = 0; epoch < NUM_EPOCHS; epoch++) {
Collections.shuffle(flattenedExamples, random);
flattenedExamples.forEach(classifier::learn);
}
double totalCost = flattenedExamples.stream()
.mapToDouble(e -> classifier.bestAction(e).cost).sum();
Redwood.log("scoref.train",
String.format("Training cost: %.4f", 100 * totalCost / flattenedExamples.size()));
}
private double evaluatePolicy(List<ClustererDoc> docs, boolean training) {
isTraining = 0;
B3Evaluator evaluator = new B3Evaluator();
for (ClustererDoc doc : docs) {
State currentState = new State(doc);
while (!currentState.isComplete()) {
currentState.doBestAction(classifier);
}
currentState.updateEvaluator(evaluator);
}
isTraining = 1;
double score = evaluator.getF1();
Redwood.log("scoref.train", String.format("B3 F1 score on %s: %.4f",
training ? "train" : "validate", score));
return score;
}
private List<Pair<CandidateAction, CandidateAction>> runPolicy(ClustererDoc doc, double beta) {
List<Pair<CandidateAction, CandidateAction>> examples = new ArrayList<>();
State currentState = new State(doc);
while (!currentState.isComplete()) {
Pair<CandidateAction, CandidateAction> actions = currentState.getActions(classifier);
if (actions == null) {
continue;
}
examples.add(actions);
boolean useExpert = random.nextDouble() < beta;
double action1Score = useExpert ? -actions.first.cost :
classifier.weightFeatureProduct(actions.first.features);
double action2Score = useExpert ? -actions.second.cost :
classifier.weightFeatureProduct(actions.second.features);
currentState.doAction(action1Score >= action2Score);
}
return examples;
}
private static class GlobalFeatures {
public boolean anaphorSeen;
public int currentIndex;
public int size;
public double docSize;
}
private static class State {
private static int sHits;
private static int sMisses;
private static int ffHits;
private static int ffMisses;
private final Map<MergeKey, Boolean> hashedScores;
private final Map<Long, Double> hashedCosts;
private final ClustererDoc doc;
private final List<Cluster> clusters;
private final Map<Integer, Cluster> mentionToCluster;
private final List<Pair<Integer, Integer>> mentionPairs;
private final List<GlobalFeatures> globalFeatures;
private int currentIndex;
private Cluster c1;
private Cluster c2;
private long hash;
public State(ClustererDoc doc) {
currentDocId = doc.id;
this.doc = doc;
this.hashedScores = new HashMap<>();
this.hashedCosts = new HashMap<>();
this.clusters = new ArrayList<>();
this.hash = 0;
mentionToCluster = new HashMap<>();
for (int m : doc.mentions) {
Cluster c = new Cluster(m);
clusters.add(c);
mentionToCluster.put(m, c);
hash ^= c.hash * 7;
}
List<Pair<Integer, Integer>> allPairs = new ArrayList<>(doc.classificationScores.keySet());
Counter<Pair<Integer, Integer>> scores =
USE_RANKING ? doc.rankingScores : doc.classificationScores;
Collections.sort(allPairs, (p1, p2) -> {
double diff = scores.getCount(p2) - scores.getCount(p1);
return diff == 0 ? 0 : (int) Math.signum(diff);
});
int i = 0;
for (i = 0; i < allPairs.size(); i++) {
double score = scores.getCount(allPairs.get(i));
if (score < MIN_PAIRWISE_SCORE && i > MIN_PAIRS) {
break;
}
if (i >= EARLY_STOP_THRESHOLD && i / score > EARLY_STOP_VAL) {
break;
}
}
mentionPairs = allPairs.subList(0, i);
if (LEFT_TO_RIGHT) {
Collections.sort(mentionPairs, (p1, p2) -> {
if (p1.second.equals(p2.second)) {
double diff = scores.getCount(p2) - scores.getCount(p1);
return diff == 0 ? 0 : (int) Math.signum(diff);
}
return doc.mentionIndices.get(p1.second)
< doc.mentionIndices.get(p2.second) ? -1 : 1;
});
for (int j = 0; j < mentionPairs.size(); j++) {
Pair<Integer, Integer> p1 = mentionPairs.get(j);
for (int k = j + 1; k < mentionPairs.size(); k++) {
Pair<Integer, Integer> p2 = mentionPairs.get(k);
assert(doc.mentionIndices.get(p1.second)
<= doc.mentionIndices.get(p2.second));
}
}
}
Counter<Integer> seenAnaphors = new ClassicCounter<>();
Counter<Integer> seenAntecedents = new ClassicCounter<>();
globalFeatures = new ArrayList<>();
for (int j = 0; j < allPairs.size(); j++) {
Pair<Integer, Integer> mentionPair = allPairs.get(j);
GlobalFeatures gf = new GlobalFeatures();
gf.currentIndex = j;
gf.anaphorSeen = seenAnaphors.containsKey(mentionPair.second);
gf.size = mentionPairs.size();
gf.docSize = doc.mentions.size() / 300.0;
globalFeatures.add(gf);
seenAnaphors.incrementCount(mentionPair.second);
seenAntecedents.incrementCount(mentionPair.first);
}
currentIndex = 0;
setClusters();
}
public State(State state) {
this.hashedScores = state.hashedScores;
this.hashedCosts = state.hashedCosts;
this.doc = state.doc;
this.hash = state.hash;
this.mentionPairs = state.mentionPairs;
this.currentIndex = state.currentIndex;
this.globalFeatures = state.globalFeatures;
this.clusters = new ArrayList<>();
this.mentionToCluster = new HashMap<>();
for (Cluster c : state.clusters) {
Cluster copy = new Cluster(c);
clusters.add(copy);
for (int m : copy.mentions) {
mentionToCluster.put(m, copy);
}
}
setClusters();
}
public void setClusters() {
Pair<Integer, Integer> currentPair = mentionPairs.get(currentIndex);
c1 = mentionToCluster.get(currentPair.first);
c2 = mentionToCluster.get(currentPair.second);
}
public void doAction(boolean isMerge) {
if (isMerge) {
if (c2.size() > c1.size()) {
Cluster tmp = c1;
c1 = c2;
c2 = tmp;
}
hash ^= 7 * c1.hash;
hash ^= 7 * c2.hash;
c1.merge(c2);
for (int m : c2.mentions) {
mentionToCluster.put(m, c1);
}
clusters.remove(c2);
hash ^= 7 * c1.hash;
}
currentIndex++;
if (!isComplete()) {
setClusters();
}
while (c1 == c2) {
currentIndex++;
if (isComplete()) {
break;
}
setClusters();
}
}
public boolean doBestAction(ClustererClassifier classifier) {
Boolean doMerge = hashedScores.get(new MergeKey(c1, c2, currentIndex));
if (doMerge == null) {
Counter<String> features = getFeatures(doc, c1, c2,
globalFeatures.get(currentIndex));
doMerge = classifier.weightFeatureProduct(features) > 0;
hashedScores.put(new MergeKey(c1, c2, currentIndex), doMerge);
sMisses += isTraining;
} else {
sHits += isTraining;
}
doAction(doMerge);
return doMerge;
}
public boolean isComplete() {
return currentIndex >= mentionPairs.size();
}
public double getFinalCost(ClustererClassifier classifier) {
while(EXACT_LOSS && !isComplete()) {
if (hashedCosts.containsKey(hash)) {
ffHits += isTraining;;
return hashedCosts.get(hash);
}
doBestAction(classifier);
}
ffMisses += isTraining;
double cost = EvalUtils.getCombinedF1(MUC_WEIGHT, doc.goldClusters, clusters,
doc.mentionToGold, mentionToCluster);
hashedCosts.put(hash, cost);
return cost;
}
public void updateEvaluator(Evaluator evaluator) {
evaluator.update(doc.goldClusters, clusters, doc.mentionToGold, mentionToCluster);
}
public Pair<CandidateAction, CandidateAction> getActions(ClustererClassifier classifier) {
Counter<String> mergeFeatures = getFeatures(doc, c1, c2,
globalFeatures.get(currentIndex));
double mergeScore = Math.exp(classifier.weightFeatureProduct(mergeFeatures));
hashedScores.put(new MergeKey(c1, c2, currentIndex), mergeScore > 0.5);
State merge = new State(this);
merge.doAction(true);
double mergeB3 = merge.getFinalCost(classifier);
State noMerge = new State(this);
noMerge.doAction(false);
double noMergeB3 = noMerge.getFinalCost(classifier);
double weight = doc.mentions.size() / 100.0;
double maxB3 = Math.max(mergeB3, noMergeB3);
return new Pair<>(
new CandidateAction(mergeFeatures, weight * (maxB3 - mergeB3)),
new CandidateAction(new ClassicCounter<>(), weight * (maxB3 - noMergeB3)));
}
}
private static class MergeKey {
private final int hash;
public MergeKey(Cluster c1, Cluster c2, int ind) {
hash = (int)(c1.hash ^ c2.hash) + (2003 * ind) + currentDocId;
}
@Override
public int hashCode() {
return hash;
}
@Override
public boolean equals(Object o) {
return ((MergeKey) o).hash == hash;
}
}
public static class Cluster {
private static final Map<Pair<Integer, Integer>, Long> MENTION_HASHES = new HashMap<>();
private static final Random RANDOM = new Random(0);
public final List<Integer> mentions;
public long hash;
public Cluster(int m) {
mentions = new ArrayList<>();
mentions.add(m);
hash = getMentionHash(m);
}
public Cluster(Cluster c) {
mentions = new ArrayList<>(c.mentions);
hash = c.hash;
}
public void merge(Cluster c) {
mentions.addAll(c.mentions);
hash ^= c.hash;
}
public int size() {
return mentions.size();
}
public long getHash() {
return hash;
}
private static long getMentionHash(int m) {
Pair<Integer, Integer> pair = new Pair<>(m, currentDocId);
Long hash = MENTION_HASHES.get(pair);
if (hash == null) {
hash = RANDOM.nextLong();
MENTION_HASHES.put(pair, hash);
}
return hash;
}
}
private static int featuresCacheHits;
private static int featuresCacheMisses;
private static Map<MergeKey, CompressedFeatureVector> featuresCache = new HashMap<>();
private static Compressor<String> compressor = new Compressor<>();
private static Counter<String> getFeatures(ClustererDoc doc, Pair<Integer, Integer> mentionPair,
Counter<Pair<Integer, Integer>> scores) {
Counter<String> features = new ClassicCounter<>();
if (!scores.containsKey(mentionPair)) {
mentionPair = new Pair<>(mentionPair.second, mentionPair.first);
}
double score = scores.getCount(mentionPair);
features.incrementCount("max", score);
return features;
}
private static Counter<String> getFeatures(ClustererDoc doc,
List<Pair<Integer, Integer>> mentionPairs, Counter<Pair<Integer, Integer>> scores) {
Counter<String> features = new ClassicCounter<>();
double maxScore = 0;
double minScore = 1;
Counter<String> totals = new ClassicCounter<>();
Counter<String> totalsLog = new ClassicCounter<>();
Counter<String> counts = new ClassicCounter<>();
for (Pair<Integer, Integer> mentionPair : mentionPairs) {
if (!scores.containsKey(mentionPair)) {
mentionPair = new Pair<>(mentionPair.second, mentionPair.first);
}
double score = scores.getCount(mentionPair);
double logScore = cappedLog(score);
String mt1 = doc.mentionTypes.get(mentionPair.first);
String mt2 = doc.mentionTypes.get(mentionPair.second);
mt1 = mt1.equals("PRONOMINAL") ? "PRONOMINAL" : "NON_PRONOMINAL";
mt2 = mt2.equals("PRONOMINAL") ? "PRONOMINAL" : "NON_PRONOMINAL";
String conj = "_" + mt1 + "_" + mt2;
maxScore = Math.max(maxScore, score);
minScore = Math.min(minScore, score);
totals.incrementCount("", score);
totalsLog.incrementCount("", logScore);
counts.incrementCount("");
totals.incrementCount(conj, score);
totalsLog.incrementCount(conj, logScore);
counts.incrementCount(conj);
}
features.incrementCount("max", maxScore);
features.incrementCount("min", minScore);
for (String key : counts.keySet()) {
features.incrementCount("avg" + key, totals.getCount(key) / mentionPairs.size());
features.incrementCount("avgLog" + key, totalsLog.getCount(key) / mentionPairs.size());
}
return features;
}
private static int earliestMention(Cluster c, ClustererDoc doc) {
int earliest = -1;
for (int m : c.mentions) {
int pos = doc.mentionIndices.get(m);
if (earliest == -1 || pos < doc.mentionIndices.get(earliest)) {
earliest = m;
}
}
return earliest;
}
private static Counter<String> getFeatures(ClustererDoc doc, Cluster c1, Cluster c2, GlobalFeatures gf) {
MergeKey key = new MergeKey(c1, c2, gf.currentIndex);
CompressedFeatureVector cfv = featuresCache.get(key);
Counter<String> features = cfv == null ? null : compressor.uncompress(cfv);
if (features != null) {
featuresCacheHits += isTraining;
return features;
}
featuresCacheMisses += isTraining;
features = new ClassicCounter<>();
if (gf.anaphorSeen) {
features.incrementCount("anaphorSeen");
}
features.incrementCount("docSize", gf.docSize);
features.incrementCount("percentComplete", gf.currentIndex / (double) gf.size);
features.incrementCount("bias", 1.0);
int earliest1 = earliestMention(c1, doc);
int earliest2 = earliestMention(c2, doc);
if (doc.mentionIndices.get(earliest1) > doc.mentionIndices.get(earliest2)) {
int tmp = earliest1;
earliest1 = earliest2;
earliest2 = tmp;
}
features.incrementCount("anaphoricity", doc.anaphoricityScores.getCount(earliest2));
if (c1.mentions.size() == 1 && c2.mentions.size() == 1) {
Pair<Integer, Integer> mentionPair = new Pair<>(c1.mentions.get(0),
c2.mentions.get(0));
if (USE_CLASSIFICATION) {
features.addAll(addSuffix(getFeatures(doc, mentionPair, doc.classificationScores),
"-classification"));
}
if (USE_RANKING) {
features.addAll(addSuffix(getFeatures(doc, mentionPair, doc.rankingScores),
"-ranking"));
}
features = addSuffix(features, "-single");
} else {
List<Pair<Integer, Integer>> between = new ArrayList<>();
for (int m1 : c1.mentions) {
for (int m2 : c2.mentions) {
between.add(new Pair<>(m1, m2));
}
}
if (USE_CLASSIFICATION) {
features.addAll(addSuffix(getFeatures(doc, between, doc.classificationScores),
"-classification"));
}
if (USE_RANKING) {
features.addAll(addSuffix(getFeatures(doc, between, doc.rankingScores),
"-ranking"));
}
}
featuresCache.put(key, compressor.compress(features));
return features;
}
private static Counter<String> addSuffix(Counter<String> features, String suffix) {
Counter<String> withSuffix = new ClassicCounter<>();
for (Map.Entry<String, Double> e : features.entrySet()) {
withSuffix.incrementCount(e.getKey() + suffix, e.getValue());
}
return withSuffix;
}
private static double cappedLog(double x) {
return Math.log(Math.max(x, 1e-8));
}
private static class ClustererClassifier extends SimpleLinearClassifier {
public ClustererClassifier(double learningRate) {
super(SimpleLinearClassifier.risk(),
SimpleLinearClassifier.constant(learningRate),
0);
}
public ClustererClassifier(String modelFile, double learningRate) {
super(SimpleLinearClassifier.risk(),
SimpleLinearClassifier.constant(learningRate),
0,
modelFile);
}
public CandidateAction bestAction(Pair<CandidateAction, CandidateAction> actions) {
return weightFeatureProduct(actions.first.features) >
weightFeatureProduct(actions.second.features) ? actions.first : actions.second;
}
public void learn(Pair<CandidateAction, CandidateAction> actions) {
CandidateAction goodAction = actions.first;
CandidateAction badAction = actions.second;
if (badAction.cost == 0) {
CandidateAction tmp = goodAction;
goodAction = badAction;
badAction = tmp;
}
Counter<String> features = new ClassicCounter<>(goodAction.features);
for (Map.Entry<String, Double> e : badAction.features.entrySet()) {
features.decrementCount(e.getKey(), e.getValue());
}
learn(features, 0, badAction.cost);
}
}
private static class CandidateAction {
public final Counter<String> features;
public final double cost;
public CandidateAction(Counter<String> features, double cost) {
this.features = features;
this.cost = cost;
}
}
}