package de.fub.agg2graph.structs.frechet; import de.fub.agg2graph.agg.AggNode; import de.fub.agg2graph.structs.GPSPoint; import de.fub.agg2graph.structs.GPSRegion; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.List; /** * Static KD-Tree with search of nearest neighbor, k-nearest-neighbors and * points in axis aligned box. * * Used basic implementation O(n log^2 n) recursive build with alternating * selection of axis. * * * @param <L> */ public class KdTreeIndex<L extends AggNode> implements SearchIndex<L> { public KdTreeIndex(Collection<L> locations) { root = build(new ArrayList<L>(locations), 0); } @Override public L searchNN(L point) { SearchState<L> state = new SearchState<L>(); root.searchNN(point, 0, state); return state.bestSoFar; } @Override public List<L> searchKnn(L point, int k) { SearchState<L> state = new SearchState<>(); root.searchKnn(point, 0, state, k); List<L> result = new ArrayList<>(); for (int i = 0; i < k && i < state.bests.size(); ++i) { result.add(state.bests.get(i).value); } return result; } @Override public List<L> searchRegion(GPSRegion region) { assert (region.minLocation.getLon() <= region.maxLocation.getLon()); assert (region.minLocation.getLat() <= region.maxLocation.getLat()); List<L> result = new ArrayList<L>(); root.searchRegion(region, 0, result); return result; } public int size() { return root.size(); } static class SearchState<L extends GPSPoint> { L bestSoFar = null; ArrayList<Tuple<Double, L>> bests = new ArrayList<>(); double minimumDistance = Double.POSITIVE_INFINITY; } static class Tuple<K extends Comparable<K>, V> implements Comparable<Tuple<K, V>> { public K key; public V value; Tuple(K key, V value) { this.key = key; this.value = value; } @Override public int compareTo(Tuple<K, V> other) { if (other == null) { throw new NullPointerException(); } return key.compareTo(other.key); } } private class Node { ArrayList<L> locations; Node left; Node right; Node(L location) { locations = new ArrayList<L>(); locations.add(location); left = null; right = null; } public Node(ArrayList<L> locations) { this.locations = locations; left = null; right = null; } void searchNN(L point, int depth, SearchState<L> state) { for (L location : locations) { double pointDistance = point.getSquaredDistanceTo(location); if (pointDistance < state.minimumDistance) { state.bestSoFar = location; state.minimumDistance = pointDistance; } } final int axis = getAxis(depth); Node close, far; // All attached locations share the axis coordinate so we use location at 0 as a representative. double diff = distanceToAxis(axis, locations.get(0), point); if (diff <= 0) { close = right; far = left; } else { close = left; far = right; } if (close != null) { close.searchNN(point, depth + 1, state); } if (diff * diff < state.minimumDistance && far != null) { far.searchNN(point, depth + 1, state); } } public void searchKnn(L point, int depth, SearchState<L> state, int k) { Tuple<Double, L> maxTp = (state.bests.isEmpty()) ? null : Collections.max(state.bests); for (L location : locations) { double pointDistance = point.getSquaredDistanceTo(location); if (maxTp == null) { maxTp = new Tuple<Double, L>(pointDistance, location); state.bests.add(maxTp); } else if (state.bests.size() < k || pointDistance < maxTp.key) { state.bests.add(new Tuple<Double, L>(pointDistance, location)); } if (state.bests.size() > k) { // could be maximum of 1 diff here. state.bests.remove(maxTp); } } final int axis = getAxis(depth); Node close, far; // All attached locations share the axis coordinate so we use location at 0 as a representative. double diff = distanceToAxis(axis, locations.get(0), point); if (diff <= 0) { close = right; far = left; } else { close = left; far = right; } if (close != null) { close.searchKnn(point, depth + 1, state, k); } if (far != null && (diff * diff < maxTp.key || state.bests.size() < k)) { far.searchKnn(point, depth + 1, state, k); } } public void searchRegion(GPSRegion region, int depth, List<L> result) { int axis = getAxis(depth); double cmp1 = distanceToAxis(axis, locations.get(0), region.minLocation); double cmp2 = distanceToAxis(axis, locations.get(0), region.maxLocation); if (cmp1 >= 0 && left != null) { left.searchRegion(region, depth + 1, result); } if (cmp2 <= 0 && right != null) { right.searchRegion(region, depth + 1, result); } if (cmp1 >= 0 && cmp2 <= 0) { for (L location : locations) { if (region.contains(location)) { result.add(location); } } } } public int size() { return locations.size() + ((left != null) ? left.size() : 0) + ((right != null) ? right.size() : 0); } } private Node build(List<L> points, int depth) { if (points.isEmpty()) { return null; } if (points.size() == 1) { return new Node(points.get(0)); } final int axis = getAxis(depth); int median = points.size() / 2; sortByAxis(points, axis); ArrayList<L> listOfEqualCoord = new ArrayList<L>(); L medianObject = points.get(median); listOfEqualCoord.add(medianObject); int higherIndex = median + 1; for (int i = higherIndex; i < points.size(); ++i) { L current = points.get(i); if (comparators[axis].compare(medianObject, current) == 0) { listOfEqualCoord.add(current); higherIndex = i; } else { break; } } int lowerIndex = median - 1; for (int i = lowerIndex; i > 0; --i) { L current = points.get(i); if (comparators[axis].compare(medianObject, current) == 0) { listOfEqualCoord.add(current); lowerIndex = i; } else { break; } } Node node = new Node(listOfEqualCoord); node.left = build(points.subList(0, lowerIndex + 1), depth + 1); node.right = build(points.subList(higherIndex, points.size()), depth + 1); return node; } private static int getAxis(int depth) { return depth % 2; } private static double distanceToAxis(int axis, GPSPoint here, GPSPoint point) { if (axis == 0) { return here.getLon() - point.getLon(); } else { return here.getLat() - point.getLat(); } } private void sortByAxis(List<L> points, final int axis) { Collections.sort(points, comparators[axis]); } /** * Defined axis comparators. Used in the sort step for sorting based on the * axis. */ @SuppressWarnings("unchecked") static Comparator<GPSPoint> comparators[] = new Comparator[2]; { // static initialize the comparators. comparators[0] = new LongitudeComparator(); comparators[1] = new LatitudeComparator(); } private static class LatitudeComparator implements Comparator<GPSPoint> { @Override public int compare(GPSPoint o1, GPSPoint o2) { return Double.compare(o1.getLat(), o2.getLat()); } } private static class LongitudeComparator implements Comparator<GPSPoint> { @Override public int compare(GPSPoint o1, GPSPoint o2) { return Double.compare(o1.getLon(), o2.getLon()); } } private Node root = null; }