package edu.stanford.nlp.naturalli; import edu.stanford.nlp.util.logging.Redwood; import edu.stanford.nlp.ling.CoreLabel; import edu.stanford.nlp.ling.IndexedWord; import edu.stanford.nlp.semgraph.SemanticGraph; import edu.stanford.nlp.semgraph.SemanticGraphEdge; import edu.stanford.nlp.util.Lazy; import edu.stanford.nlp.util.Pair; import edu.stanford.nlp.util.StringUtils; import java.util.*; import java.util.stream.Collectors; /** * A particular instance of a search problem for finding entailed sentences. * This problem already specifies the options for the search, as well as the sentence to search from. * * Note, again, that this only searches for deletions and not insertions or mutations. * * @author Gabor Angeli */ public class ForwardEntailerSearchProblem { /** A logger for this class */ private static Redwood.RedwoodChannels log = Redwood.channels(ForwardEntailerSearchProblem.class); /** * The parse of this fragment. The vertices in the parse tree should be a subset * (possibly not strict) of the tokens above. */ public final SemanticGraph parseTree; /** * The truth of the premise -- determines the direction we can mutate the sentences. */ public final boolean truthOfPremise; /** * The maximum number of ticks top search for. Otherwise, the search will be exhaustive. */ public final int maxTicks; /** * The maximum number of results to return from a single search. */ public final int maxResults; /** * The weights to use for entailment. */ public final NaturalLogicWeights weights; /** * A result from the search over possible shortenings of the sentence. */ private static class SearchResult { public final SemanticGraph tree; public final List<String> deletedEdges; public final double confidence; private SearchResult(SemanticGraph tree, List<String> deletedEdges, double confidence) { this.tree = tree; this.deletedEdges = deletedEdges; this.confidence = confidence; } @Override public String toString() { return StringUtils.join(tree.vertexListSorted().stream().map(IndexedWord::word), " "); } } /** * A state in the search, denoting a partial shortening of the sentence. */ private static class SearchState { public final BitSet deletionMask; public final int currentIndex; public final SemanticGraph tree; public final String lastDeletedEdge; public final SearchState source; public final double score; private SearchState(BitSet deletionMask, int currentIndex, SemanticGraph tree, String lastDeletedEdge, SearchState source, double score) { this.deletionMask = (BitSet) deletionMask.clone(); this.currentIndex = currentIndex; this.tree = tree; this.lastDeletedEdge = lastDeletedEdge; this.source = source; this.score = score; } } /** * Create a new search problem, fully specified. * @see edu.stanford.nlp.naturalli.ForwardEntailer */ protected ForwardEntailerSearchProblem(SemanticGraph parseTree, boolean truthOfPremise, int maxResults, int maxTicks, NaturalLogicWeights weights ) { this.parseTree = parseTree; this.truthOfPremise = truthOfPremise; this.maxResults = maxResults; this.maxTicks = maxTicks; this.weights = weights; } /** * Run a search from this entailer. This will return a list of sentence fragments * that are entailed by the original sentence / fragment. * * @return A list of entailed fragments. */ @SuppressWarnings("unchecked") public List<SentenceFragment> search() { return searchImplementation().stream() .map(x -> new SentenceFragment(x.tree, truthOfPremise, false).changeScore(x.confidence)) .filter(x -> x.words.size() > 0 ) .collect(Collectors.toList()); } /** * The search algorithm, starting with a full sentence and iteratively shortening it to its entailed sentences. * * @return A list of search results, corresponding to shortenings of the sentence. */ @SuppressWarnings("unchecked") private List<SearchResult> searchImplementation() { // Pre-process the tree SemanticGraph parseTree = new SemanticGraph(this.parseTree); assert Util.isTree(parseTree); // (remove common determiners) List<String> determinerRemovals = new ArrayList<>(); parseTree.getLeafVertices().stream().filter(vertex -> "the".equalsIgnoreCase(vertex.word()) || "a".equalsIgnoreCase(vertex.word()) || "an".equalsIgnoreCase(vertex.word()) || "this".equalsIgnoreCase(vertex.word()) || "that".equalsIgnoreCase(vertex.word()) || "those".equalsIgnoreCase(vertex.word()) || "these".equalsIgnoreCase(vertex.word()) ).forEach(vertex -> { parseTree.removeVertex(vertex); assert Util.isTree(parseTree); determinerRemovals.add("det"); }); // (cut conj_and nodes) Set<SemanticGraphEdge> andsToAdd = new HashSet<>(); for (IndexedWord vertex : parseTree.vertexSet()) { if( parseTree.inDegree(vertex) > 1 ) { SemanticGraphEdge conjAnd = null; for (SemanticGraphEdge edge : parseTree.incomingEdgeIterable(vertex)) { if ("conj:and".equals(edge.getRelation().toString())) { conjAnd = edge; } } if (conjAnd != null) { parseTree.removeEdge(conjAnd); assert Util.isTree(parseTree); andsToAdd.add(conjAnd); } } } // Clean the tree Util.cleanTree(parseTree); assert Util.isTree(parseTree); // Find the subject / object split // This takes max O(n^2) time, expected O(n*log(n)) time. // Optimal is O(n), but I'm too lazy to implement it. BitSet isSubject = new BitSet(256); for (IndexedWord vertex : parseTree.vertexSet()) { // Search up the tree for a subj node; if found, mark that vertex as a subject. Iterator<SemanticGraphEdge> incomingEdges = parseTree.incomingEdgeIterator(vertex); SemanticGraphEdge edge = null; if (incomingEdges.hasNext()) { edge = incomingEdges.next(); } int numIters = 0; while (edge != null) { if (edge.getRelation().toString().endsWith("subj")) { assert vertex.index() > 0; isSubject.set(vertex.index() - 1); break; } incomingEdges = parseTree.incomingEdgeIterator(edge.getGovernor()); if (incomingEdges.hasNext()) { edge = incomingEdges.next(); } else { edge = null; } numIters += 1; if (numIters > 100) { // log.error("tree has apparent depth > 100"); return Collections.EMPTY_LIST; } } } // Outputs List<SearchResult> results = new ArrayList<>(); if (!determinerRemovals.isEmpty()) { if (andsToAdd.isEmpty()) { double score = Math.pow(weights.deletionProbability("det"), (double) determinerRemovals.size()); assert !Double.isNaN(score); assert !Double.isInfinite(score); results.add(new SearchResult(parseTree, determinerRemovals, score)); } else { SemanticGraph treeWithAnds = new SemanticGraph(parseTree); assert Util.isTree(treeWithAnds); for (SemanticGraphEdge and : andsToAdd) { treeWithAnds.addEdge(and.getGovernor(), and.getDependent(), and.getRelation(), Double.NEGATIVE_INFINITY, false); } assert Util.isTree(treeWithAnds); results.add(new SearchResult(treeWithAnds, determinerRemovals, Math.pow(weights.deletionProbability("det"), (double) determinerRemovals.size()))); } } // Initialize the search assert Util.isTree(parseTree); List<IndexedWord> topologicalVertices; try { topologicalVertices = parseTree.topologicalSort(); } catch (IllegalStateException e) { // log.info("Could not topologically sort the vertices! Using left-to-right traversal."); topologicalVertices = parseTree.vertexListSorted(); } if (topologicalVertices.isEmpty()) { return results; } Stack<SearchState> fringe = new Stack<>(); fringe.push(new SearchState(new BitSet(256), 0, parseTree, null, null, 1.0)); // Start the search int numTicks = 0; while (!fringe.isEmpty()) { // Overhead with popping a node. if (numTicks >= maxTicks) { return results; } numTicks += 1; if (results.size() >= maxResults) { return results; } SearchState state = fringe.pop(); assert state.score > 0.0; IndexedWord currentWord = topologicalVertices.get(state.currentIndex); // Push the case where we don't delete int nextIndex = state.currentIndex + 1; int numIters = 0; while (nextIndex < topologicalVertices.size()) { IndexedWord nextWord = topologicalVertices.get(nextIndex); assert nextWord.index() > 0; if (!state.deletionMask.get(nextWord.index() - 1)) { fringe.push(new SearchState(state.deletionMask, nextIndex, state.tree, null, state, state.score)); break; } else { nextIndex += 1; } numIters += 1; if (numIters > 10000) { // log.error("logic error (apparent infinite loop); returning"); return results; } } // Check if we can delete this subtree boolean canDelete = !state.tree.getFirstRoot().equals(currentWord); for (SemanticGraphEdge edge : state.tree.incomingEdgeIterable(currentWord)) { if ("CD".equals(edge.getGovernor().tag())) { canDelete = false; } else { // Get token information CoreLabel token = edge.getDependent().backingLabel(); OperatorSpec operator; NaturalLogicRelation lexicalRelation; Polarity tokenPolarity = token.get(NaturalLogicAnnotations.PolarityAnnotation.class); if (tokenPolarity == null) { tokenPolarity = Polarity.DEFAULT; } // Get the relation for this deletion if ((operator = token.get(NaturalLogicAnnotations.OperatorAnnotation.class)) != null) { lexicalRelation = operator.instance.deleteRelation; } else { assert edge.getDependent().index() > 0; lexicalRelation = NaturalLogicRelation.forDependencyDeletion(edge.getRelation().toString(), isSubject.get(edge.getDependent().index() - 1)); } NaturalLogicRelation projectedRelation = tokenPolarity.projectLexicalRelation(lexicalRelation); // Make sure this is a valid entailment if (!projectedRelation.applyToTruthValue(truthOfPremise).isTrue()) { canDelete = false; } } } if (canDelete) { // Register the deletion Lazy<Pair<SemanticGraph,BitSet>> treeWithDeletionsAndNewMask = Lazy.of(() -> { SemanticGraph impl = new SemanticGraph(state.tree); BitSet newMask = state.deletionMask; for (IndexedWord vertex : state.tree.descendants(currentWord)) { impl.removeVertex(vertex); assert vertex.index() > 0; newMask.set(vertex.index() - 1); assert newMask.get(vertex.index() - 1); } return Pair.makePair(impl, newMask); }); // Compute the score of the sentence double newScore = state.score; for (SemanticGraphEdge edge : state.tree.incomingEdgeIterable(currentWord)) { double multiplier = weights.deletionProbability(edge, state.tree.outgoingEdgeIterable(edge.getGovernor())); assert !Double.isNaN(multiplier); assert !Double.isInfinite(multiplier); newScore *= multiplier; } // Register the result if (newScore > 0.0) { SemanticGraph resultTree = new SemanticGraph(treeWithDeletionsAndNewMask.get().first); andsToAdd.stream().filter(edge -> resultTree.containsVertex(edge.getGovernor()) && resultTree.containsVertex(edge.getDependent())) .forEach(edge -> resultTree.addEdge(edge.getGovernor(), edge.getDependent(), edge.getRelation(), Double.NEGATIVE_INFINITY, false)); results.add(new SearchResult(resultTree, aggregateDeletedEdges(state, state.tree.incomingEdgeIterable(currentWord), determinerRemovals), newScore)); // Push the state with this subtree deleted nextIndex = state.currentIndex + 1; numIters = 0; while (nextIndex < topologicalVertices.size()) { IndexedWord nextWord = topologicalVertices.get(nextIndex); BitSet newMask = treeWithDeletionsAndNewMask.get().second; SemanticGraph treeWithDeletions = treeWithDeletionsAndNewMask.get().first; if ( !newMask.get(nextWord.index() - 1) ) { assert treeWithDeletions.containsVertex(topologicalVertices.get(nextIndex)); fringe.push(new SearchState(newMask, nextIndex, treeWithDeletions, null, state, newScore)); break; } else { nextIndex += 1; } numIters += 1; if (numIters > 10000) { // log.error("logic error (apparent infinite loop); returning"); return results; } } } } } // Return return results; } /** * Backtrace from a search state, collecting all of the deleted edges used to get there. * @param state The final search state. * @param justDeleted The edges we have just deleted. * @param otherEdges Other deletions we want to register * @return A list of deleted edges for that search state. */ private static List<String> aggregateDeletedEdges(SearchState state, Iterable<SemanticGraphEdge> justDeleted, Iterable<String> otherEdges) { List<String> rtn = new ArrayList<>(); for (SemanticGraphEdge edge : justDeleted) { rtn.add(edge.getRelation().toString()); } for (String edge : otherEdges) { rtn.add(edge); } while (state != null) { if (state.lastDeletedEdge != null) { rtn.add(state.lastDeletedEdge); } state = state.source; } return rtn; } }