package func.inst; import java.io.Serializable; import java.util.Arrays; import java.util.Random; import shared.*; import shared.DataSet; import shared.Instance; /** * A KDTree implementation * Algorithms from Andrew Moore's tutorial * @author Andrew Guillory * @version 1.0 */ public class KDTree implements Serializable { /** * Random number generator */ private static final Random random = new Random(); /** * The head node of the kd tree */ private KDTreeNode head; /** * The dimensionality of the tree (k) */ private int dimensions; /** * The distance measure to use */ private DistanceMeasure distanceMeasure; /** * Build a kd tree from the given parallel arrays * of keys and data * @param keys the array of keys * @param distance the distance measure */ public KDTree(DataSet keys, DistanceMeasure distance) { dimensions = keys.get(0).size(); distanceMeasure = distance; KDTreeNode[] nodes = new KDTreeNode[keys.size()]; for (int i = 0; i < keys.size(); i++) { nodes[i] = new KDTreeNode(keys.get(i)); } head = buildTree(nodes, 0, nodes.length); } /** * Build a kd tree from the given parallel arrays * of keys and data * @param keys the array of keys * @param data the array of data * @param distance the distance measure */ public KDTree(DataSet keys) { this(keys, new EuclideanDistance()); } /** * Build a tree from a list of nodes * @param nodes the list of nodes * @return the head node of the built tree */ private KDTreeNode buildTree(KDTreeNode[] nodes, int start, int end) { if (start >= end) { // if we're done return null return null; } else if (start + 1 == end) { // or the last element return nodes[start]; } // choose splitter int splitterIndex = chooseSplitterRandom(nodes, start, end); KDTreeNode splitter = nodes[splitterIndex]; // patition based on splitter splitterIndex = partition(nodes, start, end, splitterIndex); // recursively build tree splitter.setLeft(buildTree(nodes, start, splitterIndex)); splitter.setRight(buildTree(nodes, splitterIndex + 1, end)); return splitter; } /** * Partition an array based on a splitter * @param comparables the array * @param start the start index * @param end the end index * @param splitterIndex the splitter's index * @return the new splitter index */ private int partition(Comparable[] comparables, int start, int end, int splitterIndex) { swap(comparables, splitterIndex, end - 1); splitterIndex = end - 1; Comparable splitter = comparables[splitterIndex]; int i = start - 1; for (int j = start; j < end - 1; j++) { if (splitter.compareTo(comparables[j]) > 0) { i++; swap(comparables, i, j); } } swap(comparables, splitterIndex, i + 1); return i + 1; } /** * Swap two elements in an array * @param objects the array * @param i the first index * @param j the second index */ private void swap(Object[] objects, int i, int j) { Object temp = objects[i]; objects[i] = objects[j]; objects[j] = temp; } /** * Choose a random splitter * @param nodes the nodes to choose from * @param start the start * @param end the end * @return the splitter */ private int chooseSplitterRandom(KDTreeNode[] nodes, int start, int end) { int splitter = random.nextInt(end - start) + start; int dimension = random.nextInt(dimensions); nodes[splitter].setDimension(dimension); return splitter; } /** * Choose a splitter from a list of nodes * This isn't used because it is much slower than random, * and im not sure how to implement it any faster * @param nodes the list of nodes to pick from * @param start the start index * @param end the end index * @return the best splitter from the list */ private int chooseSplitterSmart(KDTreeNode[] nodes, int start, int end) { // find the ranges of the dimensions double[] min = new double[dimensions]; Arrays.fill(min, Double.POSITIVE_INFINITY); double[] max = new double[dimensions]; Arrays.fill(max, Double.NEGATIVE_INFINITY); for (int i = start; i < end; i++) { Instance key = nodes[i].getInstance(); for (int j = 0; j < dimensions; j++) { min[j] = Math.min(min[j], key.getContinuous(j)); max[j] = Math.max(max[j], key.getContinuous(j)); } } // find the widest dimension int widestDimension = 0; double widestWidth = max[0] - min[0]; for (int i = 1; i < dimensions; i++) { if (max[i] - min[i] > widestWidth) { widestDimension = i; widestWidth = max[i] - min[i]; } } // find the middle of the widest dimension double median = (max[widestDimension] - min[widestDimension]) / 2; // find the best splitter double bestDifference = Double.POSITIVE_INFINITY; int splitterIndex = -1; for (int i = start; i < end; i++) { KDTreeNode node = nodes[i]; if (Math.abs(node.getInstance().getContinuous(widestDimension) - median) < bestDifference) { splitterIndex = i; bestDifference = Math.abs(node.getInstance().getContinuous(widestDimension) - median); } } nodes[splitterIndex].setDimension(widestDimension); return splitterIndex; } /** * Perform a k nearest neighbor search * @param target the target of the search * @param k how many neighbors to find * @return the neighbors */ public Instance[] knn(Instance target, int k) { NearestNeighborQueue results = new NearestNeighborQueue(k); knn(head, target, new HyperRectangle(dimensions), results); return results.getNearest(); } /** * Perform a nearest neighbor search * @param target the target * @return the neighbors */ public Instance[] nn(Instance target) { NearestNeighborQueue results = new NearestNeighborQueue(); knn(head, target, new HyperRectangle(dimensions), results); return results.getNearest(); } /** * Perform a range search * @param target the target * @param range the range * @return the neighbors in the range */ public Instance[] range(Instance target, double range) { NearestNeighborQueue results = new NearestNeighborQueue(range); knn(head, target, new HyperRectangle(dimensions), results); return results.getNearest(); } /** * Perform a k nearest neighbor range search * @param target the target * @param k the k value * @param range the range * @return the neighbours */ public Instance[] knnrange(Instance target, int k, double range) { NearestNeighborQueue results = new NearestNeighborQueue(k, range); knn(head, target, new HyperRectangle(dimensions), results); return results.getNearest(); } /** * Perform a nearest neighbor search * @param node the node to search on * @param target the target * @param hr the hyper rectangle * @param results the current results */ private void knn(KDTreeNode node, Instance target, HyperRectangle hr, NearestNeighborQueue results) { if (node == null) { return; } HyperRectangle leftHR = hr.splitLeft( node.getSplitValue(), node.getDimension()); HyperRectangle rightHR = hr.splitRight( node.getSplitValue(), node.getDimension()); HyperRectangle nearHR, farHR; KDTreeNode nearNode, farNode; if (target.getContinuous(node.getDimension()) < node.getSplitValue()) { nearHR = leftHR; nearNode = node.getLeft(); farHR = rightHR; farNode = node.getRight(); } else { nearHR = rightHR; nearNode = node.getRight(); farHR = leftHR; farNode = node.getLeft(); } knn(nearNode, target, nearHR, results); if (distanceMeasure.value( farHR.pointNearestTo(target), target) <= results.getMaxDistance()) { results.add(node.getInstance(), distanceMeasure.value(node.getInstance(), target)); knn(farNode, target, farHR, results); } } }