package edu.stanford.nlp.naturalli; import edu.stanford.nlp.classify.*; import edu.stanford.nlp.ie.machinereading.structure.Span; import edu.stanford.nlp.ie.util.RelationTriple; import edu.stanford.nlp.io.IOUtils; import edu.stanford.nlp.io.RuntimeIOException; import edu.stanford.nlp.ling.CoreAnnotations; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.RVFDatum; import edu.stanford.nlp.semgraph.SemanticGraph; import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.util.*; import java.io.*; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; import java.util.stream.Stream; import java.util.zip.GZIPOutputStream; import edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem.*; import edu.stanford.nlp.util.logging.Redwood; import static edu.stanford.nlp.util.logging.Redwood.Util.*; /** * Just a convenience alias for a clause splitting search problem factory. * Mostly here to form a nice parallel with {@link edu.stanford.nlp.naturalli.ForwardEntailer}. * * @author Gabor Angeli */ public interface ClauseSplitter extends BiFunction<SemanticGraph, Boolean, ClauseSplitterSearchProblem> { /** A logger for this class */ Redwood.RedwoodChannels log = Redwood.channels(ClauseSplitter.class); enum ClauseClassifierLabel { CLAUSE_SPLIT(2), CLAUSE_INTERM(1), NOT_A_CLAUSE(0); public final byte index; ClauseClassifierLabel(int val) { this.index = (byte) val; } /** Seriously, why would Java not have this by default? */ @Override public String toString() { return this.name(); } @SuppressWarnings("unused") public static ClauseClassifierLabel fromIndex(int index) { switch (index) { case 0: return NOT_A_CLAUSE; case 1: return CLAUSE_INTERM; case 2: return CLAUSE_SPLIT; default: throw new IllegalArgumentException("Not a valid index: " + index); } } } /** * Train a clause searcher factory. That is, train a classifier for which arcs should be * new clauses. * * @param trainingData The training data. This is a stream of triples of: * <ol> * <li>The sentence containing a known extraction.</li> * <li>The span of the subject in the sentence, as a token span.</li> * <li>The span of the object in the sentence, as a token span.</li> * </ol> * @param modelPath The path to save the model to. This is useful for {@link ClauseSplitter#load(String)}. * @param trainingDataDump The path to save the training data, as a set of labeled featurized datums. * @param featurizer The featurizer to use for this classifier. * * @return A factory for creating searchers from a given dependency tree. */ static ClauseSplitter train( Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, Optional<File> modelPath, Optional<File> trainingDataDump, Featurizer featurizer) { // Parse options LinearClassifierFactory<ClauseClassifierLabel, String> factory = new LinearClassifierFactory<>(); // Generally useful objects OpenIE openie = new OpenIE(PropertiesUtils.asProperties( "splitter.nomodel", "true", "optimizefor", "GENERAL" )); WeightedDataset<ClauseClassifierLabel, String> dataset = new WeightedDataset<>(); AtomicInteger numExamplesProcessed = new AtomicInteger(0); final Optional<PrintWriter> datasetDumpWriter = trainingDataDump.map(file -> { try { return new PrintWriter(new OutputStreamWriter(new GZIPOutputStream(new FileOutputStream(trainingDataDump.get())))); } catch (IOException e) { throw new RuntimeIOException(e); } }); // Step 1: Loop over data forceTrack("Training inference"); trainingData.forEach(rawExample -> { // Parse training datum CoreMap sentence = rawExample.first; Collection<Pair<Span, Span>> spans = rawExample.second; List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class); SemanticGraph tree = sentence.get(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class); // Create raw clause searcher (no classifier) ClauseSplitterSearchProblem problem = new ClauseSplitterSearchProblem(tree, true); // Run search problem.search(fragmentAndScore -> { // Parse the search callback List<Counter<String>> features = fragmentAndScore.second; SentenceFragment fragment = fragmentAndScore.third.get(); // Search for extractions Set<RelationTriple> extractions = new HashSet<>(openie.relationsInFragments(openie.entailmentsFromClause(fragment))); Trilean correct = Trilean.FALSE; RELATION_TRIPLE_LOOP: for (RelationTriple extraction : extractions) { // Clean up the guesses Span subjectGuess = Span.fromValues(extraction.subject.get(0).index() - 1, extraction.subject.get(extraction.subject.size() - 1).index()); Span objectGuess = Span.fromValues(extraction.object.get(0).index() - 1, extraction.object.get(extraction.object.size() - 1).index()); for (Pair<Span, Span> candidateGold : spans) { Span subjectSpan = candidateGold.first; Span objectSpan = candidateGold.second; // Check if it matches if ((subjectGuess.equals(subjectSpan) && objectGuess.equals(objectSpan)) || (subjectGuess.equals(objectSpan) && objectGuess.equals(subjectSpan)) ) { correct = Trilean.TRUE; break RELATION_TRIPLE_LOOP; } else if (Util.nerOverlap(tokens, subjectSpan, subjectGuess) && Util.nerOverlap(tokens, objectSpan, objectGuess) || Util.nerOverlap(tokens, subjectSpan, objectGuess) && Util.nerOverlap(tokens, objectSpan, subjectGuess)) { if (!correct.isTrue()) { correct = Trilean.TRUE; break RELATION_TRIPLE_LOOP; } } else { if (!correct.isTrue()) { correct = Trilean.UNKNOWN; break RELATION_TRIPLE_LOOP; } } } } // Process the datum if (!features.isEmpty()) { // Convert the path to datums List<Pair<Counter<String>, ClauseClassifierLabel>> decisionsToAddAsDatums = new ArrayList<>(); if (correct.isTrue()) { // If this is a "true" path, add the k-1 decisions as INTERM and the last decision as a SPLIT for (int i = 0; i < features.size(); ++i) { if (i == features.size() - 1) { decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT)); } else { decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM)); } } } else if (correct.isFalse()) { // If this is a "false" path, then we know at least the last decision was bad. decisionsToAddAsDatums.add(Pair.makePair(features.get(features.size() - 1), ClauseClassifierLabel.NOT_A_CLAUSE)); } else if (correct.isUnknown()) { // If this is an "unknown" path, only add it if it was the result of vanilla splits // (check if it is a sequence of simple splits) boolean isSimpleSplit = false; for (Counter<String> feats : features) { if (featurizer.isSimpleSplit(feats)) { isSimpleSplit = true; break; } } // (if so, add it as if it were a True example) if (isSimpleSplit) { for (int i = 0; i < features.size(); ++i) { if (i == features.size() - 1) { decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_SPLIT)); } else { decisionsToAddAsDatums.add(Pair.makePair(features.get(i), ClauseClassifierLabel.CLAUSE_INTERM)); } } } } // Add the datums for (Pair<Counter<String>, ClauseClassifierLabel> decision : decisionsToAddAsDatums) { // (create datum) RVFDatum<ClauseClassifierLabel, String> datum = new RVFDatum<>(decision.first); datum.setLabel(decision.second); // (dump datum to debug log) if (datasetDumpWriter.isPresent()) { datasetDumpWriter.get().println(decision.second + "\t" + StringUtils.join(decision.first.entrySet().stream().map(entry -> entry.getKey() + "->" + entry.getValue()), ";")); } // (add datum to dataset) dataset.add(datum); } } return true; }, new LinearClassifier<>(new ClassicCounter<>()), Collections.emptyMap(), featurizer, 10000); // Debug info if (numExamplesProcessed.incrementAndGet() % 100 == 0) { log("processed " + numExamplesProcessed + " training sentences: " + dataset.size() + " datums"); } }); endTrack("Training inference"); // Close the file if (datasetDumpWriter.isPresent()) { datasetDumpWriter.get().close(); } // Step 2: Train classifier forceTrack("Training"); Classifier<ClauseClassifierLabel,String> fullClassifier = factory.trainClassifier(dataset); endTrack("Training"); if (modelPath.isPresent()) { Pair<Classifier<ClauseClassifierLabel,String>, Featurizer> toSave = Pair.makePair(fullClassifier, featurizer); try { IOUtils.writeObjectToFile(toSave, modelPath.get()); log("SUCCESS: wrote model to " + modelPath.get().getPath()); } catch (IOException e) { log("ERROR: failed to save model to path: " + modelPath.get().getPath()); err(e); } } // Step 3: Check accuracy of classifier forceTrack("Training accuracy"); dataset.randomize(42L); Util.dumpAccuracy(fullClassifier, dataset); endTrack("Training accuracy"); int numFolds = 5; forceTrack(numFolds + " fold cross-validation"); for (int fold = 0; fold < numFolds; ++fold) { forceTrack("Fold " + (fold + 1)); forceTrack("Training"); Pair<GeneralDataset<ClauseClassifierLabel, String>, GeneralDataset<ClauseClassifierLabel, String>> foldData = dataset.splitOutFold(fold, numFolds); Classifier<ClauseClassifierLabel, String> classifier = factory.trainClassifier(foldData.first); endTrack("Training"); forceTrack("Test"); Util.dumpAccuracy(classifier, foldData.second); endTrack("Test"); endTrack("Fold " + (fold + 1)); } endTrack(numFolds + " fold cross-validation"); // Step 5: return factory return (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(fullClassifier), Optional.of(featurizer)); } static ClauseSplitter train( Stream<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData, File modelPath, File trainingDataDump) { return train(trainingData, Optional.of(modelPath), Optional.of(trainingDataDump), ClauseSplitterSearchProblem.DEFAULT_FEATURIZER); } /** * Load a factory model from a given path. This can be trained with * {@link ClauseSplitter#train(Stream, Optional, Optional, Featurizer)}. * * @return A function taking a dependency tree, and returning a clause searcher. */ static ClauseSplitter load(String serializedModel) throws IOException { try { long start = System.currentTimeMillis(); Pair<Classifier<ClauseClassifierLabel,String>, Featurizer> data = IOUtils.readObjectFromURLOrClasspathOrFileSystem(serializedModel); ClauseSplitter rtn = (tree, truth) -> new ClauseSplitterSearchProblem(tree, truth, Optional.of(data.first), Optional.of(data.second)); log.info("Loading clause splitter from " + serializedModel + " ... done [" + Redwood.formatTimeDifference(System.currentTimeMillis() - start) + "]"); return rtn; } catch (ClassNotFoundException e) { throw new IllegalStateException("Invalid model at path: " + serializedModel, e); } } }