/*
* RapidMiner
*
* Copyright (C) 2001-2008 by Rapid-I and the contributors
*
* Complete list of developers available at our web site:
*
* http://rapid-i.com
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see http://www.gnu.org/licenses/.
*/
package com.rapidminer.tools.math.container;
import java.util.Collection;
import java.util.LinkedList;
import java.util.ListIterator;
import java.util.Stack;
import com.rapidminer.datatable.SimpleDataTable;
import com.rapidminer.datatable.SimpleDataTableRow;
import com.rapidminer.tools.Tupel;
import com.rapidminer.tools.math.similarity.DistanceMeasure;
/**
* This class is an implementation of a Ball-Tree for organizing multidimensional datapoints
* in a fashion supporting the search for nearest neighbours. This is only working well in
* low to middle number of dimensions. Since the building of the tree is very expensiv,
* in most cases a linear search strategy will outperform the ballTree in overall performance.
*
* @param <T> This is the type of value with is stored with the points and retrieved on nearest
* neighbour search
*
* @author Sebastian Land
* @version $Id: BallTree.java,v 1.4 2008/07/13 20:38:24 ingomierswa Exp $
*/
public class BallTree<T> implements GeometricDataCollection<T> {
private BallTreeNode<T> root;
private int k;
private double dimensionFactor;
private DistanceMeasure distance;
public BallTree(DistanceMeasure distance) {
this.distance = distance;
}
public void add(double[] values, T storeValue) {
if (root == null) {
root = new BallTreeNode<T>(values, 0, storeValue);
// setting dimension
k = values.length;
dimensionFactor = Math.sqrt(Math.PI) / Math.pow((double)gammaFunction(k / 2), 1d / (double)k);
} else {
double totalAncestorIncrease = 0;
double bestVolumeIncrease = Double.POSITIVE_INFINITY;
BallTreeNode<T> bestNode = null; // this node will be made child of new node
int bestNodeIndex = 0;
int bestSide = -1; // -1 left, 1 right
BallTreeNode<T> currentNode = root;
LinkedList<BallTreeNode<T>> ancestorList = new LinkedList<BallTreeNode<T>>();
while(true) {
// calculate ancestor increase if added to this current node
double deltaAncestorIncrease = getVolumeIncludingPoint(currentNode, values) - getVolume(currentNode);
totalAncestorIncrease += deltaAncestorIncrease;
// calculate new Volume if added as left or right child of current
double leftVolume = getNewVolume(currentNode, currentNode.getLeftChild(), values);
double rightVolume = getNewVolume(currentNode, currentNode.getRightChild(), values);
// check if adding as left node is best position till now
double minVolume = Math.min(leftVolume, rightVolume);
if (minVolume + totalAncestorIncrease < bestVolumeIncrease) {
bestVolumeIncrease = minVolume + totalAncestorIncrease;
bestNode = currentNode;
bestSide = Double.compare(leftVolume, rightVolume);
bestNodeIndex = ancestorList.size();
}
// adding next father
ancestorList.add(currentNode);
// check for termination
if (currentNode.isLeaf())
break;
// search for better child
if (currentNode.hasTwoChilds()) {
BallTreeNode<T> leftChild = currentNode.getLeftChild();
double deltaVLeft = getVolumeIncludingPoint(leftChild, values) - getVolume(leftChild);
BallTreeNode<T> rightChild = currentNode.getRightChild();
double deltaVRight = getVolumeIncludingPoint(rightChild, values) - getVolume(rightChild);
BallTreeNode<T> betterChild = (deltaVLeft < deltaVRight)? leftChild: rightChild;
currentNode = betterChild;
} else {
// or use single if only one present
currentNode = currentNode.getChild();
}
}
// now adding as specified child from bestFather
BallTreeNode<T> newNode = new BallTreeNode<T>(values, 0, storeValue);
if (bestSide < 0) {
newNode.setChild(bestNode.getLeftChild());
bestNode.setLeftChild(newNode);
} else {
newNode.setChild(bestNode.getRightChild());
bestNode.setRightChild(newNode);
}
// setting radius of new node
if (!newNode.isLeaf())
newNode.setRadius(distance.calculateDistance(values, newNode.getChild().getCenter()) + newNode.getChild().getRadius());
// correcting radius of all ancestors
ListIterator<BallTreeNode<T>> iterator = ancestorList.listIterator(bestNodeIndex + 1);
while (iterator.hasPrevious()) {
BallTreeNode<T> ancestor = iterator.previous();
if (ancestor.hasTwoChilds()) {
BallTreeNode<T> leftChild = ancestor.getLeftChild();
BallTreeNode<T> rightChild = ancestor.getRightChild();
ancestor.setRadius(Math.max(rightChild.getRadius() + distance.calculateDistance(rightChild.getCenter(), ancestor.getCenter()),
leftChild.getRadius() + distance.calculateDistance(leftChild.getCenter(), ancestor.getCenter())));
} else {
BallTreeNode<T> child = ancestor.getChild();
ancestor.setRadius(distance.calculateDistance(ancestor.getCenter(), child.getCenter()) + child.getRadius());
}
}
}
}
/**
* Returns the volume of the ball if the new node is added as child of father and new father
* of child with center as center. Child might be null, then the radius is 0
*/
private double getNewVolume(BallTreeNode<T> father, BallTreeNode<T> child, double[] center) {
if (child == null)
return 0;
return Math.pow((distance.calculateDistance(center, child.getCenter()) + child.getRadius()) * dimensionFactor, k);
}
public Collection<T> getNearestValues(int k, double[] values) {
BoundedPriorityQueue<Tupel<Double, BallTreeNode<T>>> priorityQueue = getNearestNodes(k, values);
LinkedList<T> neighboursList = new LinkedList<T>();
for (Tupel<Double, BallTreeNode<T>> tupel: priorityQueue) {
neighboursList.add((tupel.getSecond()).getStoreValue());
}
return neighboursList;
}
public Collection<Tupel<Double, T>> getNearestValueDistances(int k, double[] values) {
BoundedPriorityQueue<Tupel<Double, BallTreeNode<T>>> priorityQueue = getNearestNodes(k, values);
LinkedList<Tupel<Double, T>> neighboursList = new LinkedList<Tupel<Double, T>>();
for (Tupel<Double, BallTreeNode<T>> tupel: priorityQueue) {
neighboursList.add(new Tupel<Double, T>(tupel.getFirst(), tupel.getSecond().getStoreValue()));
}
return neighboursList;
}
private BoundedPriorityQueue<Tupel<Double, BallTreeNode<T>>> getNearestNodes(int k, double[] values) {
Stack<BallTreeNode<T>> nodeStack = new Stack<BallTreeNode<T>>();
Stack<Integer> sideStack = new Stack<Integer>();
// first doing initial search for nearest Node
traverseTree(nodeStack, sideStack, root, values);
// creating data structure for finding k nearest values
BoundedPriorityQueue<Tupel<Double, BallTreeNode<T>>> priorityQueue = new BoundedPriorityQueue<Tupel<Double, BallTreeNode<T>>>(k);
// now work on stack
while (!nodeStack.isEmpty()) {
// put top element into priorityQueue
BallTreeNode<T> currentNode = nodeStack.pop();
Integer currentSide = sideStack.pop();
Tupel<Double, BallTreeNode<T>> currentTupel = new Tupel<Double, BallTreeNode<T>>(distance.calculateDistance(currentNode.getCenter(), values), currentNode);
priorityQueue.add(currentTupel);
// now check if far children has to be regarded
if (currentNode.hasTwoChilds()) {
BallTreeNode<T> otherChild = (currentSide < 0) ? currentNode.getRightChild(): currentNode.getLeftChild();
if (!priorityQueue.isFilled() ||
priorityQueue.peek().getFirst().doubleValue() + otherChild.getRadius() >
distance.calculateDistance(values, otherChild.getCenter())) {
// if needs to be checked, traverse tree to not visited leaf
traverseTree(nodeStack, sideStack, otherChild, values);
}
}
// go on, until stack is empty
}
return priorityQueue;
}
private void traverseTree(Stack<BallTreeNode<T>> stack, Stack<Integer> sideStack, BallTreeNode<T> root, double[] values) {
BallTreeNode<T> currentNode = root;
stack.push(currentNode);
while(!currentNode.isLeaf()) {
if (currentNode.hasTwoChilds()) {
double distanceLeft = distance.calculateDistance(currentNode.getLeftChild().getCenter(), values);
double distanceRight = distance.calculateDistance(currentNode.getRightChild().getCenter(), values);
currentNode = (distanceLeft < distanceRight) ? currentNode.getLeftChild(): currentNode.getRightChild();
sideStack.push(Double.compare(distanceLeft, distanceRight));
} else {
currentNode = currentNode.getChild();
sideStack.push(0);
}
stack.push(currentNode);
}
sideStack.push(0);
}
private double getVolumeIncludingPoint(BallTreeNode node, double[] point) {
return Math.pow(Math.max(node.getRadius(), distance.calculateDistance(point, node.getCenter())) * dimensionFactor, k);
}
private double getVolume(BallTreeNode node) {
return Math.pow(node.getRadius() * dimensionFactor, k);
}
private double gammaFunction(int n) {
double result = 1;
for (int i = 2; i < n; i++) {
result *= i;
}
return result;
}
public SimpleDataTable getVisualization() {
SimpleDataTable table = new SimpleDataTable("BallTree", new String[] {"x", "y", "radius"});
fillTable(table, root);
return table;
}
private void fillTable(SimpleDataTable table, BallTreeNode<T> node) {
table.add(new SimpleDataTableRow(new double[] {node.getCenter()[0], node.getCenter()[1], node.getRadius()}));
if (node.hasLeftChild())
fillTable(table, node.getLeftChild());
if (node.hasRightChild())
fillTable(table, node.getRightChild());
}
}