/*
* Copyright (c) 2003, the JUNG Project and the Regents of the University
* of California
* All rights reserved.
*
* This software is open-source under the BSD license; see either
* "license.txt" or
* http://jung.sourceforge.net/license.txt for a description.
*/
/*
* Created on Aug 9, 2004
*
*/
package edu.uci.ics.jung.algorithms.util;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Set;
/**
* Groups items into a specified number of clusters, based on their proximity in
* d-dimensional space, using the k-means algorithm. Calls to
* <code>cluster</code> will terminate when either of the two following
* conditions is true:
* <ul>
* <li/>the number of iterations is > <code>max_iterations</code>
* <li/>none of the centroids has moved as much as <code>convergence_threshold</code>
* since the previous iteration
* </ul>
*
* @author Joshua O'Madadhain
*/
public class KMeansClusterer<T>
{
protected int max_iterations;
protected double convergence_threshold;
protected Random rand;
/**
* Creates an instance whose termination conditions are set according
* to the parameters.
*/
public KMeansClusterer(int max_iterations, double convergence_threshold)
{
this.max_iterations = max_iterations;
this.convergence_threshold = convergence_threshold;
this.rand = new Random();
}
/**
* Creates an instance with max iterations of 100 and convergence threshold
* of 0.001.
*/
public KMeansClusterer()
{
this(100, 0.001);
}
/**
* Returns the maximum number of iterations.
*/
public int getMaxIterations()
{
return max_iterations;
}
/**
* Sets the maximum number of iterations.
*/
public void setMaxIterations(int max_iterations)
{
if (max_iterations < 0)
throw new IllegalArgumentException("max iterations must be >= 0");
this.max_iterations = max_iterations;
}
/**
* Returns the convergence threshold.
*/
public double getConvergenceThreshold()
{
return convergence_threshold;
}
/**
* Sets the convergence threshold.
* @param convergence_threshold
*/
public void setConvergenceThreshold(double convergence_threshold)
{
if (convergence_threshold <= 0)
throw new IllegalArgumentException("convergence threshold " +
"must be > 0");
this.convergence_threshold = convergence_threshold;
}
/**
* Returns a <code>Collection</code> of clusters, where each cluster is
* represented as a <code>Map</code> of <code>Objects</code> to locations
* in d-dimensional space.
* @param object_locations a map of the Objects to cluster, to
* <code>double</code> arrays that specify their locations in d-dimensional space.
* @param num_clusters the number of clusters to create
* @throws NotEnoughClustersException
*/
@SuppressWarnings("unchecked")
public Collection<Map<T, double[]>> cluster(Map<T, double[]> object_locations, int num_clusters)
{
if (object_locations == null || object_locations.isEmpty())
throw new IllegalArgumentException("'objects' must be non-empty");
if (num_clusters < 2 || num_clusters > object_locations.size())
throw new IllegalArgumentException("number of clusters " +
"must be >= 2 and <= number of objects (" +
object_locations.size() + ")");
Set<double[]> centroids = new HashSet<double[]>();
Object[] obj_array = object_locations.keySet().toArray();
Set<T> tried = new HashSet<T>();
// create the specified number of clusters
while (centroids.size() < num_clusters && tried.size() < object_locations.size())
{
T o = (T)obj_array[(int)(rand.nextDouble() * obj_array.length)];
tried.add(o);
double[] mean_value = object_locations.get(o);
boolean duplicate = false;
for (double[] cur : centroids)
{
if (Arrays.equals(mean_value, cur))
duplicate = true;
}
if (!duplicate)
centroids.add(mean_value);
}
if (tried.size() >= object_locations.size())
throw new NotEnoughClustersException();
// put items in their initial clusters
Map<double[], Map<T, double[]>> clusterMap = assignToClusters(object_locations, centroids);
// keep reconstituting clusters until either
// (a) membership is stable, or
// (b) number of iterations passes max_iterations, or
// (c) max movement of any centroid is <= convergence_threshold
int iterations = 0;
double max_movement = Double.POSITIVE_INFINITY;
while (iterations++ < max_iterations && max_movement > convergence_threshold)
{
max_movement = 0;
Set<double[]> new_centroids = new HashSet<double[]>();
// calculate new mean for each cluster
for (Map.Entry<double[], Map<T, double[]>> entry : clusterMap.entrySet())
{
double[] centroid = entry.getKey();
Map<T, double[]> elements = entry.getValue();
ArrayList<double[]> locations = new ArrayList<double[]>(elements.values());
double[] mean = DiscreteDistribution.mean(locations);
max_movement = Math.max(max_movement,
Math.sqrt(DiscreteDistribution.squaredError(centroid, mean)));
new_centroids.add(mean);
}
// TODO: check membership of clusters: have they changed?
// regenerate cluster membership based on means
clusterMap = assignToClusters(object_locations, new_centroids);
}
return clusterMap.values();
}
/**
* Assigns each object to the cluster whose centroid is closest to the
* object.
* @param object_locations a map of objects to locations
* @param centroids the centroids of the clusters to be formed
* @return a map of objects to assigned clusters
*/
protected Map<double[], Map<T, double[]>> assignToClusters(Map<T, double[]> object_locations, Set<double[]> centroids)
{
Map<double[], Map<T, double[]>> clusterMap = new HashMap<double[], Map<T, double[]>>();
for (double[] centroid : centroids)
clusterMap.put(centroid, new HashMap<T, double[]>());
for (Map.Entry<T, double[]> object_location : object_locations.entrySet())
{
T object = object_location.getKey();
double[] location = object_location.getValue();
// find the cluster with the closest centroid
Iterator<double[]> c_iter = centroids.iterator();
double[] closest = c_iter.next();
double distance = DiscreteDistribution.squaredError(location, closest);
while (c_iter.hasNext())
{
double[] centroid = c_iter.next();
double dist_cur = DiscreteDistribution.squaredError(location, centroid);
if (dist_cur < distance)
{
distance = dist_cur;
closest = centroid;
}
}
clusterMap.get(closest).put(object, location);
}
return clusterMap;
}
/**
* Sets the seed used by the internal random number generator.
* Enables consistent outputs.
*/
public void setSeed(int random_seed)
{
this.rand = new Random(random_seed);
}
/**
* An exception that indicates that the specified data points cannot be
* clustered into the number of clusters requested by the user.
* This will happen if and only if there are fewer distinct points than
* requested clusters. (If there are fewer total data points than
* requested clusters, <code>IllegalArgumentException</code> will be thrown.)
*
* @author Joshua O'Madadhain
*/
@SuppressWarnings("serial")
public static class NotEnoughClustersException extends RuntimeException
{
@Override
public String getMessage()
{
return "Not enough distinct points in the input data set to form " +
"the requested number of clusters";
}
}
}