package edu.stanford.nlp.ie.machinereading;
import java.io.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import edu.stanford.nlp.classify.*;
import edu.stanford.nlp.ie.machinereading.structure.*;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ArgumentParser.Option;
public class BasicRelationExtractor implements Extractor {
private static final long serialVersionUID = 2606577772115897869L;
private static final Logger logger = Logger.getLogger(BasicRelationExtractor.class.getName());
protected LinearClassifier<String, String> classifier;
@Option(name="featureCountThreshold", gloss="feature count threshold to apply to dataset")
public int featureCountThreshold = 2;
@Option(name="featureFactory", gloss="Feature factory for the relation extractor")
public RelationFeatureFactory featureFactory;
/**
* strength of the prior on the linear classifier (passed to LinearClassifierFactory) or the C constant if relationExtractorClassifierType=svm
*/
@Option(name="sigma", gloss="strength of the prior on the linear classifier (passed to LinearClassifierFactory) or the C constant if relationExtractorClassifierType=svm")
public double sigma = 1.0;
/**
* which classifier to use (can be 'linear' or 'svm')
*/
public String relationExtractorClassifierType = "linear";
/**
* If true, it creates automatically negative examples by generating all combinations between EntityMentions in a sentence
* This is the common behavior, but for some domain (i.e., KBP) it must disabled. In these domains, the negative relation examples are created in the reader
*/
protected boolean createUnrelatedRelations;
/** Verifies that predicted labels are compatible with the relation arguments */
private LabelValidator validator;
protected RelationMentionFactory relationMentionFactory;
public void setValidator(LabelValidator lv) { validator = lv; }
public void setRelationExtractorClassifierType(String s) { relationExtractorClassifierType = s; }
public void setFeatureCountThreshold(int i) {featureCountThreshold = i; }
public void setSigma(double d) { sigma = d; }
public BasicRelationExtractor(RelationFeatureFactory featureFac, Boolean createUnrelatedRelations, RelationMentionFactory factory) {
featureFactory = featureFac;
this.createUnrelatedRelations = createUnrelatedRelations;
this.relationMentionFactory = factory;
logger.setLevel(Level.INFO);
}
public void setCreateUnrelatedRelations(boolean b) {
createUnrelatedRelations = b;
}
public static BasicRelationExtractor load(String modelPath) throws IOException, ClassNotFoundException {
return IOUtils.readObjectFromURLOrClasspathOrFileSystem(modelPath);
}
@Override
public void save(String modelpath) throws IOException {
// make sure modelpath directory exists
int lastSlash = modelpath.lastIndexOf(File.separator);
if(lastSlash > 0){
String path = modelpath.substring(0, lastSlash);
File f = new File(path);
if (! f.exists()) {
f.mkdirs();
}
}
FileOutputStream fos = new FileOutputStream(modelpath);
ObjectOutputStream out = new ObjectOutputStream(fos);
out.writeObject(this);
out.close();
}
/**
* Train on a list of ExtractionSentence containing labeled RelationMention objects
*/
@Override
public void train(Annotation sentences) {
// Train a single multi-class classifier
GeneralDataset<String, String> trainSet = createDataset(sentences);
trainMulticlass(trainSet);
}
public void trainMulticlass(GeneralDataset<String, String> trainSet) {
if (relationExtractorClassifierType.equalsIgnoreCase("linear")) {
LinearClassifierFactory<String, String> lcFactory = new LinearClassifierFactory<>(1e-4, false, sigma);
lcFactory.setVerbose(false);
// use in-place SGD instead of QN. this is faster but much worse!
// lcFactory.useInPlaceStochasticGradientDescent(-1, -1, 1.0);
// use a hybrid minimizer: start with in-place SGD, continue with QN
// lcFactory.useHybridMinimizerWithInPlaceSGD(50, -1, sigma);
classifier = lcFactory.trainClassifier(trainSet);
} else if (relationExtractorClassifierType.equalsIgnoreCase("svm")) {
SVMLightClassifierFactory<String, String> svmFactory = new SVMLightClassifierFactory<>();
svmFactory.setC(sigma);
classifier = svmFactory.trainClassifier(trainSet);
} else {
throw new RuntimeException("Invalid classifier type: " + relationExtractorClassifierType);
}
if (logger.isLoggable(Level.FINE)) {
reportWeights(classifier, null);
}
}
protected static void reportWeights(LinearClassifier<String, String> classifier, String classLabel) {
if (classLabel != null) logger.fine("CLASSIFIER WEIGHTS FOR LABEL " + classLabel);
Map<String, Counter<String>> labelsToFeatureWeights = classifier.weightsAsMapOfCounters();
List<String> labels = new ArrayList<>(labelsToFeatureWeights.keySet());
Collections.sort(labels);
for (String label: labels) {
Counter<String> featWeights = labelsToFeatureWeights.get(label);
List<Pair<String, Double>> sorted = Counters.toSortedListWithCounts(featWeights);
StringBuilder bos = new StringBuilder();
bos.append("WEIGHTS FOR LABEL ").append(label).append(':');
for (Pair<String, Double> feat: sorted) {
bos.append(' ').append(feat.first()).append(':').append(feat.second()+"\n");
}
logger.fine(bos.toString());
}
}
protected String classOf(Datum<String, String> datum, ExtractionObject rel) {
Counter<String> probs = classifier.probabilityOf(datum);
List<Pair<String, Double>> sortedProbs = Counters.toDescendingMagnitudeSortedListWithCounts(probs);
double nrProb = probs.getCount(RelationMention.UNRELATED);
for(Pair<String, Double> choice: sortedProbs){
if(choice.first.equals(RelationMention.UNRELATED)) return choice.first;
if(nrProb >= choice.second) return RelationMention.UNRELATED; // no prediction, all probs have the same value
if(compatibleLabel(choice.first, rel)) return choice.first;
}
return RelationMention.UNRELATED;
}
private boolean compatibleLabel(String label, ExtractionObject rel) {
if(rel == null) return true;
if(validator != null) return validator.validLabel(label, rel);
return true;
}
protected Counter<String> probabilityOf(Datum<String, String> testDatum) {
return classifier.probabilityOf(testDatum);
}
protected void justificationOf(Datum<String, String> testDatum, PrintWriter pw, String label) {
classifier.justificationOf(testDatum, pw);
}
/**
* Predict a relation for each pair of entities in the sentence; including relations of type unrelated.
* This creates new RelationMention objects!
*/
protected List<RelationMention> extractAllRelations(CoreMap sentence) {
List<RelationMention> extractions = new ArrayList<>();
List<RelationMention> cands = null;
if(createUnrelatedRelations){
// creates all possible relations between all entities in the sentence
cands = AnnotationUtils.getAllUnrelatedRelations(relationMentionFactory, sentence, false);
} else {
// just take the candidates produced by the reader (in KBP)
cands = sentence.get(MachineReadingAnnotations.RelationMentionsAnnotation.class);
if(cands == null){
cands = new ArrayList<>();
}
}
// the actual classification takes place here!
for (RelationMention rel : cands) {
Datum<String, String> testDatum = createDatum(rel);
String label = classOf(testDatum, rel);
Counter<String> probs = probabilityOf(testDatum);
double prob = probs.getCount(label);
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
if (logger.isLoggable(Level.INFO)) {
justificationOf(testDatum, pw, label);
}
logger.info("Current sentence: " + AnnotationUtils.tokensAndNELabelsToString(rel.getArg(0).getSentence()) + "\n"
+ "Classifying relation: " + rel + "\n"
+ "JUSTIFICATION for label GOLD:" + rel.getType() + " SYS:" + label + " (prob:" + prob + "):\n"
+ sw.toString());
logger.info("Justification done.");
RelationMention relation = relationMentionFactory.constructRelationMention(
rel.getObjectId(),
sentence,
rel.getExtent(),
label,
null,
rel.getArgs(),
probs);
extractions.add(relation);
if(! relation.getType().equals(rel.getType())){
logger.info("Classification: found different type " + relation.getType() + " for relation: " + rel);
logger.info("The predicted relation is: " + relation);
logger.info("Current sentence: " + AnnotationUtils.tokensAndNELabelsToString(rel.getArg(0).getSentence()));
} else{
logger.info("Classification: found similar type " + relation.getType() + " for relation: " + rel);
logger.info("The predicted relation is: " + relation);
logger.info("Current sentence: " + AnnotationUtils.tokensAndNELabelsToString(rel.getArg(0).getSentence()));
}
}
return extractions;
}
public List<String> annotateMulticlass(List<Datum<String, String>> testDatums) {
List<String> predictedLabels = new ArrayList<>();
for (Datum<String, String> testDatum: testDatums) {
String label = classOf(testDatum, null);
Counter<String> probs = probabilityOf(testDatum);
double prob = probs.getCount(label);
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
if (logger.isLoggable(Level.FINE)) {
justificationOf(testDatum, pw, label);
}
logger.fine("JUSTIFICATION for label GOLD:" + testDatum.label() + " SYS:" + label + " (prob:" + prob + "):\n"
+ sw.toString() + "\nJustification done.");
predictedLabels.add(label);
if(! testDatum.label().equals(label)){
logger.info("Classification: found different type " + label + " for relation: " + testDatum);
} else{
logger.info("Classification: found similar type " + label + " for relation: " + testDatum);
}
}
return predictedLabels;
}
public void annotateSentence(CoreMap sentence) {
// this stores all relation mentions generated by this extractor
List<RelationMention> relations = new ArrayList<>();
// extractAllRelations creates new objects for every predicted relation
for (RelationMention rel : extractAllRelations(sentence)) {
// add all relations. potentially useful for a joint model
// if (! RelationMention.isUnrelatedLabel(rel.getType()))
relations.add(rel);
}
// caution: this removes the old list of relation mentions!
for (RelationMention r: relations) {
if (! r.getType().equals(RelationMention.UNRELATED)) {
logger.fine("Found positive relation in annotateSentence: " + r);
}
}
sentence.set(MachineReadingAnnotations.RelationMentionsAnnotation.class, relations);
}
@Override
public void annotate(Annotation dataset) {
for (CoreMap sentence : dataset.get(CoreAnnotations.SentencesAnnotation.class)){
annotateSentence(sentence);
}
}
protected GeneralDataset<String, String> createDataset(Annotation corpus) {
GeneralDataset<String, String> dataset = new RVFDataset<>();
for (CoreMap sentence : corpus.get(CoreAnnotations.SentencesAnnotation.class)) {
for (RelationMention rel : AnnotationUtils.getAllRelations(relationMentionFactory, sentence, createUnrelatedRelations)) {
dataset.add(createDatum(rel));
}
}
dataset.applyFeatureCountThreshold(featureCountThreshold);
return dataset;
}
protected Datum<String, String> createDatum(RelationMention rel) {
assert(featureFactory != null);
return featureFactory.createDatum(rel);
}
protected Datum<String, String> createDatum(RelationMention rel, String label) {
assert(featureFactory != null);
Datum<String, String> datum = featureFactory.createDatum(rel, label);
return datum;
}
@Override
public void setLoggerLevel(Level level) {
logger.setLevel(level);
}
}