/* * Copyright (c) 2003, the JUNG Project and the Regents of the University * of California * All rights reserved. * * This software is open-source under the BSD license; see either * "license.txt" or * http://jung.sourceforge.net/license.txt for a description. */ /* * Created on Aug 9, 2004 * */ package edu.uci.ics.jung.algorithms.util; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Random; import java.util.Set; /** * Groups items into a specified number of clusters, based on their proximity in * d-dimensional space, using the k-means algorithm. Calls to * <code>cluster</code> will terminate when either of the two following * conditions is true: * <ul> * <li/>the number of iterations is > <code>max_iterations</code> * <li/>none of the centroids has moved as much as * <code>convergence_threshold</code> since the previous iteration * </ul> * * @author Joshua O'Madadhain */ public class KMeansClusterer<T> { protected int max_iterations; protected double convergence_threshold; protected Random rand; /** * Creates an instance whose termination conditions are set according to the * parameters. */ public KMeansClusterer(int max_iterations, double convergence_threshold) { this.max_iterations = max_iterations; this.convergence_threshold = convergence_threshold; this.rand = new Random(); } /** * Creates an instance with max iterations of 100 and convergence threshold * of 0.001. */ public KMeansClusterer() { this(100, 0.001); } /** * Returns the maximum number of iterations. */ public int getMaxIterations() { return max_iterations; } /** * Sets the maximum number of iterations. */ public void setMaxIterations(int max_iterations) { if (max_iterations < 0) { throw new IllegalArgumentException("max iterations must be >= 0"); } this.max_iterations = max_iterations; } /** * Returns the convergence threshold. */ public double getConvergenceThreshold() { return convergence_threshold; } /** * Sets the convergence threshold. * * @param convergence_threshold */ public void setConvergenceThreshold(double convergence_threshold) { if (convergence_threshold <= 0) { throw new IllegalArgumentException( "convergence threshold " + "must be > 0"); } this.convergence_threshold = convergence_threshold; } /** * Returns a <code>Collection</code> of clusters, where each cluster is * represented as a <code>Map</code> of <code>Objects</code> to locations in * d-dimensional space. * * @param object_locations * a map of the Objects to cluster, to <code>double</code> arrays * that specify their locations in d-dimensional space. * @param num_clusters * the number of clusters to create * @throws NotEnoughClustersException */ @SuppressWarnings("unchecked") public Collection<Map<T, double[]>> cluster( Map<T, double[]> object_locations, int num_clusters) { if (object_locations == null || object_locations.isEmpty()) { throw new IllegalArgumentException("'objects' must be non-empty"); } if (num_clusters < 2 || num_clusters > object_locations.size()) { throw new IllegalArgumentException("number of clusters " + "must be >= 2 and <= number of objects (" + object_locations.size() + ")"); } Set<double[]> centroids = new HashSet<double[]>(); Object[] obj_array = object_locations.keySet().toArray(); Set<T> tried = new HashSet<T>(); // create the specified number of clusters while (centroids.size() < num_clusters && tried.size() < object_locations.size()) { T o = (T) obj_array[(int) (rand.nextDouble() * obj_array.length)]; tried.add(o); double[] mean_value = object_locations.get(o); boolean duplicate = false; for (double[] cur : centroids) { if (Arrays.equals(mean_value, cur)) { duplicate = true; } } if (!duplicate) { centroids.add(mean_value); } } if (tried.size() >= object_locations.size()) { throw new NotEnoughClustersException(); } // put items in their initial clusters Map<double[], Map<T, double[]>> clusterMap = assignToClusters( object_locations, centroids); // keep reconstituting clusters until either // (a) membership is stable, or // (b) number of iterations passes max_iterations, or // (c) max movement of any centroid is <= convergence_threshold int iterations = 0; double max_movement = Double.POSITIVE_INFINITY; while (iterations++ < max_iterations && max_movement > convergence_threshold) { max_movement = 0; Set<double[]> new_centroids = new HashSet<double[]>(); // calculate new mean for each cluster for (Map.Entry<double[], Map<T, double[]>> entry : clusterMap .entrySet()) { double[] centroid = entry.getKey(); Map<T, double[]> elements = entry.getValue(); ArrayList<double[]> locations = new ArrayList<double[]>( elements.values()); double[] mean = DiscreteDistribution.mean(locations); max_movement = Math.max(max_movement, Math.sqrt( DiscreteDistribution.squaredError(centroid, mean))); new_centroids.add(mean); } // TODO: check membership of clusters: have they changed? // regenerate cluster membership based on means clusterMap = assignToClusters(object_locations, new_centroids); } return clusterMap.values(); } /** * Assigns each object to the cluster whose centroid is closest to the * object. * * @param object_locations * a map of objects to locations * @param centroids * the centroids of the clusters to be formed * @return a map of objects to assigned clusters */ protected Map<double[], Map<T, double[]>> assignToClusters( Map<T, double[]> object_locations, Set<double[]> centroids) { Map<double[], Map<T, double[]>> clusterMap = new HashMap<double[], Map<T, double[]>>(); for (double[] centroid : centroids) { clusterMap.put(centroid, new HashMap<T, double[]>()); } for (Map.Entry<T, double[]> object_location : object_locations .entrySet()) { T object = object_location.getKey(); double[] location = object_location.getValue(); // find the cluster with the closest centroid Iterator<double[]> c_iter = centroids.iterator(); double[] closest = c_iter.next(); double distance = DiscreteDistribution.squaredError(location, closest); while (c_iter.hasNext()) { double[] centroid = c_iter.next(); double dist_cur = DiscreteDistribution.squaredError(location, centroid); if (dist_cur < distance) { distance = dist_cur; closest = centroid; } } clusterMap.get(closest).put(object, location); } return clusterMap; } /** * Sets the seed used by the internal random number generator. Enables * consistent outputs. */ public void setSeed(int random_seed) { this.rand = new Random(random_seed); } /** * An exception that indicates that the specified data points cannot be * clustered into the number of clusters requested by the user. This will * happen if and only if there are fewer distinct points than requested * clusters. (If there are fewer total data points than requested clusters, * <code>IllegalArgumentException</code> will be thrown.) * * @author Joshua O'Madadhain */ @SuppressWarnings("serial") public static class NotEnoughClustersException extends RuntimeException { @Override public String getMessage() { return "Not enough distinct points in the input data set to form " + "the requested number of clusters"; } } }