package edu.stanford.nlp.ie;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.pipeline.DefaultPaths;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations;
import edu.stanford.nlp.semgraph.semgrex.SemgrexBatchParser;
import edu.stanford.nlp.semgraph.semgrex.SemgrexMatcher;
import edu.stanford.nlp.semgraph.semgrex.SemgrexPattern;
import edu.stanford.nlp.simple.Sentence;
import edu.stanford.nlp.util.ArgumentParser;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.util.logging.RedwoodConfiguration;
import java.io.*;
import java.util.*;
/**
* A tokensregex extractor for KBP.
*
* @author Gabor Angeli
*/
public class KBPSemgrexExtractor implements KBPRelationExtractor {
protected final Redwood.RedwoodChannels logger = Redwood.channels(KBPSemgrexExtractor.class);
@ArgumentParser.Option(name="dir", gloss="The tokensregex directory")
public static String DIR = DefaultPaths.DEFAULT_KBP_SEMGREX_DIR;
@ArgumentParser.Option(name="test", gloss="The dataset to test on")
public static File TEST_FILE = new File("test.conll");
@ArgumentParser.Option(name="predictions", gloss="Dump model predictions to this file")
public static Optional<String> PREDICTIONS = Optional.empty();
private final Map<RelationType, Collection<SemgrexPattern> > rules = new HashMap<>();
public KBPSemgrexExtractor(String semgrexdir) throws IOException {
this(semgrexdir, false);
}
public KBPSemgrexExtractor(String semgrexdir, boolean verbose) throws IOException {
if (verbose)
logger.log("Creating SemgrexRegexExtractor");
// Create extractors
for (RelationType rel : RelationType.values()) {
String filename = semgrexdir + File.separator + rel.canonicalName.replace("/", "SLASH") + ".rules";
if (IOUtils.existsInClasspathOrFileSystem(filename)) {
List<SemgrexPattern> rulesforrel = SemgrexBatchParser.compileStream(IOUtils.getInputStreamFromURLOrClasspathOrFileSystem(filename));
if (verbose)
logger.log("Read " + rulesforrel.size() + " rules from " + filename + " for relation " + rel);
rules.put(rel, rulesforrel);
}
}
}
@Override
public Pair<String, Double> classify(KBPInput input) {
for (RelationType rel : RelationType.values()) {
if (rules.containsKey(rel) &&
rel.entityType == input.subjectType &&
rel.validNamedEntityLabels.contains(input.objectType)) {
Collection<SemgrexPattern> rulesForRel = rules.get(rel);
CoreMap sentence = input.sentence.asCoreMap(Sentence::nerTags, Sentence::dependencyGraph);
boolean matches
= matches(sentence, rulesForRel, input,
sentence.get(SemanticGraphCoreAnnotations.EnhancedPlusPlusDependenciesAnnotation.class)) ||
matches(sentence, rulesForRel, input,
sentence.get(SemanticGraphCoreAnnotations.AlternativeDependenciesAnnotation.class));
if (matches) {
//logger.log("MATCH for " + rel + ". " + sentence: + sentence + " with rules for " + rel);
return Pair.makePair(rel.canonicalName, 1.0);
}
}
}
return Pair.makePair(NO_RELATION, 1.0);
}
/**
* Returns whether any of the given patterns match this tree.
*/
private boolean matches(CoreMap sentence, Collection<SemgrexPattern> rulesForRel,
KBPInput input, SemanticGraph graph) {
if (graph == null || graph.isEmpty()) {
return false;
}
List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class);
for (int i : input.subjectSpan) {
if ("O".equals(tokens.get(i).ner())) {
tokens.get(i).setNER(input.subjectType.name);
}
}
for (int i : input.objectSpan) {
if ("O".equals(tokens.get(i).ner())) {
tokens.get(i).setNER(input.objectType.name);
}
}
for (SemgrexPattern p : rulesForRel) {
try {
SemgrexMatcher n = p.matcher(graph);
while (n.find()) {
IndexedWord entity = n.getNode("entity");
IndexedWord slot = n.getNode("slot");
boolean hasSubject = entity.index() >= input.subjectSpan.start() + 1 && entity.index() <= input.subjectSpan.end();
boolean hasObject = slot.index() >= input.objectSpan.start() + 1 && slot.index() <= input.objectSpan.end();
if (hasSubject && hasObject) {
return true;
}
}
} catch (Exception e) {
//Happens when graph has no roots
return false;
}
}
return false;
}
public static void main(String[] args) throws IOException {
RedwoodConfiguration.standard().apply(); // Disable SLF4J crap.
ArgumentParser.fillOptions(KBPSemgrexExtractor.class, args);
KBPSemgrexExtractor extractor = new KBPSemgrexExtractor(DIR);
List<Pair<KBPInput, String>> testExamples = KBPRelationExtractor.readDataset(TEST_FILE);
extractor.computeAccuracy(testExamples.stream(), PREDICTIONS.map(x -> {
try {
return "stdout".equalsIgnoreCase(x) ? System.out : new PrintStream(new FileOutputStream(x));
} catch (IOException e) {
throw new RuntimeIOException(e);
}
}));
}
}