/* * File: NearestNeighborTestHarness.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Aug 10, 2009, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.algorithm.nearest; import gov.sandia.cognition.learning.data.DefaultInputOutputPair; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.function.distance.EuclideanDistanceMetric; import gov.sandia.cognition.math.DivergenceFunction; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.VectorFactory; import gov.sandia.cognition.math.matrix.Vectorizable; import java.util.Arrays; import java.util.Collection; import java.util.List; import junit.framework.TestCase; import java.util.Random; /** * Unit tests for NearestNeighborTestHarness. * * @author krdixon */ public abstract class NearestNeighborTestHarness extends TestCase { /** * Random number generator to use for a fixed random seed. */ public final Random RANDOM = new Random( 1 ); /** * Default tolerance of the regression tests, {@value}. */ public final double TOLERANCE = 1e-5; public static class CounterEuclidenDistance extends EuclideanDistanceMetric { public int evaluations = 0; @Override public double evaluate(Vectorizable first, Vectorizable second) { this.evaluations++; return super.evaluate(first, second); } } /** * Example from http://en.wikipedia.org/wiki/Kd-tree#Construction */ @SuppressWarnings("unchecked") public static List<DefaultInputOutputPair<Vector,Double>> POINTS = Arrays.asList( new DefaultInputOutputPair<Vector,Double>(VectorFactory.getDefault().copyValues(2, 3), 0.0), new DefaultInputOutputPair<Vector,Double>(VectorFactory.getDefault().copyValues(5, 4), 1.0), new DefaultInputOutputPair<Vector,Double>(VectorFactory.getDefault().copyValues(9, 6), 2.0), new DefaultInputOutputPair<Vector,Double>(VectorFactory.getDefault().copyValues(4, 7), 3.0), new DefaultInputOutputPair<Vector,Double>(VectorFactory.getDefault().copyValues(8, 1), 4.0), new DefaultInputOutputPair<Vector,Double>(VectorFactory.getDefault().copyValues(7, 2), 5.0) ); /** * Tests for class NearestNeighborTestHarness. * @param testName Name of the test. */ public NearestNeighborTestHarness( String testName) { super(testName); } /** * Creates an instance. * @param data Pairs. * @return KNN */ public abstract AbstractNearestNeighbor<Vector, Double> createInstance( Collection<? extends InputOutputPair<Vector,Double>> data ); /** * Test constructors */ public abstract void testConstructors(); /** * Test learner */ public abstract void testLearner(); /** * Tests to make sure you implemented the createInstance method properly. */ public void testCreateInstance() { System.out.println( "createInstance" ); AbstractNearestNeighbor<Vector, Double> nn = this.createInstance( POINTS ); assertEquals( POINTS.size(), nn.getData().size() ); for( InputOutputPair<? extends Vector,Double> d : nn.getData() ) { boolean found = false; for( InputOutputPair<Vector,Double> point : POINTS ) { if( point.getFirst().equals( d.getFirst() ) ) { found = true; assertEquals( point.getSecond(), d.getSecond() ); } } assertTrue( found ); } } /** * Test clone */ public void testClone() { System.out.println( "Clone" ); AbstractNearestNeighbor<Vector, Double> nn = this.createInstance( POINTS ); @SuppressWarnings("unchecked") AbstractNearestNeighbor<Vector, Double> clone = (AbstractNearestNeighbor<Vector, Double>) nn.clone(); assertNotNull( clone ); assertNotSame( nn, clone ); assertNotSame( nn.getDivergenceFunction(), clone.getDivergenceFunction() ); assertEquals( nn.getData().size(), clone.getData().size() ); assertNotSame( nn.getData(), clone.getData() ); } /** * Test of getDivergenceFunction method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor. */ public void testGetDivergenceFunction() { System.out.println( "getDivergenceFunction" ); AbstractNearestNeighbor<Vector, Double> nn = this.createInstance( POINTS ); assertNotNull( nn.getDivergenceFunction() ); } /** * Test of setDivergenceFunction method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor. */ public void testSetDivergenceFunction() { System.out.println( "setDivergenceFunction" ); AbstractNearestNeighbor<Vector, Double> nn = this.createInstance( POINTS ); assertNotNull( nn.getDivergenceFunction() ); DivergenceFunction<? super Vector, ? super Vector> foo = nn.getDivergenceFunction(); nn.setDivergenceFunction( null ); assertNull( nn.getDivergenceFunction() ); nn.setDivergenceFunction( foo ); assertSame( foo, nn.getDivergenceFunction() ); } /** * Test of getData method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor. */ public void testGetData() { System.out.println( "getData" ); AbstractNearestNeighbor<Vector, Double> nn = this.createInstance( POINTS ); Collection<InputOutputPair<? extends Vector,Double>> data = nn.getData(); assertNotNull( data ); assertEquals( POINTS.size(), data.size() ); for( InputOutputPair<? extends Vector,Double> value : data ) { boolean found = false; for( InputOutputPair<Vector,Double> point : POINTS ) { if( point.getFirst().equals( value.getFirst() ) ) { found = true; assertEquals( point.getSecond(), value.getSecond() ); } } assertTrue( found ); } } /** * Tests add. */ public void testAdd() { System.out.println( "add" ); AbstractNearestNeighbor<Vector, Double> nn = this.createInstance( POINTS ); InputOutputPair<Vector,Double> pair = new DefaultInputOutputPair<Vector,Double>(VectorFactory.getDefault().copyValues(0,0), 6.0); int preSize = nn.getData().size(); assertFalse( nn.getData().contains(pair) ); nn.add( pair ); assertEquals( preSize+1, nn.getData().size() ); boolean found = false; for( InputOutputPair<? extends Vector,Double> d : nn.getData() ) { if( pair.getFirst().equals( d.getFirst() ) ) { found = true; assertEquals( pair.getSecond(), d.getSecond() ); } } assertTrue( found ); // Add it again. nn.add( pair ); assertEquals( preSize+2, nn.getData().size() ); found = false; for( InputOutputPair<? extends Vector,Double> d : nn.getData() ) { if( pair.getFirst().equals( d.getFirst() ) ) { found = true; assertEquals( pair.getSecond(), d.getSecond() ); } } assertTrue( found ); } /** * Test of evaluate method, of class gov.sandia.isrc.learning.util.function.KNearestNeighbor. */ public void testEvaluateSimple() { System.out.println( "evaluate 1NN" ); AbstractNearestNeighbor<Vector, Double> nn = this.createInstance( POINTS ); int M = -1; for (InputOutputPair<? extends Vector, ? extends Double> pair : nn.getData()) { M = pair.getFirst().getDimensionality(); assertEquals( pair.getSecond(), nn.evaluate( pair.getFirst() ) ); } try { nn.evaluate( null ); fail( "Should have thrown null-pointer exception" ); } catch (Exception e) { System.out.println( "Good! Properly thrown exception: " + e ); } final double r = TOLERANCE; for (InputOutputPair<? extends Vector, ? extends Double> pair : nn.getData()) { Vector small = VectorFactory.getDefault().createUniformRandom( M, -r, r, RANDOM ); Vector perturb = pair.getFirst().plus( small ); Double estimate = nn.evaluate( perturb ); double error = pair.getSecond() - estimate; double distance = Math.abs( error ); assertEquals( 0.0, distance, TOLERANCE ); } } }