package edu.umd.cloud9.example.clustering; import java.util.Arrays; import java.util.List; import javax.annotation.Nullable; import com.google.common.base.Function; import com.google.common.collect.Lists; public class KMeans { private static final int MAX_ITERATIONS = 30; /** * Performs a k-means on the point set to compute k clusters. * * @param points point set * @param k number of clusters * @return clusters */ public static List<Point>[] run(Point[] points, int k) { Point[] centroids = initialize(points, k); int[] repartition = new int[points.length]; @SuppressWarnings("unchecked") List<Point>[] clusters = (List<Point>[]) new List[k]; int it = 0; int[] tmp = new int[points.length]; do { tmp = repartition.clone(); repartitionStep(points, k, centroids, repartition, clusters); centroidStep(points, k, centroids, clusters); it++; } while (!Arrays.equals(repartition, tmp) && it < MAX_ITERATIONS); return clusters; } public static void dumpClusters(List<Point>[] clusters){ for (List<Point> cluster : clusters) { System.out.println(Lists.transform(cluster, new Function<Point, Double>() { @Override @Nullable public Double apply(@Nullable Point point) { return point.value; } }).toString()); } } /** * Initializes the k-means by randomly picking points in the set. * * @param points point set * @param k number of clusters * @return clusters */ private static Point[] initialize(Point[] points, int k) { Integer[] arr = ExpectationMaximization.sampleNUniquePoints(k, points.length); Point[] centroids = new Point[k]; for (int i=0; i<k; i++) { centroids[i] = new Point(points[arr[i]].value); } // Return return centroids; } /** * Processes the repartition step. * * @param points point set * @param k number of clusters * @param centroids centroids of the clusters * @param repartition repartition array * @param clusters clusters */ private static void repartitionStep(Point[] points, int k, Point[] centroids, int[] repartition, List<Point>[] clusters) { // Initialization of the clusters for (int i = 0; i < k; i++) clusters[i] = Lists.newArrayList(); // Compute repartition for (int i = 0; i < points.length; i++) { int index = 0; double dist = Double.MAX_VALUE; for (int j = 0; j < k; j++) { double dist_tmp = Math.abs(points[i].value - centroids[j].value); if (dist_tmp < dist) { dist = dist_tmp; index = j; } } repartition[i] = index; clusters[index].add(points[i]); } } /** * Processes the centroid step. * * @param points point set * @param k number of clusters * @param centroids centroids of the clusters * @param clusters clusters */ private static void centroidStep(Point[] points, int k, Point[] centroids, List<Point>[] clusters) { for (int i = 0; i < k; i++) { centroids[i] = new Point(0); for (int j = 0; j < clusters[i].size(); j++) { centroids[i].value = centroids[i].value + clusters[i].get(j).value; } centroids[i].value = centroids[i].value * (1.0d / clusters[i].size()); } } }