package edu.stanford.nlp.coref.statistical;
import java.io.File;
import java.io.PrintWriter;
import java.util.Map;
import edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.LearningRateSchedule;
import edu.stanford.nlp.coref.statistical.SimpleLinearClassifier.Loss;
import edu.stanford.nlp.stats.Counter;
/**
* Pairwise mention-classification model.
* @author Kevin Clark
*/
public class PairwiseModel {
public final String name;
private final int trainingExamples;
private final int epochs;
protected final SimpleLinearClassifier classifier;
private final double singletonRatio;
private final String str;
protected final MetaFeatureExtractor meta;
public static class Builder {
private final String name;
private final MetaFeatureExtractor meta;
@SuppressWarnings("unused") // output in config file with reflection
private final String source = StatisticalCorefTrainer.extractedFeaturesFile;
private int trainingExamples = 100000000;
private int epochs = 8;
private Loss loss = SimpleLinearClassifier.log();
private LearningRateSchedule learningRateSchedule =
SimpleLinearClassifier.adaGrad(0.05, 30.0);
private double regularizationStrength = 1e-7;
private double singletonRatio = 0.3;
private String modelFile = null;
public Builder(String name, MetaFeatureExtractor meta) {
this.name = name;
this.meta = meta;
}
public Builder trainingExamples(int trainingExamples)
{ this.trainingExamples = trainingExamples; return this; }
public Builder epochs(int epochs)
{ this.epochs = epochs; return this; }
public Builder singletonRatio(double singletonRatio)
{ this.singletonRatio = singletonRatio; return this; }
public Builder loss(Loss loss)
{ this.loss = loss; return this; }
public Builder regularizationStrength(double regularizationStrength)
{ this.regularizationStrength = regularizationStrength; return this; }
public Builder learningRateSchedule(LearningRateSchedule learningRateSchedule)
{ this.learningRateSchedule = learningRateSchedule; return this; }
public Builder modelPath(String modelFile)
{ this.modelFile = modelFile; return this; }
public PairwiseModel build() {
return new PairwiseModel(this);
}
}
public static Builder newBuilder(String name, MetaFeatureExtractor meta) {
return new Builder(name, meta);
}
public PairwiseModel(Builder builder) {
name = builder.name;
meta = builder.meta;
trainingExamples = builder.trainingExamples;
epochs = builder.epochs;
singletonRatio = builder.singletonRatio;
classifier = new SimpleLinearClassifier(builder.loss, builder.learningRateSchedule,
builder.regularizationStrength, builder.modelFile == null ? null :
((builder.modelFile.endsWith(".ser") || builder.modelFile.endsWith(".gz")) ? builder.modelFile :
StatisticalCorefTrainer.pairwiseModelsPath + builder.modelFile + "/model.ser"));
str = StatisticalCorefTrainer.fieldValues(builder);
}
public String getDefaultOutputPath() {
return StatisticalCorefTrainer.pairwiseModelsPath + name +"/";
}
public SimpleLinearClassifier getClassifier() {
return classifier;
}
public void writeModel() throws Exception {
writeModel(getDefaultOutputPath());
}
public void writeModel(String outputPath) throws Exception {
File outDir = new File(outputPath);
if (!outDir.exists()) {
outDir.mkdir();
}
try (PrintWriter writer = new PrintWriter(outputPath + "config", "UTF-8")) {
writer.print(str);
}
try (PrintWriter writer = new PrintWriter(outputPath + "/weights", "UTF-8")) {
classifier.printWeightVector(writer);
}
classifier.writeWeights(outputPath + "/model.ser");
}
public void learn(Example example,
Map<Integer, CompressedFeatureVector> mentionFeatures, Compressor<String> compressor) {
Counter<String> features = meta.getFeatures(example, mentionFeatures, compressor);
classifier.learn(features, example.label == 1.0 ? 1.0 : -1.0, 1.0);
}
public void learn(Example example,
Map<Integer, CompressedFeatureVector> mentionFeatures, Compressor<String> compressor,
double weight) {
Counter<String> features = meta.getFeatures(example, mentionFeatures, compressor);
classifier.learn(features, example.label == 1.0 ? 1.0 : -1.0, weight);
}
public void learn(Example correct, Example incorrect,
Map<Integer, CompressedFeatureVector> mentionFeatures, Compressor<String> compressor,
double weight) {
Counter<String> cFeatures = null;
Counter<String> iFeatures = null;
if (correct != null) {
cFeatures = meta.getFeatures(correct, mentionFeatures, compressor);
}
if (incorrect != null) {
iFeatures = meta.getFeatures(incorrect, mentionFeatures, compressor);
}
if (correct == null || incorrect == null) {
if (singletonRatio != 0) {
if (correct != null) {
classifier.learn(cFeatures, 1.0, weight * singletonRatio);
}
if (incorrect != null) {
classifier.learn(iFeatures, -1.0, weight * singletonRatio);
}
}
} else {
classifier.learn(cFeatures, 1.0, weight);
classifier.learn(iFeatures, -1.0, weight);
}
}
public double predict(Example example,
Map<Integer, CompressedFeatureVector> mentionFeatures, Compressor<String> compressor) {
Counter<String> features = meta.getFeatures(example, mentionFeatures, compressor);
return classifier.label(features);
}
public int getNumTrainingExamples() {
return trainingExamples;
}
public int getNumEpochs() {
return epochs;
}
}