package ca.pfv.spmf.datastructures.kdtree;
/* This file is copyright (c) 2008-2013 Philippe Fournier-Viger
*
* This file is part of the SPMF DATA MINING SOFTWARE
* (http://www.philippe-fournier-viger.com/spmf).
*
* SPMF is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* SPMF is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
* A PARTICULAR PURPOSE. See the GNU General Public License for more details.
* You should have received a copy of the GNU General Public License along with
* SPMF. If not, see <http://www.gnu.org/licenses/>.
*/
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import ca.pfv.spmf.algorithms.clustering.distanceFunctions.DistanceEuclidian;
import ca.pfv.spmf.algorithms.clustering.distanceFunctions.DistanceFunction;
import ca.pfv.spmf.datastructures.redblacktree.RedBlackTree;
import ca.pfv.spmf.patterns.cluster.DoubleArray;
/**
* This is an implementation of a "KD tree" based on the description in the
* book: "Algorithms in a Nutshell" by Heineman et al. (2008).
* <br/><br/>
*
* This implementation uses the Randomized-Select algorithm described in the
* book "Introduction to algorithms" book by Cormen et al. (2001) as suggested
* by Heineman.
* <br/><br/>
*
* Elements that are inserted into the tree have to be arrays of double.
* <br/><br/>
*
* The class provide methods for
* - building the tree by inserting points,
* - using the tree to find the nearest neighbor to a given point
* - using the tree to find the k nearest neighbors to a given point
* <br/><br/>
*
* To find the k-nearest neighboors, the closest points are stored in a red black tree.
*
* @author Philippe Fournier-Viger
*/
public class KDTree {
private int nodeCount = 0; // number of nodes in the tree
private KDNode root = null; // the tree root
int dimensionCount = 0; // number of dimensions
// random number generator used by the Randomized-select algorithm
private static Random random = new Random(System.currentTimeMillis());
/* The distance function to be used for comparing vectors */
DistanceFunction distanceFunction = new DistanceEuclidian();
/**
* Default constructor
*/
public KDTree() {
}
/**
* Get the number of nodes in the KD-TREE
* @return the number of nodes
*/
public int size() {
return nodeCount;
}
/**
* This method build the KDtree from a set of points.
* This method should be called only once.
* @param points an array of points, where each point is a DoubleArray
*/
public void buildtree(List<DoubleArray> points) {
if (points.size() == 0) {
return;
}
dimensionCount = points.get(0).size();
root = generateNode(0, points, 0, points.size() - 1);
}
/**
* Generate a node for the d-dimension for points (left, right).
* @param currentD the current dimension
* @param points arrays of points
* @param left left
* @param right right
* @return a node
*/
private KDNode generateNode(int currentD, List<DoubleArray> points, int left, int right) {
// if there is no point
if (right < left) {
return null;
}
nodeCount++;
// if there is only a single point
if (right == left) {
return new KDNode(points.get(left), currentD);
}
// else if there is more than one point
// We calculate the desired rank that correspond to the median.
int m = (right - left) / 2;
// we select the median point
DoubleArray medianNode = randomizedSelect(points, m, left, right, currentD);
// we will use this point to separate the two lower branches of the tree.
KDNode node = new KDNode(medianNode, currentD);
currentD++;
if (currentD == dimensionCount) {
currentD = 0;
}
// recursively create subnodes for the two branches of the three
node.below = generateNode(currentD, points, left, left + m -1);
node.above = generateNode(currentD, points, left + m +1, right);
return node;
}
/**
* Method to select the ith smallest integer of an array in average linear
* time. It is based on the pseudo-code of the Randomized-Select algorithm in
* the book "Introduction to algorithms" by Cormen et al. (2001), with some
* modifications such as using a while loop instead of recursive calls.
*
* @param a: array of integers
* @param i: the rank i of the desired integer.
* @param currentD: the dimension that is used
* @return the element in the array "a" that is larger than i elements.
*/
private DoubleArray randomizedSelect(List<DoubleArray> points, int i, int left,
int right, int currentD) {
int p = left;
int r = right;
while (true) {
if (p == r) {
return points.get(p);
}
int q = randomizedPartition(points, p, r, currentD);
int k = q - p + 1;
if (i == k - 1) {
return points.get(q);
} else if (i < k) {
r = q - 1;
} else {
i = i - k;
p = q + 1;
}
}
}
/**
* Private method used by the randomized-select method
* (see the book for details).
*/
private int randomizedPartition(List<DoubleArray> points, int p, int r, int currentD) {
int i = 0;
if (p < r) {
i = p + random.nextInt(r - p);
} else {
i = r + random.nextInt(p - r);
}
swap(points, r, i);
return partition(points, p, r, currentD); // call the partition method of
// quicksort.
}
/**
* Private method used by the randomized-select method
* (see the book for details).
*/
private int partition(List<DoubleArray> points, int p, int r, int currentD) {
DoubleArray x = points.get(r);
int i = p - 1;
for (int j = p; j <= r - 1; j++) {
if (points.get(j).data[currentD] <= x.data[currentD]) {
i = i + 1;
swap(points, i, j);
}
}
swap(points, i + 1, r);
return i + 1;
}
/**
* swapping two points in an array.
* @param points the array
* @param i the first point
* @param j the second point
*/
private void swap(List<DoubleArray> points, int i, int j) {
DoubleArray valueI = points.get(i);
points.set(i,points.get(j));
points.set(j, valueI);
}
//=====================================================================================
//======================= To find the first nearest neighbor =========================
//=====================================================================================
DoubleArray nearestNeighboor = null; // the current nearest neighboor.
double minDist = 0; // the distance of the current nearest neighbor with the target point.
/**
* Method to get the nearest neighbor
*/
public DoubleArray nearest(DoubleArray targetPoint) {
if (root == null){
return null;
}
// Find the node where the point would be inserted and calculate the distance
findParent(targetPoint, root, 0);
// After that, start from the root and check all rectangles that overlap the
// distance with the parent. If a point with a distance smaller than the parent is found,
// then return it.
nearest(root, targetPoint);
return nearestNeighboor;
}
/**
* This method find the node where this point would be inserted in the kdd-tree.
* @param target : the point
* @param node : the current node in the tree
* @param d : the dimension used at this level of the tree.
* @return : the node where the point would be inserted.
*/
private void findParent(DoubleArray target, KDNode node, int d) {
// IF the node would be inserted in the branch "below" this node.
if(target.data[d] < node.values.data[d]){
if (++d == dimensionCount) {
d = 0;
}
if(node.below == null){
nearestNeighboor = node.values;
minDist = distanceFunction.calculateDistance(node.values, target);
return;
}
findParent(target, node.below, d);
}
// IF the node would be inserted in the branch "above" this node.
if(++d == dimensionCount) {
d = 0;
}
if(node.above == null){
nearestNeighboor = node.values;
minDist = distanceFunction.calculateDistance(node.values, target);
return;
}
findParent(target, node.above, d);
}
private void nearest(KDNode node, DoubleArray targetPoint) {
// If shorter, update minimum
double d = distanceFunction.calculateDistance(node.values, targetPoint);
if (d < minDist) {
minDist = d;
nearestNeighboor = node.values;
}
int dMinus1 = node.d-1;
if(dMinus1 <0){
dMinus1 = dimensionCount - 1;
}
// calculate perpendiculary distance with preceding dimensions.
double perpendicularyDistance = Math.abs(node.values.data[node.d] - targetPoint.data[dMinus1]);
if (perpendicularyDistance < minDist) {
// explore both side of the tree
if (node.above != null) {
nearest(node.above, targetPoint);
}
if (node.below != null) {
nearest(node.below, targetPoint);
}
} else {
// only explore one side of the three
if (targetPoint.data[dMinus1] < node.values.data[node.d]) {
if (node.below != null) {
nearest(node.below, targetPoint);
}
} else if (node.above != null) {
nearest(node.above, targetPoint);
}
}
}
// /**
// * Calcule the euclidian distance between two points
// * @param node1 the first point.
// * @param node2 the second point.
// * @return
// */
// private double distance(DoubleArray node1, DoubleArray node2) {
// double sum = 0;
// for(int i=0; i< node1.length; i++){
// sum += Math.pow(node1[i] - node2[i], 2);
// }
// return Math.sqrt(sum);
// }
//
//=====================================================================================
//======================= Method to find the k nearest neighboor =========================
//=====================================================================================
RedBlackTree<KNNPoint> resultKNN = null; // field to store the current k nearest neighboor with the target point
int k =0; // the parameter k.
/**
* Method to get the k nearest neighboors
*/
public RedBlackTree<KNNPoint> knearest(DoubleArray targetPoint, int k) {
this.k = k;
this.resultKNN = new RedBlackTree<KNNPoint>();
if (root == null){
return null;
}
// First traverse the tree to find the place where the node would be inserted.
findParent_knn(targetPoint, root, 0);
// Now start back at the root, and check all rectangles that have a perpendicular distance
// smaller than the k best points found until now.
nearest_knn(root, targetPoint);
// return the k nearest neighbors.
return resultKNN;
}
/**
* traverse the tree to find the place where the node would be inserted.
* @param target the vector
* @param node the current node
* @param d the current dimension
*/
private void findParent_knn(DoubleArray target, KDNode node, int d) {
// If the node would be inserted in the branch "below"
if(target.data[d] < node.values.data[d]){
if (++d == dimensionCount) {
d = 0;
}
if(node.below == null){
tryToSave(node, target);
return;
}
tryToSave(node.below, target);
findParent_knn(target, node.below, d);
}
// If the node would be inserted in the branch "above".
if(++d == dimensionCount) {
d = 0;
}
if(node.above == null){
tryToSave(node, target);
return;
}
tryToSave(node.above, target);
findParent_knn(target, node.above, d);
}
/**
* Method to try to save a node in the set of the current closest k neighbors.
* @param node the node to be added.
* @param target the target node.
*/
private void tryToSave(KDNode node, DoubleArray target) {
if(node == null){
return;
}
double distance = distanceFunction.calculateDistance(target, node.values);
if(resultKNN.size() == k && resultKNN.maximum().distance < distance){
return;
}
KNNPoint point = new KNNPoint(node.values, distance);
if(resultKNN.contains(point)){
return;
}
resultKNN.add(point);
if(resultKNN.size() > k){
resultKNN.popMaximum();
}
}
/**
* Start back at the root, and check all rectangles that have a perpendicular distance
* smaller than the k best points found until now.
* @param node the current node
* @param targetPoint the vector
*/
private void nearest_knn(KDNode node, DoubleArray targetPoint) {
tryToSave(node, targetPoint);
int dMinus1 = node.d-1;
if(dMinus1 < 0){
dMinus1 = dimensionCount - 1;
}
double perpendicularDistance = Math.abs(node.values.data[node.d] - targetPoint.data[dMinus1]);
if (perpendicularDistance < resultKNN.maximum().distance) {
// explore the "above" and "below" branches.
if (node.above != null) {
nearest_knn(node.above, targetPoint);
}
if (node.below != null) {
nearest_knn(node.below, targetPoint);
}
} else {
// explore one side of the tree.
if (targetPoint.data[dMinus1] < node.values.data[node.d]) {
if (node.below != null) {
nearest_knn(node.below, targetPoint);
}
} else {
if (node.above != null) {
nearest_knn(node.above, targetPoint);
}
}
}
}
// =========================== METHOD TO FIND POINTS WITHIN A RADIUS - used by DBSCAN =============================
/**
* Method to get all the points within the radius of a given target point, EXCEPT the target point!
*/
public List<DoubleArray> pointsWithinRadiusOf(DoubleArray targetPoint, double radius) {
List<DoubleArray> result = new ArrayList<DoubleArray>();
if (root == null){
return null;
}
// Now start back at the root, and check all rectangles that have a perpendicular distance
// smaller than the radius.
findPointsWithinRadius(root, targetPoint, result, radius);
// return the points within the radius
return result;
}
/**
* Start back at the root, and check all rectangles that have a perpendicular distance
* smaller than the radius.
* @param node the current node
* @param targetPoint the vector
* @param result the set of points within the radius (to be filled by this method)
* @param the radius
*/
private void findPointsWithinRadius(KDNode node, DoubleArray targetPoint, List<DoubleArray> result, double radius) {
// if it is the target point, we skip it because we don't want to return it
if(node.values != targetPoint) {
tryToSaveRadius(node, targetPoint, result, radius);
}
int dMinus1 = node.d-1;
if(dMinus1 < 0){
dMinus1 = dimensionCount - 1;
}
double perpendicularDistance = Math.abs(node.values.data[node.d] - targetPoint.data[dMinus1]);
if (perpendicularDistance < radius) {
// explore the "above" and "below" branches.
if (node.above != null) {
findPointsWithinRadius(node.above, targetPoint, result, radius);
}
if (node.below != null) {
findPointsWithinRadius(node.below, targetPoint, result, radius);
}
} else {
// explore one side of the tree.
if (targetPoint.data[dMinus1] < node.values.data[node.d]) {
if (node.below != null) {
findPointsWithinRadius(node.below, targetPoint, result, radius);
}
} else {
if (node.above != null) {
findPointsWithinRadius(node.above, targetPoint, result, radius);
}
}
}
}
/**
* Method to try to save a node in the set of the current closest k neighbors.
* @param node the node to be added.
* @param target the target node.
* @param the radius
*/
private void tryToSaveRadius(KDNode node, DoubleArray target, List<DoubleArray> result, double radius) {
if(node == null){
return;
}
double distance = distanceFunction.calculateDistance(target, node.values);
if(radius < distance){
return;
}
result.add(node.values);
}
/// ---------------------------------------------------------------------------------------------
// =========================== METHOD TO FIND POINTS WITHIN A RADIUS AND KEEP THE DISTANCE - used by OPTICS =============================
/**
* Method to get all the points within the radius of a given target point, EXCEPT the target point,
* and also get their distance to the target point!
*/
public List<KNNPoint> pointsWithinRadiusOfWithDistance(DoubleArray targetPoint, double radius) {
List<KNNPoint> result = new ArrayList<KNNPoint>();
if (root == null){
return null;
}
// Now start back at the root, and check all rectangles that have a perpendicular distance
// smaller than the radius.
findPointsWithinRadiusWithDistance(root, targetPoint, result, radius);
// return the points within the radius
return result;
}
/**
* Start back at the root, and check all rectangles that have a perpendicular distance
* smaller than the radius.
* @param node the current node
* @param targetPoint the vector
* @param result the set of points within the radius (to be filled by this method), and their distance to the target point!
* @param the radius
*/
private void findPointsWithinRadiusWithDistance(KDNode node, DoubleArray targetPoint, List<KNNPoint> result, double radius) {
// if it is the target point, we skip it because we don't want to return it
if(node.values != targetPoint) {
tryToSaveRadiusWithDistance(node, targetPoint, result, radius);
}
int dMinus1 = node.d-1;
if(dMinus1 < 0){
dMinus1 = dimensionCount - 1;
}
double perpendicularDistance = Math.abs(node.values.data[node.d] - targetPoint.data[dMinus1]);
if (perpendicularDistance < radius) {
// explore the "above" and "below" branches.
if (node.above != null) {
findPointsWithinRadiusWithDistance(node.above, targetPoint, result, radius);
}
if (node.below != null) {
findPointsWithinRadiusWithDistance(node.below, targetPoint, result, radius);
}
} else {
// explore one side of the tree.
if (targetPoint.data[dMinus1] < node.values.data[node.d]) {
if (node.below != null) {
findPointsWithinRadiusWithDistance(node.below, targetPoint, result, radius);
}
} else {
if (node.above != null) {
findPointsWithinRadiusWithDistance(node.above, targetPoint, result, radius);
}
}
}
}
/**
* Method to try to save a node in the set of the current closest k neighbors.
* @param node the node to be added.
* @param target the target node.
* @param the radius
*/
private void tryToSaveRadiusWithDistance(KDNode node, DoubleArray target, List<KNNPoint> result, double radius) {
if(node == null){
return;
}
double distance = distanceFunction.calculateDistance(target, node.values);
if(radius < distance){
return;
}
result.add(new KNNPoint(node.values, distance));
}
/// ---------------------------------------------------------------------------------------------
/**
* Convert a vector of double to a string representation
* @param values the vector
* @return a string
*/
private String toString(double [] values){
StringBuilder buffer = new StringBuilder();
for(Double element : values ){
buffer.append(" " + element);
}
return buffer.toString();
}
/**
* Convert this tree to a string representation
*/
public String toString(){
return toString(root, " ");
}
/**
* Convert a substree to a string while using some indentation.
* @param node the node
* @param indent the current indentation
* @return a string
*/
private String toString(KDNode node, String indent){
if(node == null){
return "";
}
String newIndent1 = indent + " |";
String newIndent2 = indent + " |";
return node.values + " (" + node.d +") \n"
+ indent + toString(node.above, newIndent1) + "\n"
+ indent + toString(node.below, newIndent2);
}
}