package org.vertexium.accumulo;
import org.apache.accumulo.core.client.IteratorSetting;
import org.apache.accumulo.core.client.ScannerBase;
import org.apache.accumulo.core.data.Key;
import org.apache.accumulo.core.data.Value;
import org.apache.accumulo.core.trace.Span;
import org.apache.accumulo.core.trace.Trace;
import org.vertexium.*;
import org.vertexium.accumulo.iterator.ConnectedVertexIdsIterator;
import org.vertexium.accumulo.util.RangeUtils;
import org.vertexium.util.IterableUtils;
import org.vertexium.util.VertexiumLogger;
import org.vertexium.util.VertexiumLoggerFactory;
import java.io.IOException;
import java.util.*;
import java.util.stream.Collectors;
public class AccumuloFindPathStrategy {
private static final VertexiumLogger LOGGER = VertexiumLoggerFactory.getLogger(AccumuloFindPathStrategy.class);
private final AccumuloGraph graph;
private final FindPathOptions options;
private final ProgressCallback progressCallback;
private final Authorizations authorizations;
public AccumuloFindPathStrategy(
AccumuloGraph graph,
FindPathOptions options,
ProgressCallback progressCallback,
Authorizations authorizations
) {
this.graph = graph;
this.options = options;
this.progressCallback = progressCallback;
this.authorizations = authorizations;
}
public Iterable<Path> findPaths() {
progressCallback.progress(0, ProgressCallback.Step.FINDING_PATH);
List<Path> foundPaths = new ArrayList<>();
if (options.getMaxHops() < 1) {
throw new IllegalArgumentException("maxHops cannot be less than 1");
} else if (options.getMaxHops() == 1) {
Set<String> sourceConnectedVertexIds = getConnectedVertexIds(options.getSourceVertexId());
if (sourceConnectedVertexIds.contains(options.getDestVertexId())) {
foundPaths.add(new Path(options.getSourceVertexId(), options.getDestVertexId()));
}
} else if (options.getMaxHops() == 2) {
findPathsSetIntersection(foundPaths);
} else {
findPathsBreadthFirst(foundPaths, options.getSourceVertexId(), options.getDestVertexId(), options.getMaxHops());
}
progressCallback.progress(1, ProgressCallback.Step.COMPLETE);
return foundPaths;
}
private void findPathsSetIntersection(List<Path> foundPaths) {
String sourceVertexId = options.getSourceVertexId();
String destVertexId = options.getDestVertexId();
Set<String> vertexIds = new HashSet<>();
vertexIds.add(sourceVertexId);
vertexIds.add(destVertexId);
Map<String, Set<String>> connectedVertexIds = getConnectedVertexIds(vertexIds);
progressCallback.progress(0.1, ProgressCallback.Step.SEARCHING_SOURCE_VERTEX_EDGES);
Set<String> sourceVertexConnectedVertexIds = connectedVertexIds.get(sourceVertexId);
if (sourceVertexConnectedVertexIds == null) {
return;
}
progressCallback.progress(0.3, ProgressCallback.Step.SEARCHING_DESTINATION_VERTEX_EDGES);
Set<String> destVertexConnectedVertexIds = connectedVertexIds.get(destVertexId);
if (destVertexConnectedVertexIds == null) {
return;
}
if (sourceVertexConnectedVertexIds.contains(destVertexId)) {
foundPaths.add(new Path(sourceVertexId, destVertexId));
if (options.isGetAnyPath()) {
return;
}
}
progressCallback.progress(0.6, ProgressCallback.Step.MERGING_EDGES);
sourceVertexConnectedVertexIds.retainAll(destVertexConnectedVertexIds);
progressCallback.progress(0.9, ProgressCallback.Step.ADDING_PATHS);
foundPaths.addAll(
sourceVertexConnectedVertexIds.stream()
.map(connectedVertexId -> new Path(sourceVertexId, connectedVertexId, destVertexId))
.collect(Collectors.toList())
);
}
private void findPathsBreadthFirst(List<Path> foundPaths, String sourceVertexId, String destVertexId, int hops) {
Map<String, Set<String>> connectedVertexIds = getConnectedVertexIds(sourceVertexId, destVertexId);
// start at 2 since we already got the source and dest vertex connected vertex ids
for (int i = 2; i < hops; i++) {
progressCallback.progress((double) i / (double) hops, ProgressCallback.Step.FINDING_PATH);
Set<String> vertexIdsToSearch = new HashSet<>();
for (Map.Entry<String, Set<String>> entry : connectedVertexIds.entrySet()) {
vertexIdsToSearch.addAll(entry.getValue());
}
vertexIdsToSearch.removeAll(connectedVertexIds.keySet());
Map<String, Set<String>> r = getConnectedVertexIds(vertexIdsToSearch);
connectedVertexIds.putAll(r);
}
progressCallback.progress(0.9, ProgressCallback.Step.ADDING_PATHS);
Set<String> seenVertices = new HashSet<>();
Path currentPath = new Path(sourceVertexId);
findPathsRecursive(connectedVertexIds, foundPaths, sourceVertexId, destVertexId, hops, seenVertices, currentPath, progressCallback);
}
private void findPathsRecursive(
Map<String, Set<String>> connectedVertexIds,
List<Path> foundPaths,
final String sourceVertexId,
String destVertexId,
int hops,
Set<String> seenVertices,
Path currentPath,
@SuppressWarnings("UnusedParameters") ProgressCallback progressCallback
) {
if (options.isGetAnyPath() && foundPaths.size() == 1) {
return;
}
seenVertices.add(sourceVertexId);
if (sourceVertexId.equals(destVertexId)) {
foundPaths.add(currentPath);
} else if (hops > 0) {
Set<String> vertexIds = connectedVertexIds.get(sourceVertexId);
if (vertexIds != null) {
for (String childId : vertexIds) {
if (!seenVertices.contains(childId)) {
findPathsRecursive(connectedVertexIds, foundPaths, childId, destVertexId, hops - 1, seenVertices, new Path(currentPath, childId), progressCallback);
}
}
}
}
seenVertices.remove(sourceVertexId);
}
private Set<String> getConnectedVertexIds(String vertexId) {
Set<String> vertexIds = new HashSet<>();
vertexIds.add(vertexId);
Map<String, Set<String>> results = getConnectedVertexIds(vertexIds);
Set<String> vertexIdResults = results.get(vertexId);
if (vertexIdResults == null) {
return new HashSet<>();
}
return vertexIdResults;
}
private Map<String, Set<String>> getConnectedVertexIds(String vertexId1, String vertexId2) {
Set<String> vertexIds = new HashSet<>();
vertexIds.add(vertexId1);
vertexIds.add(vertexId2);
return getConnectedVertexIds(vertexIds);
}
private Map<String, Set<String>> getConnectedVertexIds(Set<String> vertexIds) {
Span trace = Trace.start("getConnectedVertexIds");
try {
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("getConnectedVertexIds:\n %s", IterableUtils.join(vertexIds, "\n "));
}
if (vertexIds.size() == 0) {
return new HashMap<>();
}
List<org.apache.accumulo.core.data.Range> ranges = new ArrayList<>();
for (String vertexId : vertexIds) {
ranges.add(RangeUtils.createRangeFromString(vertexId));
}
int maxVersions = 1;
Long startTime = null;
Long endTime = null;
ScannerBase scanner = graph.createElementScanner(
FetchHint.EDGE_REFS,
ElementType.VERTEX,
maxVersions,
startTime,
endTime,
ranges,
false,
authorizations
);
IteratorSetting connectedVertexIdsIteratorSettings = new IteratorSetting(
1000,
ConnectedVertexIdsIterator.class.getSimpleName(),
ConnectedVertexIdsIterator.class
);
ConnectedVertexIdsIterator.setLabels(connectedVertexIdsIteratorSettings, options.getLabels());
ConnectedVertexIdsIterator.setExcludedLabels(connectedVertexIdsIteratorSettings, options.getExcludedLabels());
scanner.addScanIterator(connectedVertexIdsIteratorSettings);
final long timerStartTime = System.currentTimeMillis();
try {
Map<String, Set<String>> results = new HashMap<>();
for (Map.Entry<Key, Value> row : scanner) {
try {
Set<String> rowVertexIds = ConnectedVertexIdsIterator.decodeValue(row.getValue());
results.put(row.getKey().getRow().toString(), rowVertexIds);
} catch (IOException e) {
throw new VertexiumException("Could not decode vertex ids for row: " + row.getKey().toString(), e);
}
}
return results;
} finally {
scanner.close();
AccumuloGraph.GRAPH_LOGGER.logEndIterator(System.currentTimeMillis() - timerStartTime);
}
} finally {
trace.stop();
}
}
}