/**
* KDTree.java
*/
package rampancy.util.data.kdTree;
import java.util.*;
/**
* @author Matthew Chun-Lum
*
*/
public class KDTree<T extends KDPoint> {
private KDNode<T> rootNode;
private List<T> points; // so we can fetch all of the points in constant time
private List<Comparator<T>> comparators;
private int numDimensions;
/**
* Constructor
* @param comparators
*/
public KDTree(List<Comparator<T>> comparators) {
this(new ArrayList<T>(), comparators);
}
/**
* Constructor
* @param points
* @param comparators
*/
public KDTree(List<T> points, List<Comparator<T>> comparators) {
this(points, points.size(), comparators);
}
/**
* Constructor
* @param points
* @param size
* @param comparators
*/
public KDTree(List<T> points, int size, List<Comparator<T>> comparators) {
this.points = points;
this.comparators = comparators;
this.numDimensions = comparators.size();
rootNode = generateInitialNodes(this.points.subList(0, size), 0);
}
/**
* Finds the nearest neighbor to the passed point
* @param point
* @param distanceFunction
* @return
*/
public T getNearestNeighbor(T point) {
List<T> nearest = getKNearestNeighbors(point, 1);
if(nearest == null || nearest.isEmpty())
return null;
return nearest.get(0);
}
/**
* Finds the k nearest neighbors to the passed point
* @param point
* @param k
* @param distanceFunction
* @return
*/
public List<T> getKNearestNeighbors(T point, int k) {
if(rootNode == null)
return new ArrayList<T>(); // return an empty list
KDNearestSearch<T> nearestSearch = new KDNearestSearch<T>(point, k);
recursiveFindKNearestNeighbors(nearestSearch, rootNode);
return nearestSearch.getNearestNeighbors();
}
/**
* Adds the point to the tree, splits nodes as necessary
* @param point
*/
public void add(T point) {
points.add(point);
if(rootNode == null) {
rootNode = generateNewNode(point, 0);
} else {
recursiveAdd(point, rootNode);
}
}
/**
* @return a pointer to the root node (mainly used in testing)
*/
public KDNode<T> getRootNode() {
return rootNode;
}
/**
* @return the number of dimensions in this tree
*/
public int getNumberOfDimensions() {
return numDimensions;
}
/**
* @return all of the points in the tree.
*/
public List<T> getAllPoints() {
return points;
}
// ------------------ Private ------------------ //
/**
* Recursively finds the K nearest neighbors
*/
private void recursiveFindKNearestNeighbors(KDNearestSearch<T> search, KDNode<T> node) {
if(node.isLeaf()) {
search.attemptToAdd(node.median);
} else {
boolean canMoveLeft = (node.left != null);
boolean canMoveRight = (node.right != null);
boolean targetToTheLeft = isToTheLeftOf(search.target, node.median, node.axis);
boolean wentLeft = false;
boolean wentRight = false;
if(canMoveLeft && (!canMoveRight || targetToTheLeft)) {
recursiveFindKNearestNeighbors(search, node.left);
wentLeft = true;
} else if(canMoveRight && (!canMoveLeft || !targetToTheLeft)) {
recursiveFindKNearestNeighbors(search, node.right);
wentRight = true;
}
if(search.attemptToAdd(node.median)) {
KDNearestSearch<T> newSearch = new KDNearestSearch<T>(search.target, search.k);
if(wentLeft && canMoveRight) {
recursiveFindKNearestNeighbors(newSearch, node.right);
} else if(wentRight && canMoveLeft) {
recursiveFindKNearestNeighbors(newSearch, node.left);
}
if(!newSearch.nearestNeighbors.isEmpty()) {
search.nearestNeighbors.addAll(newSearch.nearestNeighbors);
}
}
}
}
/**
* Recursively add the point to the tree, generating new nodes as necessary
* @param point
* @param node
*/
private void recursiveAdd(T point, KDNode<T> node) {
if(isToTheLeftOf(point, node.median, node.axis)) {
if(node.left == null) {
node.left = generateNewNode(point, node.depth + 1);
} else {
recursiveAdd(point, node.left);
}
} else {
if(node.right == null) {
node.right = generateNewNode(point, node.depth + 1);
} else {
recursiveAdd(point, node.right);
}
}
}
/**
* @param point
* @param median
* @param axis
* @return {@code true} if the passed point is less than or equal to the median
*/
private boolean isToTheLeftOf(T point, T median, int axis) {
return comparators.get(axis).compare(point, median) < 0 ;
}
/**
* Sets the root node to contain the passed point
* @param point
*/
private KDNode<T> generateNewNode(T point, int depth) {
KDNode<T> node = new KDNode<T>();
node.axis = depth % numDimensions;
node.depth = depth;
node.median = point;
return node;
}
/**
* Recursively build the tree from a passed list of KDPoints
* @param points
* @param depth
* @return the pointer to the root node
*/
private KDNode<T> generateInitialNodes(List<T> points, int depth) {
if(points.isEmpty())
return null;
int axis = depth % numDimensions;
Collections.sort(points, comparators.get(axis));
int medianIndex = (points.size()) / 2;
T median = points.get(medianIndex);
KDNode<T> node = new KDNode<T>();
node.axis = axis;
node.depth = depth;
node.median = median;
List<T> left = points.subList(0, medianIndex);
List<T> right = points.subList(medianIndex + 1, points.size());
node.left = generateInitialNodes(left, depth + 1);
node.right = generateInitialNodes(right, depth + 1);
return node;
}
}