/*
* Copyright (c) 2003, the JUNG Project and the Regents of the University
* of California
* All rights reserved.
*
* This software is open-source under the BSD license; see either
* "license.txt" or
* http://jung.sourceforge.net/license.txt for a description.
*/
/*
* Created on Aug 9, 2004
*
*/
package edu.uci.ics.jung.algorithms.util;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Random;
import java.util.Set;
/**
* Groups items into a specified number of clusters, based on their proximity in
* d-dimensional space, using the k-means algorithm. Calls to
* <code>cluster</code> will terminate when either of the two following
* conditions is true:
* <ul>
* <li/>the number of iterations is > <code>max_iterations</code>
* <li/>none of the centroids has moved as much as
* <code>convergence_threshold</code> since the previous iteration
* </ul>
*
* @author Joshua O'Madadhain
*/
public class KMeansClusterer<T> {
protected int max_iterations;
protected double convergence_threshold;
protected Random rand;
/**
* Creates an instance whose termination conditions are set according to the
* parameters.
*/
public KMeansClusterer(int max_iterations, double convergence_threshold) {
this.max_iterations = max_iterations;
this.convergence_threshold = convergence_threshold;
this.rand = new Random();
}
/**
* Creates an instance with max iterations of 100 and convergence threshold
* of 0.001.
*/
public KMeansClusterer() {
this(100, 0.001);
}
/**
* Returns the maximum number of iterations.
*/
public int getMaxIterations() {
return max_iterations;
}
/**
* Sets the maximum number of iterations.
*/
public void setMaxIterations(int max_iterations) {
if (max_iterations < 0) {
throw new IllegalArgumentException("max iterations must be >= 0");
}
this.max_iterations = max_iterations;
}
/**
* Returns the convergence threshold.
*/
public double getConvergenceThreshold() {
return convergence_threshold;
}
/**
* Sets the convergence threshold.
*
* @param convergence_threshold
*/
public void setConvergenceThreshold(double convergence_threshold) {
if (convergence_threshold <= 0) {
throw new IllegalArgumentException(
"convergence threshold " + "must be > 0");
}
this.convergence_threshold = convergence_threshold;
}
/**
* Returns a <code>Collection</code> of clusters, where each cluster is
* represented as a <code>Map</code> of <code>Objects</code> to locations in
* d-dimensional space.
*
* @param object_locations
* a map of the Objects to cluster, to <code>double</code> arrays
* that specify their locations in d-dimensional space.
* @param num_clusters
* the number of clusters to create
* @throws NotEnoughClustersException
*/
@SuppressWarnings("unchecked")
public Collection<Map<T, double[]>> cluster(
Map<T, double[]> object_locations, int num_clusters) {
if (object_locations == null || object_locations.isEmpty()) {
throw new IllegalArgumentException("'objects' must be non-empty");
}
if (num_clusters < 2 || num_clusters > object_locations.size()) {
throw new IllegalArgumentException("number of clusters "
+ "must be >= 2 and <= number of objects ("
+ object_locations.size() + ")");
}
Set<double[]> centroids = new HashSet<double[]>();
Object[] obj_array = object_locations.keySet().toArray();
Set<T> tried = new HashSet<T>();
// create the specified number of clusters
while (centroids.size() < num_clusters
&& tried.size() < object_locations.size()) {
T o = (T) obj_array[(int) (rand.nextDouble() * obj_array.length)];
tried.add(o);
double[] mean_value = object_locations.get(o);
boolean duplicate = false;
for (double[] cur : centroids) {
if (Arrays.equals(mean_value, cur)) {
duplicate = true;
}
}
if (!duplicate) {
centroids.add(mean_value);
}
}
if (tried.size() >= object_locations.size()) {
throw new NotEnoughClustersException();
}
// put items in their initial clusters
Map<double[], Map<T, double[]>> clusterMap = assignToClusters(
object_locations, centroids);
// keep reconstituting clusters until either
// (a) membership is stable, or
// (b) number of iterations passes max_iterations, or
// (c) max movement of any centroid is <= convergence_threshold
int iterations = 0;
double max_movement = Double.POSITIVE_INFINITY;
while (iterations++ < max_iterations
&& max_movement > convergence_threshold) {
max_movement = 0;
Set<double[]> new_centroids = new HashSet<double[]>();
// calculate new mean for each cluster
for (Map.Entry<double[], Map<T, double[]>> entry : clusterMap
.entrySet()) {
double[] centroid = entry.getKey();
Map<T, double[]> elements = entry.getValue();
ArrayList<double[]> locations = new ArrayList<double[]>(
elements.values());
double[] mean = DiscreteDistribution.mean(locations);
max_movement = Math.max(max_movement, Math.sqrt(
DiscreteDistribution.squaredError(centroid, mean)));
new_centroids.add(mean);
}
// TODO: check membership of clusters: have they changed?
// regenerate cluster membership based on means
clusterMap = assignToClusters(object_locations, new_centroids);
}
return clusterMap.values();
}
/**
* Assigns each object to the cluster whose centroid is closest to the
* object.
*
* @param object_locations
* a map of objects to locations
* @param centroids
* the centroids of the clusters to be formed
* @return a map of objects to assigned clusters
*/
protected Map<double[], Map<T, double[]>> assignToClusters(
Map<T, double[]> object_locations, Set<double[]> centroids) {
Map<double[], Map<T, double[]>> clusterMap = new HashMap<double[], Map<T, double[]>>();
for (double[] centroid : centroids) {
clusterMap.put(centroid, new HashMap<T, double[]>());
}
for (Map.Entry<T, double[]> object_location : object_locations
.entrySet()) {
T object = object_location.getKey();
double[] location = object_location.getValue();
// find the cluster with the closest centroid
Iterator<double[]> c_iter = centroids.iterator();
double[] closest = c_iter.next();
double distance = DiscreteDistribution.squaredError(location,
closest);
while (c_iter.hasNext()) {
double[] centroid = c_iter.next();
double dist_cur = DiscreteDistribution.squaredError(location,
centroid);
if (dist_cur < distance) {
distance = dist_cur;
closest = centroid;
}
}
clusterMap.get(closest).put(object, location);
}
return clusterMap;
}
/**
* Sets the seed used by the internal random number generator. Enables
* consistent outputs.
*/
public void setSeed(int random_seed) {
this.rand = new Random(random_seed);
}
/**
* An exception that indicates that the specified data points cannot be
* clustered into the number of clusters requested by the user. This will
* happen if and only if there are fewer distinct points than requested
* clusters. (If there are fewer total data points than requested clusters,
* <code>IllegalArgumentException</code> will be thrown.)
*
* @author Joshua O'Madadhain
*/
@SuppressWarnings("serial")
public static class NotEnoughClustersException extends RuntimeException {
@Override
public String getMessage() {
return "Not enough distinct points in the input data set to form "
+ "the requested number of clusters";
}
}
}