package edu.berkeley.nlp.syntax; import java.util.ArrayList; import java.util.IdentityHashMap; import java.util.List; /** * Tool for finding path relationships between nodes in a tree * * @author David Burkett */ public class TreePathFinder<L> { private Tree<L> root; private IdentityHashMap<Tree<L>, List<Tree<L>>> pathsFromRoot; public Tree<L> getRoot() { return root; } public TreePathFinder(Tree<L> tree) { root = tree; pathsFromRoot = new IdentityHashMap<Tree<L>, List<Tree<L>>>(tree.getPreOrderTraversal().size()); constructPaths(root, new ArrayList<Tree<L>>()); } private void constructPaths(Tree<L> node, List<Tree<L>> path) { path.add(node); pathsFromRoot.put(node, path); for (Tree<L> child : node.getChildren()) { ArrayList<Tree<L>> childPath = new ArrayList<Tree<L>>(path.size()+1); childPath.addAll(path); constructPaths(child, childPath); } } public Tree<L> findParent(Tree<L> node) { if (!pathsFromRoot.containsKey(node)) { throw new IllegalArgumentException("Tree must be node in the tree used to initialize the TreePathFinder"); } if (node == root) { return null; } List<Tree<L>> path = pathsFromRoot.get(node); return path.get(path.size() - 2); } public TreePath<L> findPath(Tree<L> start, Tree<L> end) { validateInput(start, end); List<TreePath.Transition<L>> transitions = new ArrayList<TreePath.Transition<L>>(); if (start != end) { List<Tree<L>> startPath = pathsFromRoot.get(start); List<Tree<L>> endPath = pathsFromRoot.get(end); // Find root of common subtree int rootIndex = findRootIndex(startPath, endPath); // Transitions from start node up to root of common subtree for (int i = startPath.size() - 1; i > rootIndex; i--) { transitions.add(new TreePath.Transition<L>(startPath.get(i), startPath.get(i-1), TreePath.Direction.UP)); } // First transition down from root of common subtree (directional if there have been up transitions) if (rootIndex < endPath.size() - 1) { TreePath.Direction postRootDirection = TreePath.Direction.DOWN; if (rootIndex < startPath.size() - 1) { postRootDirection = TreePath.Direction.DOWN_RIGHT; for (Tree<L> rootChild : startPath.get(rootIndex).getChildren()) { if (startPath.get(rootIndex+1) == rootChild) { break; } if (endPath.get(rootIndex+1) == rootChild) { postRootDirection = TreePath.Direction.DOWN_LEFT; break; } } } transitions.add(new TreePath.Transition<L>(endPath.get(rootIndex), endPath.get(rootIndex+1), postRootDirection)); } // Remaining transitions down to end node for (int i = rootIndex + 1; i < endPath.size() - 1; i++) { transitions.add(new TreePath.Transition<L>(endPath.get(i), endPath.get(i+1), TreePath.Direction.DOWN)); } } return new TreePath<L>(transitions); } private int findRootIndex(List<Tree<L>> startPath, List<Tree<L>> endPath) { int rootIndex = 0; for (Tree<L> node : startPath) { if (rootIndex == endPath.size() || node != endPath.get(rootIndex)) { break; } rootIndex++; } rootIndex--; return rootIndex; } public Tree<L> findLowestCommonAncestor(Tree<L> start, Tree<L> end) { validateInput(start, end); if (start == end) return start; List<Tree<L>> startPath = pathsFromRoot.get(start); List<Tree<L>> endPath = pathsFromRoot.get(end); int rootIndex = findRootIndex(startPath, endPath); return startPath.get(rootIndex); } private void validateInput(Tree<L> start, Tree<L> end) { if (start == null || end == null) { throw new IllegalArgumentException("Cannot provide null trees"); } if (!pathsFromRoot.containsKey(start) || !pathsFromRoot.containsKey(end)) { throw new IllegalArgumentException("Both trees must be nodes in the tree used to initialize the TreePathFinder"); } } }