package org.neo4j.graphalgo.shortestpath; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedList; import java.util.Map; import java.util.Set; import java.util.TreeMap; import org.neo4j.commons.iterator.PrefetchingIterator; import org.neo4j.graphalgo.PathImpl; import org.neo4j.graphdb.GraphDatabaseService; import org.neo4j.graphdb.Node; import org.neo4j.graphdb.Path; import org.neo4j.graphdb.Relationship; import org.neo4j.graphdb.RelationshipExpander; public class AStar implements PathFinder { private final GraphDatabaseService graphDb; private final RelationshipExpander expander; private final CostEvaluator<Double> lengthEvaluator; private final EstimateEvaluator<Double> estimateEvaluator; public AStar( GraphDatabaseService graphDb, RelationshipExpander expander, CostEvaluator<Double> lengthEvaluator, EstimateEvaluator<Double> estimateEvaluator ) { this.graphDb = graphDb; this.expander = expander; this.lengthEvaluator = lengthEvaluator; this.estimateEvaluator = estimateEvaluator; } public Path findSinglePath( Node start, Node end ) { Doer doer = new Doer( start, end ); while ( doer.hasNext() ) { Node node = doer.next(); if ( node.equals( end ) ) { // Hit, return path LinkedList<Relationship> rels = new LinkedList<Relationship>(); Relationship rel = graphDb.getRelationshipById( doer.cameFrom.get( node.getId() ) ); while ( rel != null ) { rels.addFirst( rel ); node = rel.getOtherNode( node ); Long nextRelId = doer.cameFrom.get( node.getId() ); rel = nextRelId == null ? null : graphDb.getRelationshipById( nextRelId ); } Path path = toPath( start, rels ); return path; } } return null; } public Collection<Path> findPaths( Node node, Node end ) { Path path = findSinglePath( node, end ); return path != null ? Arrays.asList( path ) : Collections.<Path>emptyList(); } private Path toPath( Node start, LinkedList<Relationship> rels ) { PathImpl.Builder builder = new PathImpl.Builder( start ); for ( Relationship rel : rels ) { builder = builder.push( rel ); } return builder.build(); } private static class Data { private double wayLength; // acumulated cost to get here (g) private double estimate; // heuristic estimate of cost to reach end (h) double getFscore() { return wayLength + estimate; } } private class Doer extends PrefetchingIterator<Node> { private final Node end; private Node lastNode; private boolean expand; private final Set<Long> visitedNodes = new HashSet<Long>(); private final Set<Node> nextNodesSet = new HashSet<Node>(); private final TreeMap<Double, Collection<Node>> nextNodes = new TreeMap<Double, Collection<Node>>(); private final Map<Long, Long> cameFrom = new HashMap<Long, Long>(); private final Map<Long, Data> score = new HashMap<Long, Data>(); Doer( Node start, Node end ) { this.end = end; Data data = new Data(); data.wayLength = 0; data.estimate = estimateEvaluator.getCost( start, end ); addNext( start, data.getFscore() ); this.score.put( start.getId(), data ); } private void addNext( Node node, double fscore ) { Collection<Node> nodes = this.nextNodes.get( fscore ); if ( nodes == null ) { nodes = new HashSet<Node>(); this.nextNodes.put( fscore, nodes ); } nodes.add( node ); this.nextNodesSet.add( node ); } private Node popLowestScoreNode() { Iterator<Map.Entry<Double, Collection<Node>>> itr = this.nextNodes.entrySet().iterator(); if ( !itr.hasNext() ) { return null; } Map.Entry<Double, Collection<Node>> entry = itr.next(); Node node = entry.getValue().isEmpty() ? null : entry.getValue().iterator().next(); if ( node == null ) { return null; } if ( node != null ) { entry.getValue().remove( node ); this.nextNodesSet.remove( node ); if ( entry.getValue().isEmpty() ) { this.nextNodes.remove( entry.getKey() ); } this.visitedNodes.add( node.getId() ); } return node; } @Override protected Node fetchNextOrNull() { // FIXME if ( !this.expand ) { this.expand = true; } else { expand(); } Node node = popLowestScoreNode(); this.lastNode = node; return node; } private void expand() { for ( Relationship rel : expander.expand( this.lastNode ) ) { Node node = rel.getOtherNode( this.lastNode ); if ( this.visitedNodes.contains( node.getId() ) ) { continue; } Data lastNodeData = this.score.get( this.lastNode.getId() ); double tentativeGScore = lastNodeData.wayLength + lengthEvaluator.getCost( rel, false ); boolean isBetter = false; double estimate = estimateEvaluator.getCost( node, this.end ); if ( !this.nextNodesSet.contains( node ) ) { addNext( node, estimate + tentativeGScore ); isBetter = true; } else if ( tentativeGScore < this.score.get( node.getId() ).wayLength ) { isBetter = true; } if ( isBetter ) { this.cameFrom.put( node.getId(), rel.getId() ); Data data = new Data(); data.wayLength = tentativeGScore; data.estimate = estimate; this.score.put( node.getId(), data ); } } } } }