package edu.stanford.nlp.naturalli;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.ie.machinereading.structure.Span;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasIndex;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.pipeline.Annotation;
import edu.stanford.nlp.process.TSVSentenceProcessor;
import edu.stanford.nlp.semgraph.SemanticGraph;
import edu.stanford.nlp.semgraph.SemanticGraphCoreAnnotations;
import edu.stanford.nlp.semgraph.SemanticGraphEdge;
import edu.stanford.nlp.semgraph.semgrex.SemgrexMatcher;
import edu.stanford.nlp.semgraph.semgrex.SemgrexPattern;
import edu.stanford.nlp.trees.PennTreeReader;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeReader;
import edu.stanford.nlp.trees.UniversalEnglishGrammaticalStructureFactory;
import edu.stanford.nlp.util.*;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.text.DecimalFormat;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import static edu.stanford.nlp.util.logging.Redwood.Util.*;
/**
* A script to convert a TSV dump from our KBP sentences table into a Turk-task ready clause splitting dataset.
*
* @author Gabor Angeli
*/
public class CreateClauseDataset implements TSVSentenceProcessor {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(CreateClauseDataset.class);
@ArgumentParser.Option(name="in", gloss="The input to read from")
private static InputStream in = System.in;
public CreateClauseDataset() {
}
private static Span toSpan(List<? extends HasIndex> chunk) {
int min = Integer.MAX_VALUE;
int max = -1;
for (HasIndex word : chunk) {
min = Math.min(word.index() - 1, min);
max = Math.max(word.index(), max);
}
assert min >= 0;
assert max < Integer.MAX_VALUE && max > 0;
return new Span(min, max);
}
@Override
public void process(long id, Annotation doc) {
CoreMap sentence = doc.get(CoreAnnotations.SentencesAnnotation.class).get(0);
SemanticGraph depparse = sentence.get(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class);
log.info("| " + sentence.get(CoreAnnotations.TextAnnotation.class));
// Get all valid subject spans
BitSet consumedAsSubjects = new BitSet();
@SuppressWarnings("MismatchedQueryAndUpdateOfCollection") List<Span> subjectSpans = new ArrayList<>();
NEXTNODE: for (IndexedWord head : depparse.topologicalSort()) {
// Check if the node is a noun/pronoun
if (head.tag().startsWith("N") || head.tag().equals("PRP")) {
// Try to get the NP chunk
Optional<List<IndexedWord>> subjectChunk = segmenter.getValidChunk(depparse, head, segmenter.VALID_SUBJECT_ARCS, Optional.empty(), true);
if (subjectChunk.isPresent()) {
// Make sure it's not already a member of a larger NP
for (IndexedWord tok : subjectChunk.get()) {
if (consumedAsSubjects.get(tok.index())) {
continue NEXTNODE; // Already considered. Continue to the next node.
}
}
// Register it as an NP
for (IndexedWord tok : subjectChunk.get()) {
consumedAsSubjects.set(tok.index());
}
// Add it as a subject
subjectSpans.add(toSpan(subjectChunk.get()));
}
}
}
}
/**
* The pattern for traces which are potential subjects
*/
private static Pattern TRACE_TARGET_PATTERN = Pattern.compile("(NP-.*)-([0-9]+)");
/**
* The pattern for trace markers.
*/
private static Pattern TRACE_SOURCE_PATTERN = Pattern.compile(".*\\*-([0-9]+)");
/**
* The converter from constituency to dependency trees.
*/
private static UniversalEnglishGrammaticalStructureFactory parser = new UniversalEnglishGrammaticalStructureFactory();
/**
* The OpenIE segmenter to use.
*/
private static RelationTripleSegmenter segmenter = new RelationTripleSegmenter();
/**
* The natural logic annotator for marking polarity.
*/
private static NaturalLogicAnnotator natlog = new NaturalLogicAnnotator();
/**
* Parse a given constituency tree into a dependency graph.
*
* @param tree The constituency tree, in Penn Treebank style.
* @return The dependency graph for the tree.
*/
private static SemanticGraph parse(Tree tree) {
return new SemanticGraph(parser.newGrammaticalStructure(tree).typedDependenciesCollapsed());
}
/**
* Create a dataset of subject/object pairs, such that a sequence of splits that segments this
* subject and object is a correct sequence.
*
* @param depparse The dependency parse of the sentence.
* @param traceTargets The set of spans corresponding to targets of traces.
* @param traceSources The set of indices in a sentence corresponding to the sources of traces.
* @return A dataset of subject/object spans.
*/
@SuppressWarnings("UnusedParameters")
private static Collection<Pair<Span, Span>> subjectObjectPairs(SemanticGraph depparse,
List<CoreLabel> tokens,
Map<Integer, Span> traceTargets,
Map<Integer, Integer> traceSources) {
// log(StringUtils.join(tokens.stream().map(CoreLabel::word), " "));
List<Pair<Span, Span>> data = new ArrayList<>();
for (SemgrexPattern vpPattern : segmenter.VP_PATTERNS) {
SemgrexMatcher matcher = vpPattern.matcher(depparse);
while (matcher.find()) {
// Get the verb and object
IndexedWord verb = matcher.getNode("verb");
IndexedWord object = matcher.getNode("object");
if (verb != null && object != null) {
// See if there is already a subject attached
boolean hasSubject = false;
for (SemanticGraphEdge edge : depparse.outgoingEdgeIterable(verb)) {
if (edge.getRelation().toString().contains("subj")) {
hasSubject = true;
}
}
for (SemanticGraphEdge edge : depparse.outgoingEdgeIterable(object)) {
if (edge.getRelation().toString().contains("subj")) {
hasSubject = true;
}
}
if (!hasSubject) {
// Get the spans for the verb and object
Optional<List<IndexedWord>> verbChunk = segmenter.getValidChunk(depparse, verb, segmenter.VALID_ADVERB_ARCS, Optional.empty(), true);
Optional<List<IndexedWord>> objectChunk = segmenter.getValidChunk(depparse, object, segmenter.VALID_OBJECT_ARCS, Optional.empty(), true);
if (verbChunk.isPresent() && objectChunk.isPresent()) {
Collections.sort(verbChunk.get(), (a, b) -> a.index() - b.index());
Collections.sort(objectChunk.get(), (a, b) -> a.index() - b.index());
// Find a trace
int traceId = -1;
Span verbSpan = toSpan(verbChunk.get());
Span traceSpan = Span.fromValues(verbSpan.start() - 1, verbSpan.end() + 1);
for (Map.Entry<Integer, Integer> entry : traceSources.entrySet()) {
if (traceSpan.contains(entry.getValue())) {
traceId = entry.getKey();
}
}
//noinspection StatementWithEmptyBody
if (traceId < 0) {
// Register the VP as an unknown VP
// List<CoreLabel> vpChunk = new ArrayList<>();
// vpChunk.addAll(verbChunk.get());
// vpChunk.addAll(objectChunk.get());
// Collections.sort(vpChunk, (a, b) -> a.index() - b.index());
// debug("could not find trace for " + vpChunk);
} else {
// Add the obj chunk
Span subjectSpan = traceTargets.get(traceId);
Span objectSpan = toSpan(objectChunk.get());
if (subjectSpan != null) {
// debug("(" +
// StringUtils.join(tokens.subList(subjectSpan.start(), subjectSpan.end()).stream().map(CoreLabel::word), " ") + "; " +
// verb.word() + "; " +
// StringUtils.join(tokens.subList(objectSpan.start(), objectSpan.end()).stream().map(CoreLabel::word), " ") +
// ")");
data.add(Pair.makePair(subjectSpan, objectSpan));
}
}
}
}
}
}
}
// Run vanilla pattern splits
for (SemgrexPattern vpPattern : segmenter.VERB_PATTERNS) {
SemgrexMatcher matcher = vpPattern.matcher(depparse);
while (matcher.find()) {
// Get the verb and object
IndexedWord subject = matcher.getNode("subject");
IndexedWord object = matcher.getNode("object");
if (subject != null && object != null) {
Optional<List<IndexedWord>> subjectChunk = segmenter.getValidChunk(depparse, subject, segmenter.VALID_SUBJECT_ARCS, Optional.empty(), true);
Optional<List<IndexedWord>> objectChunk = segmenter.getValidChunk(depparse, object, segmenter.VALID_OBJECT_ARCS, Optional.empty(), true);
if (subjectChunk.isPresent() && objectChunk.isPresent()) {
Span subjectSpan = toSpan(subjectChunk.get());
Span objectSpan = toSpan(objectChunk.get());
data.add(Pair.makePair(subjectSpan, objectSpan));
}
}
}
}
return data;
}
/**
* Collect all the possible targets for traces. This is limited to NP-style traces.
*
* @param root The tree to search in. This is a recursive function.
* @return The set of trace targets. The key is the id of the trace, the value is the span of the target of the trace.
*/
private static Map<Integer, Span> findTraceTargets(Tree root) {
Map<Integer, Span> spansInTree = new HashMap<>(4);
Matcher m = TRACE_TARGET_PATTERN.matcher(root.label().value() == null ? "NULL" : root.label().value());
if (m.matches()) {
int index = Integer.parseInt(m.group(2));
spansInTree.put(index, Span.fromPair(root.getSpan()).toExclusive());
}
for (Tree child : root.children()) {
spansInTree.putAll(findTraceTargets(child));
}
return spansInTree;
}
/**
* Collect all the trace markers in the sentence.
*
* @param root The tree to search in. This is a recursive function.
* @return A map of trace sources. The key is hte id of the trace, the value is the index of the trace's source in the sentence.
*/
private static Map<Integer, Integer> findTraceSources(Tree root) {
Map<Integer, Integer> spansInTree = new HashMap<>(4);
Matcher m = TRACE_SOURCE_PATTERN.matcher(root.label().value() == null ? "NULL" : root.label().value());
if (m.matches()) {
int index = Integer.parseInt(m.group(1));
spansInTree.put(index, ((CoreLabel) root.label()).index() - 1);
}
for (Tree child : root.children()) {
spansInTree.putAll(findTraceSources(child));
}
return spansInTree;
}
/**
* Count the number of extractions in the given dataset. That is, the sum count of the pair spans
* for each sentence.
*
* @param data The dataset.
* @return The number of extractions in the datasets..
*/
private static int countDatums(List<Pair<CoreMap, Collection<Pair<Span,Span>>>> data) {
int count = 0;
for (Pair<CoreMap, Collection<Pair<Span, Span>>> datum : data) {
count += datum.second.size();
}
return count;
}
/**
* Process all the trees in the given directory. For example, the WSJ section of the Penn Treebank.
*
* @param name The name of the directory we are processing.
* @param directory The directory we are processing.
* @return A dataset of subject/object pairs in the trees in the directory.
* This is a list of sentences, such that each sentence has a collection of pairs of spans.
* Each pair of spans is a subject/object span pair that constitutes a valid extraction.
* @throws IOException
*/
private static List<Pair<CoreMap, Collection<Pair<Span, Span>>>> processDirectory(String name, File directory) throws IOException {
forceTrack("Processing " + name);
// Prepare the files to iterate over
Iterable<File> files = IOUtils.iterFilesRecursive(directory, "mrg");
Tree tree;
int numTreesProcessed = 0;
List<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData = new ArrayList<>(1024);
// Iterate over the files
for (File file : files) {
// log(file);
TreeReader reader = new PennTreeReader(IOUtils.readerFromFile(file));
while ( (tree = reader.readTree()) != null ) {
try {
// Prepare the tree
tree.indexSpans();
tree.setSpans();
// Get relevant information from sentence
List<CoreLabel> tokens = tree.getLeaves().stream()
.map(leaf -> (CoreLabel) leaf.label())
// .filter(leaf -> !TRACE_SOURCE_PATTERN.matcher(leaf.word()).matches() && !leaf.tag().equals("-NONE-"))
.collect(Collectors.toList());
SemanticGraph graph = parse(tree);
Map<Integer, Span> targets = findTraceTargets(tree);
Map<Integer, Integer> sources = findTraceSources(tree);
// Create a sentence object
CoreMap sentence = new ArrayCoreMap(4) {{
set(CoreAnnotations.TokensAnnotation.class, tokens);
set(SemanticGraphCoreAnnotations.BasicDependenciesAnnotation.class, graph);
set(SemanticGraphCoreAnnotations.EnhancedDependenciesAnnotation.class, graph);
set(SemanticGraphCoreAnnotations.EnhancedPlusPlusDependenciesAnnotation.class, graph);
}};
natlog.doOneSentence(null, sentence);
// Generate training data
Collection<Pair<Span, Span>> trainingDataFromSentence = subjectObjectPairs(graph, tokens, targets, sources);
trainingData.add(Pair.makePair(sentence, trainingDataFromSentence));
// Debug print
numTreesProcessed += 1;
if (numTreesProcessed % 100 == 0) {
log("[" + new DecimalFormat("00000").format(numTreesProcessed) + "] " + countDatums(trainingData) + " known extractions");
}
} catch (Throwable t) {
t.printStackTrace();
}
}
}
// End
log("" + numTreesProcessed + " trees processed yielding " + countDatums(trainingData) + " known extractions");
endTrack("Processing " + name);
return trainingData;
}
/**
* The main entry point of the code.
*/
public static void main(String[] args) throws IOException {
forceTrack("Processing treebanks");
List<Pair<CoreMap, Collection<Pair<Span, Span>>>> trainingData = new ArrayList<>();
trainingData.addAll(processDirectory("WSJ", new File("/home/gabor/lib/data/penn_treebank/wsj")));
trainingData.addAll(processDirectory("Brown", new File("/home/gabor/lib/data/penn_treebank/brown")));
endTrack("Processing treebanks");
forceTrack("Training");
log("dataset size: " + trainingData.size());
ClauseSplitter.train(
trainingData.stream(),
new File("/home/gabor/tmp/clauseSearcher.ser.gz"),
new File("/home/gabor/tmp/clauseSearcherData.tab.gz"));
endTrack("Training");
// Execution.fillOptions(CreateClauseDataset.class, args);
//
// new CreateClauseDataset().runAndExit(in, System.err, code -> code);
}
}