/*
* File: KNearestNeighborTestHarness.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright March 7, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.nearest;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.distance.EuclideanDistanceMetric;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.Summarizer;
import gov.sandia.cognition.math.DivergenceFunction;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import junit.framework.TestCase;
/**
* Unit tests for KNearestNeighborTestHarness
*
* @author Kevin R. Dixon
* @since 1.0
*/
public abstract class KNearestNeighborTestHarness
extends TestCase
{
/** The RANDOM number generator for the tests. */
public final Random RANDOM = new Random(1);
/**
* Tolerance for tests
*/
public final double TOLERANCE = 1e-5;
public static class CounterEuclidenDistance
extends EuclideanDistanceMetric
{
public int evaluations = 0;
@Override
public double evaluate(Vectorizable first,
Vectorizable second)
{
this.evaluations++;
return super.evaluate(first, second);
}
}
/**
* Example from http://en.wikipedia.org/wiki/Kd-tree#Construction
*/
@SuppressWarnings("unchecked")
public static List<DefaultInputOutputPair<Vector,Double>> POINTS = Arrays.asList(
new DefaultInputOutputPair<Vector, Double>(VectorFactory.getDefault().copyValues(2,3), 0.0),
new DefaultInputOutputPair<Vector, Double>(VectorFactory.getDefault().copyValues(5,4), 1.0),
new DefaultInputOutputPair<Vector, Double>(VectorFactory.getDefault().copyValues(9,6), 2.0),
new DefaultInputOutputPair<Vector, Double>(VectorFactory.getDefault().copyValues(4,7), 3.0),
new DefaultInputOutputPair<Vector, Double>(VectorFactory.getDefault().copyValues(8,1), 4.0),
new DefaultInputOutputPair<Vector, Double>(VectorFactory.getDefault().copyValues(7,2), 5.0)
);
/**
*
* @param testName
*/
public KNearestNeighborTestHarness(
String testName )
{
super( testName );
}
/**
* Creates an instance.
* @param k K
* @param data Pairs.
* @return KNN
*/
public abstract AbstractKNearestNeighbor<Vector, Double> createInstance(
int k,
Collection<? extends InputOutputPair<Vector,Double>> data );
/**
* Test constructors
*/
public abstract void testConstructors();
/**
* Test learner
*/
public abstract void testLearner();
/**
* Tests to make sure you implemented the createInstance method properly.
*/
public void testCreateInstance()
{
System.out.println( "createInstance" );
int k = RANDOM.nextInt(10) + 1;
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( k, POINTS );
assertEquals( k, knn.getK() );
assertEquals( POINTS.size(), knn.getData().size() );
for( Pair<? extends Vector,Double> d : knn.getData() )
{
boolean found = false;
for( Pair<Vector,Double> point : POINTS )
{
if( point.getFirst().equals( d.getFirst() ) )
{
found = true;
assertEquals( point.getSecond(), d.getSecond() );
}
}
assertTrue( found );
}
}
/**
* Test clone
*/
public void testClone()
{
System.out.println( "Clone" );
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( 1, POINTS );
AbstractKNearestNeighbor<Vector, Double> clone = knn.clone();
assertNotNull( clone );
assertNotSame( knn, clone );
assertEquals( knn.getK(), clone.getK() );
assertNotSame( knn.getAverager(), clone.getAverager() );
assertNotSame( knn.getDivergenceFunction(), clone.getDivergenceFunction() );
assertEquals( knn.getData().size(), clone.getData().size() );
assertNotSame( knn.getData(), clone.getData() );
}
/**
* Test of evaluate method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor.
*/
public void testEvaluateSimple()
{
System.out.println( "evaluate 1NN" );
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( 1, POINTS );
int M = -1;
for (Pair<? extends Vector, ? extends Double> pair : knn.getData())
{
M = pair.getFirst().getDimensionality();
assertEquals( pair.getSecond(), knn.evaluate( pair.getFirst() ) );
}
try
{
knn.evaluate( null );
fail( "Should have thrown null-pointer exception" );
}
catch (Exception e)
{
System.out.println( "Good! Properly thrown exception: " + e );
}
final double r = TOLERANCE;
for (Pair<? extends Vector, ? extends Double> pair : knn.getData())
{
Vector small = VectorFactory.getDefault().createUniformRandom( M, -r, r, RANDOM );
Vector perturb = pair.getFirst().plus( small );
Double estimate = knn.evaluate( perturb );
double error = pair.getSecond() - estimate;
double distance = Math.abs( error );
assertEquals( 0.0, distance, TOLERANCE );
}
}
/**
* Evaluate many
*/
public void testEvaluateMany()
{
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( 1, POINTS );
final int M = CollectionUtil.getFirst(knn.getData()).getFirst().getDimensionality();
// A "k" larger than there are datapoint, should just return the
// average value
knn.setK( knn.getData().size() + 1 );
Collection<? extends Pair<? extends Vector, Double>> data = knn.getData();
ArrayList<Double> outputs = new ArrayList<Double>( data.size() );
for( Pair<? extends Vector,Double> pair : data )
{
outputs.add(pair.getSecond());
}
double expected = UnivariateStatisticsUtil.computeMean(outputs);
Vector input = VectorFactory.getDefault().createUniformRandom( M, -10.0, 10.0, RANDOM );
assertEquals( expected, knn.evaluate( input ), TOLERANCE );
}
/**
* Evaluates 2NN against known values.
*/
public void testEvaluate2NN()
{
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( 2, POINTS );
double value = knn.evaluate( POINTS.get(5).getFirst() );
assertEquals( 4.5, value );
value = knn.evaluate( POINTS.get(0).getFirst() );
assertEquals( 0.5, value );
}
/**
* Test of getDivergenceFunction method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor.
*/
public void testGetDivergenceFunction()
{
System.out.println( "getDivergenceFunction" );
int k = RANDOM.nextInt( 10 ) + 1;
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( k, POINTS );
assertNotNull( knn.getDivergenceFunction() );
}
/**
* Test of setDivergenceFunction method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor.
*/
public void testSetDivergenceFunction()
{
System.out.println( "setDivergenceFunction" );
int k = RANDOM.nextInt( 10 ) + 1;
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( k, POINTS );
assertNotNull( knn.getDivergenceFunction() );
DivergenceFunction<? super Vector, ? super Vector> foo =
knn.getDivergenceFunction();
knn.setDivergenceFunction( null );
assertNull( knn.getDivergenceFunction() );
knn.setDivergenceFunction( foo );
assertSame( foo, knn.getDivergenceFunction() );
}
/**
* Test of getAverager method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor.
*/
public void testGetAverager()
{
System.out.println( "getAverager" );
int k = RANDOM.nextInt( 10 ) + 1;
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( k, POINTS );
assertNotNull( knn.getAverager() );
}
/**
* Test of setAverager method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor.
*/
public void testSetAverager()
{
System.out.println( "setAverager" );
int k = RANDOM.nextInt( 10 ) + 1;
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( k, POINTS );
assertNotNull( knn.getAverager() );
Summarizer<? super Double, ? extends Double> avg = knn.getAverager();
knn.setAverager( null );
assertNull( knn.getAverager() );
knn.setAverager( avg );
assertSame( avg, knn.getAverager() );
}
/**
* Test of getData method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor.
*/
public void testGetData()
{
System.out.println( "getData" );
int k = RANDOM.nextInt( 10 ) + 1;
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( k, POINTS );
Collection<? extends Pair<? extends Vector,Double>> data = knn.getData();
assertNotNull( data );
assertEquals( POINTS.size(), data.size() );
for( Pair<? extends Vector,Double> value : data )
{
boolean found = false;
for( Pair<Vector,Double> point : POINTS )
{
if( point.getFirst().equals( value.getFirst() ) )
{
found = true;
assertEquals( point.getSecond(), value.getSecond() );
}
}
assertTrue( found );
}
}
/**
* Test of getK method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor.
*/
public void testGetK()
{
System.out.println( "getK" );
int k = RANDOM.nextInt( 10 ) + 1;
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( k, POINTS );
assertEquals( k, knn.getK() );
}
/**
* Test of setK method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor.
*/
public void testSetK()
{
System.out.println( "setK" );
int k = RANDOM.nextInt( 10 ) + 1;
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( k, POINTS );
assertEquals( k, knn.getK() );
k++;
assertEquals( k, knn.getK() + 1 );
knn.setK( k );
assertEquals( k, knn.getK() );
try
{
knn.setK( 0 );
fail( "Number of neighbors must be > 0" );
}
catch (Exception e)
{
System.out.println( "Good: " + e );
}
}
/**
* Tests add.
*/
public void testAdd()
{
System.out.println( "add" );
int k = RANDOM.nextInt(10) + 1;
AbstractKNearestNeighbor<Vector, Double> knn = this.createInstance( k, POINTS );
InputOutputPair<Vector, Double> pair =
new DefaultInputOutputPair<Vector,Double>(VectorFactory.getDefault().copyValues(0,0), 6.0);
int preSize = knn.getData().size();
assertFalse( knn.getData().contains(pair) );
knn.add( pair );
assertEquals( preSize+1, knn.getData().size() );
boolean found = false;
for( Pair<? extends Vector,Double> d : knn.getData() )
{
if( pair.getFirst().equals( d.getFirst() ) )
{
found = true;
assertEquals( pair.getSecond(), d.getSecond() );
}
}
assertTrue( found );
// Add it again.
knn.add( pair );
assertEquals( preSize+2, knn.getData().size() );
found = false;
for( Pair<? extends Vector,Double> d : knn.getData() )
{
if( pair.getFirst().equals( d.getFirst() ) )
{
found = true;
assertEquals( pair.getSecond(), d.getSecond() );
}
}
assertTrue( found );
}
}