package edu.stanford.nlp.ie; import edu.stanford.nlp.classify.Classifier; import edu.stanford.nlp.classify.LinearClassifier; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.io.RuntimeIOException; import edu.stanford.nlp.pipeline.DefaultPaths; import edu.stanford.nlp.util.ArgumentParser; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.util.logging.RedwoodConfiguration; import java.io.*; import java.util.List; import java.util.Optional; /** * An ensemble of other KBP relation extractors. * Currently, this class just takes the union of the given extractors. * That is, it returns the first relation returned by any extractor * (ties broken by the order the extractors are passed to the constructor), * and only returns no_relation if no extractor proposed a relation. */ @SuppressWarnings("FieldCanBeLocal") public class KBPEnsembleExtractor implements KBPRelationExtractor { protected static final Redwood.RedwoodChannels logger = Redwood.channels(KBPRelationExtractor.class); @ArgumentParser.Option(name="model", gloss="The path to the model") private static String STATISTICAL_MODEL = DefaultPaths.DEFAULT_KBP_CLASSIFIER; @ArgumentParser.Option(name="semgrex", gloss="Semgrex patterns directory") private static String SEMGREX_DIR = DefaultPaths.DEFAULT_KBP_SEMGREX_DIR; @ArgumentParser.Option(name="tokensregex", gloss="Tokensregex patterns directory") private static String TOKENSREGEX_DIR = DefaultPaths.DEFAULT_KBP_TOKENSREGEX_DIR; @ArgumentParser.Option(name="predictions", gloss="Dump model predictions to this file") public static Optional<String> PREDICTIONS = Optional.empty(); @ArgumentParser.Option(name="test", gloss="The dataset to test on") public static File TEST_FILE = new File("test.conll"); /** * The extractors to run, in the order of priority they should be run in. */ public final KBPRelationExtractor[] extractors; /** * Creates a new ensemble extractor from the given argument extractors. * @param extractors A varargs list of extractors to union together. */ public KBPEnsembleExtractor(KBPRelationExtractor... extractors) { this.extractors = extractors; } @Override public Pair<String, Double> classify(KBPInput input) { Pair<String, Double> prediction = Pair.makePair(KBPRelationExtractor.NO_RELATION, 1.0); for (KBPRelationExtractor extractor : extractors) { Pair<String, Double> classifierPrediction = extractor.classify(input); if (prediction.first.equals(KBPRelationExtractor.NO_RELATION) || (!classifierPrediction.first.equals(KBPRelationExtractor.NO_RELATION) && classifierPrediction.second > prediction.second) ){ // The last prediction was NO_RELATION, or this is not NO_RELATION and has a higher score prediction = classifierPrediction; } } return prediction; } public static void main(String[] args) throws IOException, ClassNotFoundException { RedwoodConfiguration.standard().apply(); // Disable SLF4J crap. ArgumentParser.fillOptions(KBPEnsembleExtractor.class, args); Object object = IOUtils.readObjectFromURLOrClasspathOrFileSystem(STATISTICAL_MODEL); KBPRelationExtractor statisticalExtractor; if (object instanceof LinearClassifier) { //noinspection unchecked statisticalExtractor = new KBPStatisticalExtractor((Classifier<String, String>) object); } else if (object instanceof KBPStatisticalExtractor) { statisticalExtractor = (KBPStatisticalExtractor) object; } else { throw new ClassCastException(object.getClass() + " cannot be cast into a " + KBPStatisticalExtractor.class); } logger.info("Read statistical model from " + STATISTICAL_MODEL); KBPRelationExtractor extractor = new KBPEnsembleExtractor( new KBPTokensregexExtractor(TOKENSREGEX_DIR), new KBPSemgrexExtractor(SEMGREX_DIR), statisticalExtractor ); List<Pair<KBPInput, String>> testExamples = KBPRelationExtractor.readDataset(TEST_FILE); extractor.computeAccuracy(testExamples.stream(), PREDICTIONS.map(x -> { try { return "stdout".equalsIgnoreCase(x) ? System.out : new PrintStream(new FileOutputStream(x)); } catch (IOException e) { throw new RuntimeIOException(e); } })); } }