/*
* File: KDTreeTest.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Jul 29, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.math.geometry;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.math.Metric;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.math.matrix.mtj.Vector3;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import junit.framework.TestCase;
import java.util.Random;
/**
* Unit tests for KDTreeTest.
*
* @author krdixon
*/
public class KDTreeTest
extends TestCase
{
/**
* Random number generator to use for a fixed random seed.
*/
public final Random RANDOM = new Random(1);
/**
* Default tolerance of the regression tests, {@value}.
*/
public final double TOLERANCE = 1e-5;
/**
* Example from http://en.wikipedia.org/wiki/Kd-tree#Construction
*/
public static List<DefaultPair<Vector, Integer>> points
= new ArrayList<DefaultPair<Vector, Integer>>();
static
{
points.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(2, 3), 0));
points.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(5, 4), 1));
points.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(9, 6), 2));
points.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(4, 7), 3));
points.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(8, 1), 4));
points.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(7, 2), 5));
}
/**
* Tests for class KDTreeTest.
*
* @param testName Name of the test.
*/
public KDTreeTest(
String testName)
{
super(testName);
}
/**
* Creates an instance
*
* @return instance.
*/
public KDTree<Vector, Integer, DefaultPair<Vector, Integer>> createInstance()
{
return KDTree.createBalanced(points);
}
/**
* Tests the constructors of class KDTreeTest.
*/
public void testConstructors()
{
System.out.println("Constructors");
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> tree
= new KDTree<Vector, Integer, DefaultPair<Vector, Integer>>();
assertNull(tree.comparator);
assertNull(tree.value);
assertNull(tree.leftChild);
assertNull(tree.rightChild);
assertNull(tree.parent);
tree = this.createInstance();
assertEquals(points.size(), tree.size());
assertSame(points.get(0), tree.leftChild.leftChild.value);
assertSame(points.get(1), tree.leftChild.value);
assertSame(points.get(2), tree.rightChild.value);
assertSame(points.get(3), tree.leftChild.rightChild.value);
assertSame(points.get(4), tree.rightChild.leftChild.value);
assertSame(points.get(5), tree.value);
try
{
tree = new KDTree<Vector, Integer, DefaultPair<Vector, Integer>>(
new ArrayList<DefaultPair<Vector, Integer>>(), null, 0, null);
fail("Cannot give empty points!");
}
catch (Exception e)
{
System.out.println("Good: " + e);
}
}
/**
* Test of clone method, of class KDTree.
*/
public void testClone()
{
System.out.println("clone");
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> tree
= this.createInstance();
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> clone
= tree.clone();
assertNotNull(clone);
assertNotSame(tree, clone);
assertSame(tree.comparator, clone.comparator);
assertSame(tree.parent, clone.parent);
assertNotSame(tree.value, clone.value);
assertEquals(tree.value.getFirst(), clone.value.getFirst());
assertNotSame(tree.leftChild, clone.leftChild);
assertEquals(tree.leftChild.value.getFirst(),
clone.leftChild.value.getFirst());
assertNotSame(tree.rightChild, clone.rightChild);
assertEquals(tree.rightChild.value.getFirst(),
clone.rightChild.value.getFirst());
}
/**
* Test of createBalanced method, of class KDTree.
*/
public void testCreateBalanced()
{
System.out.println("createBalanced");
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> tree
= KDTree.createBalanced(points);
System.out.println("Tree:\n" + tree);
assertEquals(points.size(), tree.size());
assertSame(points.get(0), tree.leftChild.leftChild.value);
assertSame(points.get(1), tree.leftChild.value);
assertSame(points.get(2), tree.rightChild.value);
assertSame(points.get(3), tree.leftChild.rightChild.value);
assertSame(points.get(4), tree.rightChild.leftChild.value);
assertSame(points.get(5), tree.value);
}
/**
* Test of reblanace method, of class KDTree.
*/
public void testReblanace()
{
System.out.println("reblanace");
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> tree
= new KDTree<Vector, Integer, DefaultPair<Vector, Integer>>();
for (DefaultPair<Vector, Integer> point : points)
{
tree.add(point);
}
System.out.println("Tree:\n" + tree);
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> deepCopy
= ObjectUtil.deepCopy(tree);
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> balanced
= tree.reblanace();
assertEquals(deepCopy.toString(), tree.toString());
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> originalBalanced
= KDTree.createBalanced(points);
assertEquals(originalBalanced.toString(), balanced.toString());
}
/**
* Test of add method, of class KDTree.
*/
public void testAdd()
{
System.out.println("add");
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> tree
= new KDTree<Vector, Integer, DefaultPair<Vector, Integer>>();
tree.add(points.get(0));
assertSame(points.get(0), tree.value);
assertNull(tree.leftChild);
assertNull(tree.rightChild);
assertEquals(0, tree.comparator.comparator.getIndex());
tree.add(points.get(1));
assertSame(tree.value, points.get(0));
assertSame(points.get(1), tree.rightChild.value);
assertNull(tree.leftChild);
assertEquals(1, tree.rightChild.comparator.comparator.getIndex());
tree.add(points.get(2));
assertSame(points.get(2), tree.rightChild.rightChild.value);
assertEquals(0,
tree.rightChild.rightChild.comparator.comparator.getIndex());
tree.add(points.get(3));
assertSame(points.get(3), tree.rightChild.rightChild.leftChild.value);
assertEquals(1,
tree.rightChild.rightChild.leftChild.comparator.comparator.getIndex());
}
/**
* Test of size method, of class KDTree.
*/
public void testSize()
{
System.out.println("size");
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> tree
= new KDTree<Vector, Integer, DefaultPair<Vector, Integer>>();
assertEquals(0, tree.size());
for (int i = 0; i < points.size(); i++)
{
tree.add(points.get(i));
assertEquals(i + 1, tree.size());
}
}
/**
* Test of iterator method, of class KDTree.
*/
public void testIterator()
{
System.out.println("iterator");
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> tree
= new KDTree<Vector, Integer, DefaultPair<Vector, Integer>>();
Iterator<DefaultPair<Vector, Integer>> iterator = tree.iterator();
assertFalse(iterator.hasNext());
try
{
iterator.next();
fail("No next");
}
catch (Exception e)
{
System.out.println("Good: " + e);
}
tree = this.createInstance();
for (Pair<? extends Vector, ?> pair : tree)
{
System.out.println(pair.getFirst() + " -> " + pair.getSecond());
}
iterator = tree.iterator();
assertTrue(iterator.hasNext());
assertSame(points.get(0), iterator.next());
assertSame(points.get(1), iterator.next());
assertSame(points.get(3), iterator.next());
assertSame(points.get(5), iterator.next());
assertSame(points.get(4), iterator.next());
assertSame(points.get(2), iterator.next());
assertFalse(iterator.hasNext());
try
{
iterator.next();
fail("Nothing left!");
}
catch (Exception e)
{
System.out.println("Good: " + e);
}
}
private static class EuclideanDistanceMetric
extends AbstractCloneableSerializable
implements Metric<Vectorizable>
{
public static final EuclideanDistanceMetric INSTANCE
= new EuclideanDistanceMetric();
public static int NUM_EVALS = 0;
public EuclideanDistanceMetric()
{
}
public double evaluate(
Vectorizable first,
Vectorizable second)
{
NUM_EVALS++;
Vector delta = first.convertToVector().minus(
second.convertToVector());
return delta.norm2();
}
}
/**
* findNearest
*/
public void testFindNearest()
{
System.out.println("findNearest");
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> tree
= this.createInstance();
EuclideanDistanceMetric.NUM_EVALS = 0;
Collection<? extends Pair<? extends Vector, Integer>> nearest
= tree.findNearest(points.get(0).getFirst(), 1,
EuclideanDistanceMetric.INSTANCE);
assertEquals(1, nearest.size());
System.out.println("Evals: " + EuclideanDistanceMetric.NUM_EVALS);
System.out.println("Found: " + ObjectUtil.toString(
CollectionUtil.getFirst(nearest)));
assertEquals(points.get(0).getFirst(),
CollectionUtil.getFirst(nearest).getFirst());
EuclideanDistanceMetric.NUM_EVALS = 0;
nearest = tree.findNearest(points.get(1).getFirst(), 1,
EuclideanDistanceMetric.INSTANCE);
assertEquals(1, nearest.size());
System.out.println("Evals: " + EuclideanDistanceMetric.NUM_EVALS);
assertEquals(points.get(1).getFirst(),
CollectionUtil.getFirst(nearest).getFirst());
EuclideanDistanceMetric.NUM_EVALS = 0;
nearest = tree.findNearest(points.get(5).getFirst(), 2,
EuclideanDistanceMetric.INSTANCE);
System.out.println("Evals: " + EuclideanDistanceMetric.NUM_EVALS);
assertEquals(2, nearest.size());
Iterator<? extends Pair<? extends Vector, Integer>> iterator
= nearest.iterator();
assertEquals(points.get(4).getFirst(), iterator.next().getFirst());
assertEquals(points.get(5).getFirst(), iterator.next().getFirst());
EuclideanDistanceMetric.NUM_EVALS = 0;
nearest = tree.findNearest(points.get(3).getFirst(), 3,
EuclideanDistanceMetric.INSTANCE);
System.out.println("Evals: " + EuclideanDistanceMetric.NUM_EVALS);
for (Pair<? extends Vector, Integer> neighbor : nearest)
{
System.out.println("Neighbor:\n" + ObjectUtil.toString(neighbor));
}
assertEquals(3, nearest.size());
iterator = nearest.iterator();
assertEquals(points.get(0).getFirst(), iterator.next().getFirst());
assertEquals(points.get(3).getFirst(), iterator.next().getFirst());
assertEquals(points.get(1).getFirst(), iterator.next().getFirst());
assertSame(tree, tree.findNearest(points.get(0).getFirst(), tree.size(),
EuclideanDistanceMetric.INSTANCE));
}
/**
* findNearestWithinRadius
*/
public void testFindNearestWithinRadius()
{
System.out.println("findNearestWithinRadius");
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> tree
= this.createInstance();
EuclideanDistanceMetric.NUM_EVALS = 0;
Collection<? extends Pair<? extends Vector, Integer>> nearest
= tree.findNearestWithinRadius(points.get(0).getFirst(), 1.0,
EuclideanDistanceMetric.INSTANCE);
assertEquals(1, nearest.size());
System.out.println("Evals: " + EuclideanDistanceMetric.NUM_EVALS);
System.out.println("Found: " + ObjectUtil.toString(
CollectionUtil.getFirst(nearest)));
assertEquals(points.get(0).getFirst(),
CollectionUtil.getFirst(nearest).getFirst());
EuclideanDistanceMetric.NUM_EVALS = 0;
nearest = tree.findNearestWithinRadius(points.get(1).getFirst(), 1.0,
EuclideanDistanceMetric.INSTANCE);
assertEquals(1, nearest.size());
System.out.println("Evals: " + EuclideanDistanceMetric.NUM_EVALS);
assertEquals(points.get(1).getFirst(),
CollectionUtil.getFirst(nearest).getFirst());
EuclideanDistanceMetric.NUM_EVALS = 0;
nearest = tree.findNearestWithinRadius(points.get(5).getFirst(), 2.0,
EuclideanDistanceMetric.INSTANCE);
System.out.println("Evals: " + EuclideanDistanceMetric.NUM_EVALS);
assertEquals(2, nearest.size());
Iterator<? extends Pair<? extends Vector, Integer>> iterator
= nearest.iterator();
assertEquals(points.get(4).getFirst(), iterator.next().getFirst());
assertEquals(points.get(5).getFirst(), iterator.next().getFirst());
EuclideanDistanceMetric.NUM_EVALS = 0;
nearest = tree.findNearestWithinRadius(points.get(3).getFirst(), 4.5,
EuclideanDistanceMetric.INSTANCE);
System.out.println("Evals: " + EuclideanDistanceMetric.NUM_EVALS);
for (Pair<? extends Vector, Integer> neighbor : nearest)
{
System.out.println("Neighbor:\n" + ObjectUtil.toString(neighbor));
}
assertEquals(3, nearest.size());
iterator = nearest.iterator();
assertEquals(points.get(0).getFirst(), iterator.next().getFirst());
assertEquals(points.get(3).getFirst(), iterator.next().getFirst());
assertEquals(points.get(1).getFirst(), iterator.next().getFirst());
}
/**
* Neighbor.equals
*/
public void testNeighborEquals()
{
System.out.println("Neighbor.equals");
KDTree.Neighborhood<Vector, Integer, DefaultPair<Vector, Integer>> neighborhood
= new KDTree.Neighborhood<Vector, Integer, DefaultPair<Vector, Integer>>(
1);
KDTree.Neighborhood<Vector, Integer, DefaultPair<Vector, Integer>>.Neighbor<Vector, Integer, DefaultPair<Vector, Integer>> neighbor
= neighborhood.new Neighbor<Vector, Integer, DefaultPair<Vector, Integer>>(
points.get(0), RANDOM.nextDouble());
assertFalse(neighbor.equals(null));
assertFalse(neighbor.equals(new Double(RANDOM.nextDouble())));
assertFalse(neighbor.equals(points.get(0)));
assertFalse(neighbor.equals(points.get(0).getFirst()));
assertTrue(neighbor.equals(
neighborhood.new Neighbor<Vector, Integer, DefaultPair<Vector, Integer>>(
points.get(0), RANDOM.nextDouble())));
}
/**
* Internal iterator
*/
public void testInternalIterator()
{
System.out.println("iterator");
KDTree<?, ?, ?> tree = this.createInstance();
Iterator<?> iterator = tree.iterator();
assertTrue(iterator.hasNext());
try
{
iterator.remove();
fail("Remove isn't implemented");
}
catch (Exception e)
{
System.out.println("Good: " + e);
}
}
public void testPathologicalExample()
{
System.out.println("Pathological Example");
List<DefaultPair<Vector, Integer>> pathological
= new ArrayList<DefaultPair<Vector, Integer>>();
pathological.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(0, 0), 0));
pathological.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(0, 0), 1));
pathological.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(0, 0), 2));
pathological.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(0, 0), 3));
pathological.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(0, 0), 4));
pathological.add(new DefaultPair<Vector, Integer>(
VectorFactory.getDefault().copyValues(0, 0), 5));
KDTree<Vector, Integer, DefaultPair<Vector, Integer>> tree
= KDTree.createBalanced(pathological);
System.out.println("Pathological Tree:\n" + tree);
assertSame(pathological.get(3), tree.value);
assertSame(pathological.get(4), tree.leftChild.value);
assertSame(pathological.get(0), tree.leftChild.leftChild.value);
assertSame(pathological.get(2), tree.leftChild.rightChild.value);
assertSame(pathological.get(5), tree.rightChild.value);
assertSame(pathological.get(1), tree.rightChild.leftChild.value);
EuclideanDistanceMetric.NUM_EVALS = 0;
Collection<? extends Pair<? extends Vector, Integer>> nearest
= tree.findNearest(pathological.get(1).getFirst(), 1,
EuclideanDistanceMetric.INSTANCE);
System.out.println("Evals: " + EuclideanDistanceMetric.NUM_EVALS);
System.out.println("1 Nearest:\n" + ObjectUtil.toString(nearest));
assertEquals(pathological.get(1).getFirst(),
CollectionUtil.getFirst(nearest).getFirst());
EuclideanDistanceMetric.NUM_EVALS = 0;
nearest = tree.findNearest(pathological.get(1).getFirst(), 3,
EuclideanDistanceMetric.INSTANCE);
System.out.println("Evals: " + EuclideanDistanceMetric.NUM_EVALS);
System.out.println("3 Nearest:\n" + ObjectUtil.toString(nearest));
for (Pair<? extends Vector, Integer> neighbor : nearest)
{
assertEquals(pathological.get(0).getFirst(), neighbor.getFirst());
}
nearest = tree.findNearest(pathological.get(1).getFirst(),
pathological.size() + 1, EuclideanDistanceMetric.INSTANCE);
assertEquals(tree.size(), nearest.size());
for (Pair<? extends Vector, Integer> neighbor : nearest)
{
assertEquals(pathological.get(0).getFirst(), neighbor.getFirst());
}
}
public void testSelfLookup()
{
List<Vector> data = new ArrayList<Vector>();
data.add(new Vector3(0.0, 1.0, 2.0));
data.add(new Vector3(0.0, 1.1, 2.2));
data.add(new Vector3(0.0, -1.0, -2.0));
data.add(new Vector3(0.0, -1.1, -2.2));
List<DefaultPair<Vector, Vector>> pairs
= new ArrayList<DefaultPair<Vector, Vector>>(
data.size());
for (Vector item : data)
{
pairs.add(DefaultPair.create(item, item));
}
KDTree<Vector, Vector, DefaultPair<Vector, Vector>> tree
= KDTree.createBalanced(pairs);
int neighborCount = 3;
for (Vector item : data)
{
if (item.getElement(2) == -2.2)
{
System.out.println("Here!!");
}
Collection<DefaultPair<Vector, Vector>> neighbors
= tree.findNearest(item, neighborCount,
EuclideanDistanceMetric.INSTANCE);
assertEquals(neighborCount, neighbors.size());
System.out.println("Neighbors of " + item);
boolean hasSelf = false;
for (DefaultPair<?, ?> neighbor : neighbors)
{
System.out.println(neighbor.getSecond());
hasSelf = hasSelf || neighbor.getSecond().equals(item);
}
assertTrue(hasSelf);
}
}
}