package edu.stanford.nlp.coref.statistical;
import java.io.File;
import java.lang.reflect.Field;
import java.util.Properties;
import edu.stanford.nlp.coref.CorefProperties;
import edu.stanford.nlp.coref.CorefProperties.Dataset;
import edu.stanford.nlp.coref.data.Dictionaries;
import edu.stanford.nlp.util.StringUtils;
/**
* Main class for training new statistical coreference systems.
* @author Kevin Clark
*/
public class StatisticalCorefTrainer {
public static final String CLASSIFICATION_MODEL = "classification";
public static final String RANKING_MODEL = "ranking";
public static final String ANAPHORICITY_MODEL = "anaphoricity";
public static final String CLUSTERING_MODEL_NAME = "clusterer";
public static final String EXTRACTED_FEATURES_NAME = "features";
public static String trainingPath;
public static String pairwiseModelsPath;
public static String clusteringModelsPath;
public static String predictionsName;
public static String datasetFile;
public static String goldClustersFile;
public static String wordCountsFile;
public static String mentionTypesFile;
public static String compressorFile;
public static String extractedFeaturesFile;
private static void makeDir(String path) {
File outDir = new File(path);
if (!outDir.exists()) {
outDir.mkdir();
}
}
public static void setTrainingPath(Properties props) {
trainingPath = StatisticalCorefProperties.trainingPath(props);
pairwiseModelsPath = trainingPath + "pairwise_models/";
clusteringModelsPath = trainingPath + "clustering_models/";
makeDir(pairwiseModelsPath);
makeDir(clusteringModelsPath);
}
public static void setDataPath(String name) {
String dataPath = trainingPath + name + "/";
String extractedFeaturesPath = dataPath + EXTRACTED_FEATURES_NAME + "/";
makeDir(dataPath);
makeDir(extractedFeaturesPath);
datasetFile = dataPath + "dataset.ser";
predictionsName = name + "_predictions";
goldClustersFile = dataPath + "gold_clusters.ser";
mentionTypesFile = dataPath + "mention_types.ser";
compressorFile = extractedFeaturesPath + "compressor.ser";
extractedFeaturesFile = extractedFeaturesPath + "compressed_features.ser";
}
public static String fieldValues(Object o) {
String s = "";
Field[] fields = o.getClass().getDeclaredFields();
for (Field field : fields) {
try {
field.setAccessible(true);
s += field.getName() + " = " + field.get(o) + "\n";
} catch (Exception e) {
throw new RuntimeException("Error getting field value for " + field.getName(), e);
}
}
return s;
}
private static void preprocess(Properties props, Dictionaries dictionaries, boolean isTrainSet)
throws Exception {
(isTrainSet ? new DatasetBuilder(StatisticalCorefProperties.minClassImbalance(props),
StatisticalCorefProperties.maxTrainExamplesPerDocument(props)) :
new DatasetBuilder()).runFromScratch(props, dictionaries);
new MetadataWriter(isTrainSet).runFromScratch(props, dictionaries);
new FeatureExtractorRunner(props, dictionaries).runFromScratch(props, dictionaries);
}
public static void doTraining(Properties props) throws Exception {
setTrainingPath(props);
Dictionaries dictionaries = new Dictionaries(props);
setDataPath("train");
wordCountsFile = trainingPath + "train/word_counts.ser";
CorefProperties.setInput(props, Dataset.TRAIN);
preprocess(props, dictionaries, true);
setDataPath("dev");
CorefProperties.setInput(props, Dataset.DEV);
preprocess(props, dictionaries, false);
setDataPath("train");
dictionaries = null;
PairwiseModel classificationModel = PairwiseModel.newBuilder(CLASSIFICATION_MODEL,
MetaFeatureExtractor.newBuilder().build()).build();
PairwiseModel rankingModel = PairwiseModel.newBuilder(RANKING_MODEL,
MetaFeatureExtractor.newBuilder().build()).build();
PairwiseModel anaphoricityModel = PairwiseModel.newBuilder(ANAPHORICITY_MODEL,
MetaFeatureExtractor.anaphoricityMFE()).trainingExamples(5000000).build();
PairwiseModelTrainer.trainRanking(rankingModel);
PairwiseModelTrainer.trainClassification(classificationModel, false);
PairwiseModelTrainer.trainClassification(anaphoricityModel, true);
setDataPath("dev");
PairwiseModelTrainer.test(classificationModel, predictionsName, false);
PairwiseModelTrainer.test(rankingModel, predictionsName, false);
PairwiseModelTrainer.test(anaphoricityModel, predictionsName, true);
new Clusterer().doTraining(CLUSTERING_MODEL_NAME);
}
/**
* Run the training. Main options:
* <ul>
* <li>-coref.data: location of training data (CoNLL format)</li>
* <li>-coref.statistical.trainingPath: where to write trained models and temporary files</li>
* <li>-coref.statistical.minClassImbalance: use this to downsample negative examples to
* speed up and reduce the memory footprint of training</li>
* <li>-coref.statistical.maxTrainExamplesPerDocument: use this to downsample examples from
* each document to speed up and reduce the memory footprint training</li>
* </ul>
*/
public static void main(String[] args) throws Exception {
doTraining(StringUtils.argsToProperties(args));
}
}