//Dstl (c) Crown Copyright 2017 package uk.gov.dstl.baleen.uima.grammar; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.PriorityQueue; import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.uima.cas.text.AnnotationFS; import org.apache.uima.fit.util.JCasUtil; import org.apache.uima.jcas.JCas; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.collect.HashMultimap; import com.google.common.collect.SetMultimap; import uk.gov.dstl.baleen.types.language.Dependency; import uk.gov.dstl.baleen.types.language.WordToken; /** * A graph of grammar dependencies within an annotated jCas. * <p> * Whilst UIMA annotations can store the output of a dependency grammar it is difficult to work with * and slow to query. This class builds a cache which means finding nearest neighbours (based on * dependency distance) faster and easier. * * The JCAS must have been annotated by a dependency grammar (e.g. MaltParser, ClearNlp) before passing * to build(). * */ public class DependencyGraph { private static final Logger LOGGER = LoggerFactory.getLogger(DependencyGraph.class); private final SetMultimap<WordToken, Edge> edges; private final SetMultimap<WordToken, Dependency> dependents; private final SetMultimap<WordToken, Dependency> governors; /** * Instantiates a new dependency graph. */ private DependencyGraph() { edges = HashMultimap.create(); dependents = HashMultimap.create(); governors = HashMultimap.create(); } /** * Instantiates a new dependency graph. * * @param edges * the edges * @param dependentMap * the dependent map * @param governorMap * the governor map */ private DependencyGraph(SetMultimap<WordToken, Edge> edges, SetMultimap<WordToken, Dependency> dependentMap, SetMultimap<WordToken, Dependency> governorMap) { this.edges = edges; this.dependents = dependentMap; this.governors = governorMap; } /** * Gets the dependencies where the word is the dependent. * * @param word * the word * @return the dependents */ public Set<Dependency> getDependents(WordToken word) { return Collections.unmodifiableSet(dependents.get(word)); } /** * Gets the dependencies where the word is the governor. * * @param word * the word * @return the governors */ public Set<Dependency> getGovernors(WordToken word) { return Collections.unmodifiableSet(governors.get(word)); } /** * Gets the edges to/from this word. * * @param word * the word * @return the edges */ public Stream<WordToken> getEdges(WordToken word) { return edges.get(word).stream().map(e -> e.getOther(word)); } /** * Adds the edge. * * @param dependency * the dependency */ private void addEdge(final Dependency dependency) { final WordToken governor = dependency.getGovernor(); final WordToken dependent = dependency.getDependent(); final Edge edge = new Edge(dependent, dependency, governor); edges.put(governor, edge); edges.put(dependent, edge); dependents.put(dependent, dependency); governors.put(governor, dependency); } /** * Find the nearest neighbours within dependency distance links of the provided start * dependencies. * * @param distance * the dependency distance * @param start * array / of words to start from * @return the (set of) words within range */ public Set<WordToken> extractWords(final int distance, final Dependency... start) { return extractWords(distance, d -> true, start); } /** * Find the nearest neighbours within dependency distance links of the provided start * dependencies. * * @param distance * the dependency distance * @param predicate * the predicate * @param start * array / of words to start from * @return the (set of) words within range */ public Set<WordToken> extractWords(final int distance, Predicate<Dependency> predicate, final Dependency... start) { return extractWords(distance, predicate, Arrays.asList(start)); } /** * Find the nearest neighbours within dependency distance links of the provide start * dependencies. * * @param distance * the dependency distance * @param predicate * the predicate * @param start * the start words (as list) * @return the (set of) words within range */ public Set<WordToken> extractWords(final int distance, Predicate<Dependency> predicate, final Collection<Dependency> start) { final Set<WordToken> words = new HashSet<>(); if (distance <= 0) { return words; } final int governorDistance = distance - 1; for (final Dependency d : start) { if (governorDistance > 0) { extractWords(words, governorDistance, predicate, d.getGovernor()); } extractWords(words, distance, predicate, d.getDependent()); } return words; } /** * Find the nearest neighbours within dependency distance links of the provided start * dependencies. * * @param distance * the dependency distance * @param start * array / of words to start from * @return the (set of) words within range */ public Set<WordToken> nearestWords(final int distance, final WordToken... start) { return nearestWords(distance, d -> true, Arrays.asList(start)); } /** * Find the nearest neighbours within dependency distance links of the provided start * dependencies. * * @param distance * the dependency distance * @param predicate * the predicate * @param start * array / of words to start from * @return the (set of) words within range */ public Set<WordToken> nearestWords(final int distance, Predicate<Dependency> predicate, final WordToken... start) { return nearestWords(distance, predicate, Arrays.asList(start)); } /** * Find the nearest neighbours within dependency distance links of the provide start * dependencies. * * @param distance * the dependency distance * @param predicate * the predicate * @param start * the start words (as list) * @return the (set of) words within range */ public Set<WordToken> nearestWords(final int distance, Predicate<Dependency> predicate, final Collection<WordToken> start) { final Set<WordToken> words = new HashSet<>(); if (distance <= 0) { return words; } for (final WordToken d : start) { extractWords(words, distance, predicate, d); } return words; } /** * Extract words recursively following the graph. * * @param collector * the collector * @param distance * the distance * @param predicate * the predicate * @param token * the token */ private void extractWords(final Set<WordToken> collector, final int distance, Predicate<Dependency> predicate, final WordToken token) { // The word itself collector.add(token); // TODO: Depth first, We potentially revisit the same node repeatedly, // so this could be more efficient. final List<WordToken> set = edges.get(token).stream() .filter(e -> predicate.test(e.getDependency())) .map(e -> e.getOther(token)) .collect(Collectors.toList()); if (set != null) { collector.addAll(set); final int newDistance = distance - 1; if (newDistance > 0) { set.forEach(a -> extractWords(collector, newDistance, predicate, a)); } } } /** * Log the dependency graph to the logger for debugging. */ public void log() { final StringBuilder sb = new StringBuilder(); edges.asMap().entrySet().stream().forEach(e -> { sb.append("\t"); sb.append(e.getKey().getCoveredText()); sb.append(": "); e.getValue().stream() .map(x -> x.getOther(e.getKey())) .forEach(w -> sb.append(" " + w.getCoveredText())); sb.append("\n"); }); final StringBuilder governorSb = new StringBuilder(); governors.asMap().entrySet().stream().forEach(e -> { governorSb.append("\t"); governorSb.append(e.getKey().getCoveredText()); governorSb.append(": "); e.getValue().stream() .forEach(w -> governorSb.append(" " + w.getCoveredText() + "[" + w.getDependencyType() + "]")); governorSb.append("\n"); }); final StringBuilder dependentSb = new StringBuilder(); dependents.asMap().entrySet().stream().forEach(e -> { dependentSb.append("\t"); dependentSb.append(e.getKey().getCoveredText()); dependentSb.append(": "); e.getValue().stream() .forEach(w -> dependentSb.append(" " + w.getCoveredText() + "[" + w.getDependencyType() + "]")); dependentSb.append("\n"); }); DependencyGraph.LOGGER.info("Dependency graph: Edges:\n{}\n Governors\n{}\n Dependent\n{}", sb.toString(), governorSb, dependentSb); } /** * Create a new (sub) graph where words are only those matched by the filter. * * @param predicate * the predicate * @return the new filtered dependency graph */ public DependencyGraph filter(Predicate<WordToken> predicate) { final SetMultimap<WordToken, Edge> filteredEdges = HashMultimap.create(); final SetMultimap<WordToken, Dependency> filteredDependent = HashMultimap.create(); final SetMultimap<WordToken, Dependency> filteredGovernor = HashMultimap.create(); edges.asMap().entrySet().stream() .filter(w -> predicate.test(w.getKey())) .forEach(e -> { final WordToken key = e.getKey(); e.getValue().stream() .filter(edge -> predicate.test(edge.getOther(key))) .forEach(v -> filteredEdges.put(key, v)); }); governors.asMap().keySet().stream().filter(predicate).forEach(k -> { final List<Dependency> filtered = governors.get(k).stream() .filter(d -> predicate.test(d.getGovernor()) && predicate.test(d.getDependent())) .collect(Collectors.toList()); filteredGovernor.putAll(k, filtered); }); dependents.asMap().keySet().stream().filter(predicate).forEach(k -> { final List<Dependency> filtered = dependents.get(k).stream() .filter(d -> predicate.test(d.getGovernor()) && predicate.test(d.getDependent())) .collect(Collectors.toList()); filteredDependent.putAll(k, filtered); }); return new DependencyGraph(filteredEdges, filteredDependent, filteredGovernor); } /** * Adds the dependency. * * @param dependency * the dependency */ private void addDependency(Dependency dependency) { if ((dependency.getDependencyType() == null || !"ROOT".equalsIgnoreCase(dependency.getDependencyType())) && dependency.getGovernor() != null && dependency.getDependent() != null) { addEdge(dependency); } } /** * Gets the words in te graph. * * @return the words */ public Set<WordToken> getWords() { return Collections.unmodifiableSet(edges.keySet()); } /** * Shortest path between from and to, limited by maxDistance.. * * @param from * the from * @param to * the to * @param maxDistance * the max distance * @return the list */ public List<WordToken> shortestPath(Collection<WordToken> from, Collection<WordToken> to, int maxDistance) { if (from.isEmpty() || to.isEmpty() || maxDistance <= -1) { return Collections.emptyList(); } final Set<WordToken> visited = new HashSet<>(); final PriorityQueue<WordDistance> queue = new PriorityQueue<>(); from.stream().forEach(t -> { queue.add(new WordDistance(t)); visited.add(t); }); while (!queue.isEmpty()) { final WordDistance wd = queue.poll(); LOGGER.debug("{}", wd); if (to.contains(wd.getWord())) { return wd.getWords(); } if (wd.getDistance() < maxDistance) { final Set<WordToken> nextWords = edges.get(wd.getWord()).stream() .map(w -> w.getOther(wd.getWord())) .collect(Collectors.toSet()); nextWords.removeAll(visited); nextWords.stream().forEach(t -> { queue.add(new WordDistance(t, wd)); visited.add(t); }); } } return Collections.emptyList(); } /** * Build a dependency graph from a JCAS which has already been processed through a dependency * grammar. * * Thus the JCAS as Dependency annotations. * * @param jCas * the jCAS to process. * @return the dependency graph (non-null) */ public static DependencyGraph build(final JCas jCas) { final DependencyGraph graph = new DependencyGraph(); JCasUtil.select(jCas, Dependency.class).stream() .forEach(graph::addDependency); return graph; } /** * Build a dependency graph from a JCAS which has already been processed through a dependency * grammar, but limit to a subset of the jcas (covered by annotation). * * Thus the JCAS as Dependency annotations. * * @param jCas * the jCAS to process. * @param annnotation * the annnotation * @return the dependency graph (non-null) */ public static DependencyGraph build(final JCas jCas, AnnotationFS annnotation) { final DependencyGraph graph = new DependencyGraph(); JCasUtil.selectCovered(jCas, Dependency.class, annnotation).stream() .forEach(graph::addDependency); return graph; } /** * Traverse the graph looking * * @param distance * the distance * @param start * the start * @param predicate * the predicate - use this to act on the graph (eg collect information) and return * false to stop or true to continue. */ public void traverse(int distance, Collection<Dependency> start, TraversePredicate predicate) { if (distance <= 0) { return; } final ImmutableStack<WordToken> history = new ImmutableStack<WordToken>(); for (final Dependency d : start) { if (predicate.test(d, null, d.getDependent(), history)) { ImmutableStack<WordToken> stack = history.push(d.getDependent()); traverse(distance, d.getDependent(), stack, predicate); } } } /** * Traverse the graph from token. * * @param distance * the distance * @param token * the token * @param history * the history * @param predicate * the predicate */ private void traverse(int distance, WordToken token, ImmutableStack<WordToken> history, TraversePredicate predicate) { final int newDistance = distance - 1; if (newDistance <= 0) { return; } for (final Edge e : edges.get(token)) { final WordToken other = e.getOther(token); if (!history.contains(other) && predicate.test(e.getDependency(), token, other, history)) { final ImmutableStack<WordToken> stack = history.push(other); traverse(newDistance, other, stack, predicate); } } } /** * A functional interface to implement */ @FunctionalInterface public interface TraversePredicate { /** * Test if should follow this dependencies. * * * @param dependency * the dependency * @param from * the from word * @param to * the to word * @param history * the history (all the word tokens up to from) * @return true, if successful */ boolean test(Dependency dependency, WordToken from, WordToken to, ImmutableStack<WordToken> history); } }