package edu.stanford.nlp.coref.statistical;
import java.util.Map;
import edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.Loss;
import edu.stanford.nlp.stats.Counter;
/**
* A max-margin mention-ranking coreference model.
* @author Kevin Clark
*/
public class MaxMarginMentionRanker extends PairwiseModel {
public enum ErrorType {
FN(0), FN_PRON(1), FL(2), WL(3);
public final int id;
private ErrorType(int id) {
this.id = id;
}
};
private final Loss[] losses = new Loss[ErrorType.values().length];
private final Loss loss;
public final double[] costs;
public final boolean multiplicativeCost;
public static class Builder extends PairwiseModel.Builder {
private double[] costs = new double[] {1.2, 1.2, 0.5, 1.0};
private boolean multiplicativeCost = true;
public Builder(String name, MetaFeatureExtractor meta) {
super(name, meta);
}
public Builder setCosts(double fnCost, double fnPronounCost, double faCost, double wlCost)
{ this.costs = new double[] {fnCost, fnPronounCost, faCost, wlCost}; return this; }
public Builder multiplicativeCost(boolean multiplicativeCost)
{ this.multiplicativeCost = multiplicativeCost; return this; }
@Override
public MaxMarginMentionRanker build() {
return new MaxMarginMentionRanker(this);
}
}
public static Builder newBuilder(String name, MetaFeatureExtractor meta) {
return new Builder(name, meta);
}
public MaxMarginMentionRanker(Builder builder) {
super(builder);
costs = builder.costs;
multiplicativeCost = builder.multiplicativeCost;
if (multiplicativeCost) {
for (ErrorType et : ErrorType.values()) {
losses[et.id] = SimpleLinearClassifier.maxMargin(builder.costs[et.id]);
}
}
loss = SimpleLinearClassifier.maxMargin(1.0);
}
public void learn(Example correct, Example incorrect,
Map<Integer, CompressedFeatureVector> mentionFeatures, Compressor<String> compressor,
ErrorType errorType) {
Counter<String> cFeatures = meta.getFeatures(correct, mentionFeatures, compressor);
Counter<String> iFeatures = meta.getFeatures(incorrect, mentionFeatures, compressor);
for (Map.Entry<String, Double> e : cFeatures.entrySet()) {
iFeatures.decrementCount(e.getKey(), e.getValue());
}
if (multiplicativeCost) {
classifier.learn(iFeatures, 1.0, costs[errorType.id], loss);
} else {
classifier.learn(iFeatures, 1.0, 1.0, losses[errorType.id]);
}
}
}