/* * File: KMeansClustererTest.java * Authors: Justin Basilico * Authors: Jeff Piersol * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright April 4, 2016, Sandia Corporation. Under the terms of Contract * DE-AC04-94AL85000, there is a non-exclusive license for use of this work by * or on behalf of the U.S. Government. Export of this program may require a * license from the United States Government. See CopyrightHistory.txt for * complete details. * */ package gov.sandia.cognition.learning.algorithm.clustering; import gov.sandia.cognition.learning.algorithm.clustering.cluster.CentroidCluster; import gov.sandia.cognition.learning.algorithm.clustering.cluster.ClusterCreator; import gov.sandia.cognition.learning.algorithm.clustering.cluster.MiniBatchCentroidCluster; import gov.sandia.cognition.learning.function.distance.EuclideanDistanceMetric; import gov.sandia.cognition.learning.algorithm.clustering.cluster.VectorMeanMiniBatchCentroidClusterCreator; import gov.sandia.cognition.learning.algorithm.clustering.divergence.CentroidClusterDivergenceFunction; import gov.sandia.cognition.learning.algorithm.clustering.initializer.FixedClusterInitializer; import gov.sandia.cognition.learning.algorithm.clustering.initializer.GreedyClusterInitializer; import gov.sandia.cognition.learning.function.distance.CosineDistanceMetric; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.mtj.Vector2; import gov.sandia.cognition.util.NamedValue; import java.util.ArrayList; import java.util.Collection; import java.util.Random; import junit.framework.TestCase; /** * This class implements JUnit tests for the following classes: * * MiniBatchKMeansClusterer * * @author Justin Basilico * @author Jeff Piersol * @since 4.0.0 */ public class MiniBatchKMeansClustererTest extends TestCase { /** * The distance metric used in tests. */ protected EuclideanDistanceMetric metric = null; /** * The cluster creator used in tests. */ protected ClusterCreator<MiniBatchCentroidCluster, Vector> creator = null; /** * The random number generator used in tests. */ protected Random random = null; /** * The cluster initializer used in tests. */ protected FixedClusterInitializer<MiniBatchCentroidCluster, Vector> initializer = null; /** * Creates a new instance of MiniBatchKMeansClustererTest. * * @param testName The test name. */ public MiniBatchKMeansClustererTest( String testName) { super(testName); this.metric = EuclideanDistanceMetric.INSTANCE; this.creator = VectorMeanMiniBatchCentroidClusterCreator.INSTANCE; this.random = new Random(); this.initializer = new GreedyClusterInitializer<>( CosineDistanceMetric.INSTANCE, creator, random); } /** * Creates a new clusterer to test. * * @return A new cluster to test. */ public MiniBatchKMeansClusterer<Vector> createClusterer() { return new MiniBatchKMeansClusterer<>( 0, 100, this.initializer, this.metric, this.creator, this.random); } /** * Tests the creation of a MiniBatchKMeansClusterer. */ public void testCreation() { MiniBatchKMeansClusterer<Vector> kmeans = this.createClusterer(); assertEquals(0, kmeans.getNumClusters()); assertSame(this.initializer, kmeans.getInitializer()); assertEquals(this.metric, ((CentroidClusterDivergenceFunction) kmeans.getDivergenceFunction()).getDivergenceFunction()); assertSame(this.creator, kmeans.getCreator()); kmeans.setNumRequestedClusters(1); assertEquals(1, kmeans.getNumRequestedClusters()); } /** * Tests the clustering of a MiniBatchKMeansClusterer. */ public void testClustering() { MiniBatchKMeansClusterer<Vector> kmeans = this.createClusterer(); assertEquals(0, kmeans.getNumElements()); ArrayList<Vector> elements = new ArrayList<>(); Collection<MiniBatchCentroidCluster> clusters = null; ArrayList<CentroidCluster<Vector>> clustersList = null; Vector2 v1 = new Vector2(-2.0, 0.0); Vector2 v2 = new Vector2(-1.0, 2.0); Vector2 v3 = new Vector2(0.0, 3.0); Vector2 v4 = new Vector2(3.0, 1.0); Vector2 v5 = new Vector2(3.0, -1.0); kmeans.setNumRequestedClusters(0); clusters = kmeans.learn(elements); assertNotNull(clusters); assertEquals(0, clusters.size()); kmeans.setNumRequestedClusters(1); clusters = kmeans.learn(elements); assertNotNull(clusters); assertEquals(0, clusters.size()); // Add a vector to the list of elements. elements.add(v1); // Try giving no clusters to create with a non-empty list of elements. kmeans.setNumRequestedClusters(0); clusters = kmeans.learn(elements); assertNotNull(clusters); assertEquals(0, clusters.size()); // Try creating one cluster from one element. kmeans.setNumRequestedClusters(1); clusters = kmeans.learn(elements); assertNotNull(clusters); assertEquals(1, clusters.size()); // Try creating two clusters from one element. Should return non-null b/c of greedy initializer. kmeans.setNumRequestedClusters(2); clusters = kmeans.learn(elements); assertNotNull(clusters); assertEquals(1, clusters.size()); // Add some more elements. elements.add(v2); elements.add(v3); elements.add(v4); elements.add(v5); // Use spherical k-means kmeans.setNumRequestedClusters(2); kmeans.setDivergenceFunction(new CentroidClusterDivergenceFunction<>( CosineDistanceMetric.INSTANCE)); kmeans.setMinibatchSize(20); clusters = kmeans.learn(elements); assertNotNull(clusters); assertEquals(2, clusters.size()); // clustersList = new ArrayList<CentroidCluster<Vector>>(clusters); int biggerClusterIdx = kmeans.getCluster(0).getMembers().size() >= kmeans.getCluster(1).getMembers().size() ? 0 : 1; int smallerClusterIdx = biggerClusterIdx == 1 ? 0 : 1; CentroidCluster<Vector> cluster1 = kmeans.getCluster(biggerClusterIdx); CentroidCluster<Vector> cluster2 = kmeans.getCluster(smallerClusterIdx); assertNotNull(cluster1); assertNotNull(cluster2); assertEquals(3, cluster1.getMembers().size()); assertEquals(2, cluster2.getMembers().size()); Vector centroid1 = cluster1.getCentroid(); Vector centroid2 = cluster2.getCentroid(); centroid1.unitVectorEquals(); // part of spherical k-means centroid2.unitVectorEquals(); assertNotNull(centroid1); assertNotNull(centroid2); assertFalse(centroid1.equals(centroid2)); assertEquals(-0.514, centroid1.get(0), 0.5); assertEquals(0.857, centroid1.get(1), 0.5); assertEquals(1, centroid2.get(0), 0.5); assertEquals(0, centroid2.get(1), 0.5); assertTrue(cluster1.getMembers().contains(v1)); assertTrue(cluster1.getMembers().contains(v2)); assertTrue(cluster1.getMembers().contains(v3)); assertTrue(cluster2.getMembers().contains(v4)); assertTrue(cluster2.getMembers().contains(v5)); NamedValue<? extends Number> value = kmeans.getPerformance(); assertNotNull(value); System.out.println("Value: " + value.getName() + " = " + value.getValue()); boolean exceptionThrown = false; try { kmeans.setNumRequestedClusters(-1); } catch (IllegalArgumentException iae) { exceptionThrown = true; } finally { assertTrue(exceptionThrown); } } }