/*
* File: KMeansClustererTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright March 16, 2006, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.clustering;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.CentroidCluster;
import gov.sandia.cognition.learning.algorithm.clustering.divergence.CentroidClusterDivergenceFunction;
import gov.sandia.cognition.learning.algorithm.clustering.initializer.GreedyClusterInitializer;
import gov.sandia.cognition.learning.function.distance.EuclideanDistanceMetric;
import gov.sandia.cognition.learning.algorithm.clustering.cluster.VectorMeanCentroidClusterCreator;
import gov.sandia.cognition.learning.algorithm.clustering.initializer.FixedClusterInitializer;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.mtj.Vector2;
import gov.sandia.cognition.util.NamedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;
import junit.framework.TestCase;
/**
* This class implements JUnit tests for the following classes:
*
* KMeansClusterer
*
* @author Justin Basilico
* @since 1.0
*/
public class KMeansClustererTest
extends TestCase
{
/** The distance metric used in tests. */
protected EuclideanDistanceMetric metric = null;
/** The cluster creator used in tests. */
protected VectorMeanCentroidClusterCreator creator = null;
/** The random number generator used in tests. */
protected Random random = null;
/** The cluster initializer used in tests. */
protected GreedyClusterInitializer<CentroidCluster<Vector>, Vector> initializer = null;
/** The cluster divergence functio used in tests. */
protected CentroidClusterDivergenceFunction<Vector> clusterMetric = null;
/**
* Creates a new instance of KMeansClustererTest.
*
* @param testName The test name.
*/
public KMeansClustererTest(
String testName)
{
super(testName);
this.metric = EuclideanDistanceMetric.INSTANCE;
this.creator = VectorMeanCentroidClusterCreator.INSTANCE;
this.random = new Random();
this.initializer =
new GreedyClusterInitializer<CentroidCluster<Vector>, Vector>(
metric, creator, random);
this.clusterMetric =
new CentroidClusterDivergenceFunction<Vector>(metric);
}
/**
* Creates a new clusterer to test.
*
* @return A new cluster to test.
*/
public KMeansClusterer<Vector, CentroidCluster<Vector>> createClusterer()
{
return new KMeansClusterer<Vector, CentroidCluster<Vector>>(
0, 100, this.initializer, this.clusterMetric, this.creator);
}
/**
* Tests the creation of a KMeansClusterer.
*
* If this test fails, contact Justin Basilico.
*/
public void testCreation()
{
KMeansClusterer<Vector, CentroidCluster<Vector>> kmeans =
this.createClusterer();
assertEquals(0, kmeans.getNumClusters());
assertSame(this.initializer, kmeans.getInitializer());
assertEquals(this.clusterMetric, kmeans.getDivergenceFunction());
assertSame(this.creator, kmeans.getCreator());
kmeans.setNumRequestedClusters(1);
assertEquals(1, kmeans.getNumRequestedClusters());
}
/**
* Tests the clustering of a KMeansClusterer.
*
* If this test fails, contact Justin Basilico.
*/
public void testClustering()
{
KMeansClusterer<Vector, CentroidCluster<Vector>> kmeans =
this.createClusterer();
assertEquals( 0, kmeans.getNumElements() );
ArrayList<Vector> elements = new ArrayList<Vector>();
Collection<CentroidCluster<Vector>> clusters = null;
ArrayList<CentroidCluster<Vector>> clustersList = null;
Vector2 v1 = new Vector2(1.0, 1.0);
Vector2 v2 = new Vector2(1.0, 1.2);
Vector2 v3 = new Vector2(4.0, 4.0);
Vector2 v4 = new Vector2(4.0, 4.2);
kmeans.setNumRequestedClusters(0);
clusters = kmeans.learn(elements);
assertNotNull(clusters);
assertEquals(0, clusters.size());
kmeans.setNumRequestedClusters(1);
clusters = kmeans.learn(elements);
assertNotNull(clusters);
assertEquals(0, clusters.size());
// Add a vector to the list of elements.
elements.add(v1);
// Try giving no clusters to create with a non-empty list of elements.
kmeans.setNumRequestedClusters(0);
clusters = kmeans.learn(elements);
assertNotNull(clusters);
assertEquals(0, clusters.size());
// Try creating one cluster from one element.
kmeans.setNumRequestedClusters(1);
clusters = kmeans.learn(elements);
assertNotNull(clusters);
assertEquals(1, clusters.size());
// Try creating two clusters from one element. Only one should be
// returned.
kmeans.setNumRequestedClusters(2);
clusters = kmeans.learn(elements);
assertNotNull(clusters);
assertEquals(1, clusters.size());
// Add some more elements.
elements.add(v2);
elements.add(v3);
elements.add(v4);
kmeans.setNumRequestedClusters(2);
clusters = kmeans.learn(elements);
assertNotNull(clusters);
assertEquals(2, clusters.size());
// clustersList = new ArrayList<CentroidCluster<Vector>>(clusters);
CentroidCluster<Vector> cluster1 = kmeans.getCluster(0);
CentroidCluster<Vector> cluster2 = kmeans.getCluster(1);
assertNotNull(cluster1);
assertNotNull(cluster2);
assertEquals(2, cluster1.getMembers().size());
assertEquals(2, cluster2.getMembers().size());
Vector centroid1 = cluster1.getCentroid();
Vector centroid2 = cluster2.getCentroid();
assertNotNull(centroid1);
assertNotNull(centroid1);
assertFalse(centroid1.equals(centroid2));
if (metric.evaluate(v1, centroid1) < metric.evaluate(v1, centroid2))
{
assertEquals(new Vector2(1.0, 1.1), centroid1);
assertEquals(new Vector2(4.0, 4.1), centroid2);
assertTrue(cluster1.getMembers().contains(v1));
assertTrue(cluster1.getMembers().contains(v2));
assertTrue(cluster2.getMembers().contains(v3));
assertTrue(cluster2.getMembers().contains(v4));
}
else
{
assertEquals(new Vector2(1.0, 1.1), centroid2);
assertEquals(new Vector2(4.0, 4.1), centroid1);
assertTrue(cluster2.getMembers().contains(v1));
assertTrue(cluster2.getMembers().contains(v2));
assertTrue(cluster1.getMembers().contains(v3));
assertTrue(cluster1.getMembers().contains(v4));
}
NamedValue<? extends Number> value = kmeans.getPerformance();
assertNotNull( value );
System.out.println( "Value: " + value.getName() + " = " + value.getValue() );
boolean exceptionThrown = false;
try
{
kmeans.setNumRequestedClusters(-1);
}
catch (IllegalArgumentException iae)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
}
public void testClusteringAvoidNulls()
{
KMeansClusterer<Vector, CentroidCluster<Vector>> instance =
this.createClusterer();
// Create a bad initializer.
instance.setInitializer(new FixedClusterInitializer<CentroidCluster<Vector>, Vector>()
{
@Override
public ArrayList<CentroidCluster<Vector>> initializeClusters(
final int numClusters,
final Collection<? extends Vector> elements)
{
final ArrayList<CentroidCluster<Vector>> result = new ArrayList<>();
result.add(new CentroidCluster<Vector>(CollectionUtil.getElement(elements, 0)));
result.add(new CentroidCluster<Vector>(CollectionUtil.getElement(elements, 1)));
result.add(new CentroidCluster<Vector>(CollectionUtil.getElement(elements, 2)));
return result;
}
});
ArrayList<Vector> data = new ArrayList<Vector>();
Vector2 v1 = new Vector2(1.0, 1.0);
Vector2 v2 = new Vector2(2.0, 2.0);
data.add(v1.clone());
data.add(v1.clone());
data.add(v2.clone());
data.add(v1.clone());
data.add(v2.clone());
data.add(v2.clone());
instance.setNumRequestedClusters(3);
Collection<CentroidCluster<Vector>> result = instance.learn(data);
}
}