package edu.ucsd.arcum.util; import static edu.ucsd.arcum.util.Graph.Visited.NOT_VISITED; import static edu.ucsd.arcum.util.Graph.Visited.VISITED; import static edu.ucsd.arcum.util.Graph.Visited.VISIT_STARTED; import java.util.*; import java.util.Map.Entry; import com.google.common.collect.Lists; public class Graph<T> { public interface NodeVisitor<T> { void visit(T node); } public interface TopSortVisitor<T> extends NodeVisitor<T> { void cycleFound(List<T> nodes); } public interface LayeredVisitor<T, E extends Throwable> { void visitLayer(List<T> layer) throws E; void cycleFound(List<T> cycle) throws E; } // adjacency list representation private HashMap<Node<T>, Set<Node<T>>> graph; private HashMap<T, Node<T>> nodes; public static <T> Graph<T> newGraph() { return new Graph<T>(); } public Graph() { this.graph = new LinkedHashMap<Node<T>, Set<Node<T>>>(); this.nodes = new HashMap<T, Node<T>>(); } public void addNode(T value) { selectiveCreateNode(value); } public void addEdge(T u, T v) { Node<T> nodeU = selectiveCreateNode(u); Node<T> nodeV = selectiveCreateNode(v); addEdge(nodeU, nodeV); } private Node<T> selectiveCreateNode(T value) { if (!nodes.containsKey(value)) { Node<T> node = new Node<T>(value); nodes.put(value, node); createNode(node); return node; } else { return nodes.get(value); } } private void createNode(Node<T> u) { _addNode(u, false); } private void addEdge(Node<T> u, Node<T> v) { Set<Node<T>> set = _addNode(u, true); _addNode(v, false); set.add(v); } private Set<Node<T>> _addNode(Node<T> u, boolean lookup) { if (!graph.containsKey(u)) { Set<Node<T>> set = new LinkedHashSet<Node<T>>(); graph.put(u, set); return set; } else if (lookup) { return graph.get(u); } else { return null; } } public Collection<T> depthFirstSearch(T start) { final Collection<T> list = Lists.newArrayList(); depthFirstSearch(start, new NodeVisitor<T>() { @Override public void visit(T node) { list.add(node); } }); return list; } public void depthFirstSearch(T start, NodeVisitor<T> visitor) { Node<T> startNode = beginTraversal(start); _dfs(startNode, visitor); } public Collection<T> breadthFirstSearch(T start) { final Collection<T> list = Lists.newArrayList(); breadthFirstSearch(start, new NodeVisitor<T>() { @Override public void visit(T node) { list.add(node); } }); return list; } public void breadthFirstSearch(T start, NodeVisitor<T> visitor) { Node<T> startNode = beginTraversal(start); markAndImmediatelyVisit(startNode, visitor); // LinkedLists are FIFO queues Queue<Node<T>> queue = new LinkedList<Node<T>>(); queue.add(startNode); for (;;) { Node<T> current = queue.poll(); if (current == null) break; Set<Node<T>> connectedTo = graph.get(current); for (Node<T> v : connectedTo) { if (v.getVisited() == NOT_VISITED) { markAndImmediatelyVisit(v, visitor); queue.add(v); } } } } public <E extends Throwable> void iterateOverTopologicalLayers( LayeredVisitor<T, E> layeredVisitor) throws E { markUnvisited(); while (notAllVisited()) { Map<Node<T>, Integer> inwardCount = computeUnvisitedInwardCount(); List<T> layer = Lists.newArrayList(); List<Node<T>> layerNodes = Lists.newArrayList(); for (Entry<Node<T>, Integer> entry : inwardCount.entrySet()) { Node<T> node = entry.getKey(); Integer count = entry.getValue(); if (count == 0 && node.getVisited() != VISITED) { layer.add(node.getValue()); layerNodes.add(node); } } if (layer.size() == 0) { layeredVisitor.cycleFound(new ArrayList<T>()); return; } layeredVisitor.visitLayer(layer); for (Node<T> layerNode : layerNodes) { layerNode.setVisited(VISITED); } } } private boolean notAllVisited() { for (Node n : graph.keySet()) { if (n.getVisited() != VISITED) { return true; } } return false; } private Map<Node<T>, Integer> computeUnvisitedInwardCount() { Map<Node<T>, Integer> result = new HashMap<Node<T>, Integer>(); for (Node<T> node : graph.keySet()) { if (node.getVisited() != VISITED) { result.put(node, 0); } } nodeCounting: for (Entry<Node<T>, Set<Node<T>>> entry : graph.entrySet()) { Node<T> key = entry.getKey(); if (key.getVisited() == VISITED) { continue nodeCounting; } Set<Node<T>> adjList = entry.getValue(); for (Node<T> pointedTo : adjList) { if (pointedTo.getVisited() != VISITED) { result.put(pointedTo, result.get(pointedTo) + 1); } } } return result; } // visits all nodes in the graph in a topological order, if possible. If // the graph has cycles the traversal the cycleFound method will be invoked // on the visitor, without any nodes being visited public void topologicalSort(TopSortVisitor<T> visitor) { assert visitor != null; _topsort(visitor); } public List<T> getTrees() { AbstractCollection<Node<T>> roots = _topsort(null); if (roots == null) { return Lists.newArrayList(); } else { ArrayList<T> result = new ArrayList<T>(roots.size()); for (Node<T> root : roots) { result.add(root.getValue()); } return result; } } @Override public String toString() { StringBuilder builder = new StringBuilder(); for (Entry<Node<T>, Set<Node<T>>> entry : graph.entrySet()) { builder.append(entry.getKey().toString()); builder.append(" => "); builder.append(StringUtil.separate(entry.getValue())); builder.append(String.format("%n")); } return builder.toString(); } // if visitor is null then the topsort only checks if there are cycles // and returns all "root" nodes (i.e. nodes with no inward edges). If // visitor is non-null then the return result will be null and should be // ignored private AbstractCollection<Node<T>> _topsort(TopSortVisitor<T> visitor) { Map<Node<T>, Integer> inwardCount = computeUnvisitedInwardCount(); markUnvisited(); for (Node<T> node : graph.keySet()) { inwardCount.put(node, 0); // and also initialize our count list List<T> cycles = cycleCheckingVisit(node); if (cycles != null) { if (visitor != null) { int i = cycles.lastIndexOf(cycles.get(0)); if (i == 0) { i = cycles.size() - 1; } visitor.cycleFound(cycles.subList(0, i + 1)); } return null; } } for (Set<Node<T>> adjList : graph.values()) { for (Node<T> pointedTo : adjList) { inwardCount.put(pointedTo, inwardCount.get(pointedTo) + 1); } } HashSet<Node<T>> zeroList = new HashSet<Node<T>>(); for (Entry<Node<T>, Integer> entry : inwardCount.entrySet()) { if (entry.getValue() == 0) { zeroList.add(entry.getKey()); } } if (visitor == null) { return zeroList; } List<T> toVisit = new ArrayList<T>(nodes.size()); while (!zeroList.isEmpty()) { Node<T> node = zeroList.iterator().next(); zeroList.remove(node); toVisit.add(node.getValue()); Set<Node<T>> adjList = graph.get(node); for (Node<T> pointedTo : adjList) { int count = inwardCount.get(pointedTo) - 1; inwardCount.put(pointedTo, count); if (count == 0) { zeroList.add(pointedTo); } } } for (T value : toVisit) { visitor.visit(value); } return null; } private List<T> cycleCheckingVisit(Node<T> node) { if (node.getVisited() == VISIT_STARTED) { List<T> result = new ArrayList<T>(); result.add(node.getValue()); return result; } if (node.getVisited() == VISITED) return null; node.setVisited(VISIT_STARTED); Set<Node<T>> adjList = graph.get(node); for (Node<T> pointedTo : adjList) { List<T> result = cycleCheckingVisit(pointedTo); if (result != null) { result.add(0, node.getValue()); return result; } } node.setVisited(VISITED); return null; } private Node<T> beginTraversal(T start) { assumeInGraph(start); Node<T> startNode = nodes.get(start); markUnvisited(); return startNode; } private void assumeInGraph(T start) { if (!nodes.containsKey(start)) { throw new IllegalArgumentException("Node is not in this graph"); } } private void markUnvisited() { for (Node n : graph.keySet()) { n.setVisited(NOT_VISITED); } } private void _dfs(Node<T> node, NodeVisitor<T> visitor) { markAndImmediatelyVisit(node, visitor); Set<Node<T>> connectedTo = graph.get(node); for (Node<T> v : connectedTo) { if (v.getVisited() == NOT_VISITED) { _dfs(v, visitor); } } } private void markAndImmediatelyVisit(Node<T> node, NodeVisitor<T> visitor) { visitor.visit(node.getValue()); node.setVisited(VISITED); } private static class Node<T> { private T value; private Visited visited = NOT_VISITED; public Node(T data) { this.value = data; } public T getValue() { return value; } public void setValue(T data) { this.value = data; } public Visited getVisited() { return visited; } public void setVisited(Visited visited) { this.visited = visited; } @Override public final int hashCode() { return super.hashCode(); } @Override public final boolean equals(Object o) { return super.equals(o); } @Override public String toString() { return String.valueOf(value); } } static enum Visited { NOT_VISITED, VISIT_STARTED, VISITED } }