/* * File: KNearestNeighborKDTreeTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Aug 5, 2009, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.algorithm.nearest; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.math.geometry.KDTree; import gov.sandia.cognition.learning.function.distance.EuclideanDistanceMetric; import gov.sandia.cognition.math.Metric; import gov.sandia.cognition.math.NumberAverager; import gov.sandia.cognition.math.matrix.Vector; import java.util.Collection; /** * Unit tests for KNearestNeighborKDTreeTest. * * @author krdixon */ public class KNearestNeighborKDTreeTest extends KNearestNeighborTestHarness { /** * Tests for class KNearestNeighborKDTreeTest. * @param testName Name of the test. */ public KNearestNeighborKDTreeTest( String testName) { super(testName); } @Override public KNearestNeighborKDTree<Vector, Double> createInstance( int k, Collection<? extends InputOutputPair<Vector, Double>> data) { KDTree<Vector,Double,InputOutputPair<? extends Vector,Double>> tree = new KDTree<Vector, Double, InputOutputPair<? extends Vector, Double>>( POINTS ); return new KNearestNeighborKDTree<Vector, Double>( k, tree, new CounterEuclidenDistance(), NumberAverager.INSTANCE ); } /** * Tests the constructors of class KNearestNeighborKDTreeTest. */ public void testConstructors() { System.out.println( "Constructors" ); KNearestNeighborKDTree<Vector, Double> knn = new KNearestNeighborKDTree<Vector, Double>(); assertEquals( KNearestNeighborKDTree.DEFAULT_K, knn.getK() ); assertNull( knn.getAverager() ); assertNull( knn.getDivergenceFunction() ); assertNull( knn.getData() ); int k = RANDOM.nextInt( 10 ) + 1; EuclideanDistanceMetric metric = EuclideanDistanceMetric.INSTANCE; NumberAverager averager = NumberAverager.INSTANCE; KDTree<Vector,Double,InputOutputPair<? extends Vector,Double>> tree = new KDTree<Vector, Double, InputOutputPair<? extends Vector, Double>>( POINTS ); knn = new KNearestNeighborKDTree<Vector, Double>( k, tree, metric, averager ); assertEquals( k, knn.getK() ); assertSame( tree, knn.getData() ); assertSame( metric, knn.getDivergenceFunction() ); assertSame( averager, knn.getAverager() ); } /** * Tests setData */ public void testSetData() { KNearestNeighborKDTree<Vector, Double> knn = this.createInstance(1,POINTS); Collection<InputOutputPair<? extends Vector, Double>> data = knn.getData(); KDTree<Vector,Double,InputOutputPair<? extends Vector,Double>> tree = new KDTree<Vector, Double, InputOutputPair<? extends Vector, Double>>( POINTS ); knn.setData(null); assertNull( knn.getData() ); knn.setData( tree ); assertSame( tree, knn.getData() ); } /** * setMetric */ public void testSetMetric() { System.out.println( "setMetric" ); KNearestNeighborKDTree<Vector, Double> knn = this.createInstance(1,POINTS); Metric<? super Vector> metric = knn.getDivergenceFunction(); knn.setDivergenceFunction(null); assertNull( knn.getDivergenceFunction() ); knn.setDivergenceFunction(metric); assertSame( metric, knn.getDivergenceFunction() ); } @Override public void testLearner() { System.out.println( "Learner" ); KNearestNeighborKDTree.Learner<Vector,Double> learner = new KNearestNeighborKDTree.Learner<Vector, Double>(); assertEquals( KNearestNeighborKDTree.DEFAULT_K, learner.getK() ); assertNull( learner.getAverager() ); assertSame( EuclideanDistanceMetric.INSTANCE, learner.getDivergenceFunction() ); int k = RANDOM.nextInt(10) + 1; learner.setK(k); learner.setAverager( NumberAverager.INSTANCE ); KNearestNeighborKDTree<Vector,Double> knn = learner.learn(POINTS); assertEquals( k, knn.getK() ); assertNotSame( learner.getAverager(), knn.getAverager() ); assertNotNull( knn.getAverager() ); assertNotSame( learner.getDivergenceFunction(), knn.getDivergenceFunction() ); assertNotNull( knn.getDivergenceFunction() ); assertEquals( POINTS.size(), knn.getData().size() ); assertTrue( knn.getData().containsAll(POINTS) ); } }