package edu.stanford.nlp.ie;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotation;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.tokensregex.CoreMapExpressionExtractor;
import edu.stanford.nlp.ling.tokensregex.Env;
import edu.stanford.nlp.ling.tokensregex.MatchedExpression;
import edu.stanford.nlp.ling.tokensregex.TokenSequencePattern;
import edu.stanford.nlp.pipeline.DefaultPaths;
import edu.stanford.nlp.simple.Sentence;
import edu.stanford.nlp.util.ArgumentParser;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.util.logging.RedwoodConfiguration;
import java.io.*;
import java.util.*;
/**
* A tokensregex extractor for KBP.
*
* IMPORTANT: Don't rename this class without updating the rules defs file.
*
* @author Gabor Angeli
*/
public class KBPTokensregexExtractor implements KBPRelationExtractor {
protected static final Redwood.RedwoodChannels logger = Redwood.channels(KBPTokensregexExtractor.class);
@ArgumentParser.Option(name="dir", gloss="The tokensregex directory")
public static String DIR = DefaultPaths.DEFAULT_KBP_TOKENSREGEX_DIR;
@ArgumentParser.Option(name="test", gloss="The dataset to test on")
public static File TEST_FILE = new File("test.conll");
@ArgumentParser.Option(name="predictions", gloss="Dump model predictions to this file")
public static Optional<String> PREDICTIONS = Optional.empty();
private final Map<RelationType, CoreMapExpressionExtractor> rules = new HashMap<>();
/**
* IMPORTANT: Don't rename this class without updating the rules defs file.
*/
public static class Subject implements CoreAnnotation<String> {
public Class<String> getType() { return String.class; }
}
/**
* IMPORTANT: Don't rename this class without updating the rules defs file.
*/
public static class Object implements CoreAnnotation<String> {
public Class<String> getType() { return String.class; }
}
public KBPTokensregexExtractor(String tokensregexDir) {
this(tokensregexDir, false);
}
public KBPTokensregexExtractor(String tokensregexDir, boolean verbose) {
if (verbose)
logger.log("Creating TokensRegexExtractor");
// Create extractors
for (RelationType rel : RelationType.values()) {
String path = tokensregexDir + File.separator + rel.canonicalName.replaceAll("/", "SLASH") + ".rules";
if (IOUtils.existsInClasspathOrFileSystem(path)) {
List<String> listFiles = new ArrayList<>();
listFiles.add(tokensregexDir + File.separator + "defs.rules");
listFiles.add(path);
if (verbose)
logger.log("Rule files for relation " + rel + " is " + path);
Env env = TokenSequencePattern.getNewEnv();
env.bind("collapseExtractionRules", true);
env.bind("verbose", verbose);
CoreMapExpressionExtractor extr = CoreMapExpressionExtractor.createExtractorFromFiles(env, listFiles).keepTemporaryTags();
rules.put(rel, extr);
}
}
}
@Override
public Pair<String, Double> classify(KBPInput input) {
// Annotate Sentence
CoreMap sentenceAsMap = input.sentence.asCoreMap(Sentence::nerTags);
List<CoreLabel> tokens = sentenceAsMap.get(CoreAnnotations.TokensAnnotation.class);
// Annotate where the subject is
for (int i : input.subjectSpan) {
tokens.get(i).set(Subject.class, "true");
if ("O".equals(tokens.get(i).ner())) {
tokens.get(i).setNER(input.subjectType.name);
}
}
// Annotate where the object is
for (int i : input.objectSpan) {
tokens.get(i).set(Object.class, "true");
if ("O".equals(tokens.get(i).ner())) {
tokens.get(i).setNER(input.objectType.name);
}
}
// Run Rules
for (RelationType rel : RelationType.values()) {
if (rules.containsKey(rel) &&
rel.entityType == input.subjectType &&
rel.validNamedEntityLabels.contains(input.objectType)) {
CoreMapExpressionExtractor extractor = rules.get(rel);
@SuppressWarnings("unchecked")
List<MatchedExpression> extractions = extractor.extractExpressions(sentenceAsMap);
if (extractions != null && extractions.size() > 0) {
MatchedExpression best = MatchedExpression.getBestMatched(extractions, MatchedExpression.EXPR_WEIGHT_SCORER);
// Un-Annotate Sentence
for (CoreLabel token : tokens) {
token.remove(Subject.class);
token.remove(Object.class);
}
return Pair.makePair(rel.canonicalName, best.getWeight());
}
}
}
// Un-Annotate Sentence
for (CoreLabel token : tokens) {
token.remove(Subject.class);
token.remove(Object.class);
}
return Pair.makePair(NO_RELATION, 1.0);
}
public static void main(String[] args) throws IOException {
RedwoodConfiguration.standard().apply(); // Disable SLF4J crap.
ArgumentParser.fillOptions(KBPTokensregexExtractor.class, args);
KBPTokensregexExtractor extractor = new KBPTokensregexExtractor(DIR);
List<Pair<KBPInput, String>> testExamples = KBPRelationExtractor.readDataset(TEST_FILE);
extractor.computeAccuracy(testExamples.stream(), PREDICTIONS.map(x -> {
try {
return "stdout".equalsIgnoreCase(x) ? System.out : new PrintStream(new FileOutputStream(x));
} catch (IOException e) {
throw new RuntimeIOException(e);
}
}));
}
}