package com.facebook.hive.udf; import org.apache.hadoop.hive.ql.exec.Description; import org.apache.hadoop.hive.ql.exec.UDF; import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException; import org.apache.hadoop.hive.ql.parse.SemanticException; import java.util.ArrayList; /** * Performs K-means on a set of points. Each point is represented by an array * of DOUBLEs. The set of points is passed in as the first argument as an * array of points (i.e., an array of arrays). Each point should have the same * number of elements. The second argument is an integer indicating the number * of clusters, K, to cluster the points into. The last argument is an upper * bound on the number of iterations the procedure will execute (see below). * Note that since all the points are passed in as an array, this is a UDF and * not a UDAF. If you wish to perform K-means where the points are distributed * across multiple rows, use FB_COLLECT to assemble the points into a single * row. * * The algorithm used here first selects initial cluster centers using the * K-means++ heuristic. Then the standard Lloyd's algorithm is performed until * convergence or max_iterations iterations have elapsed. * * If any arguments are NULL then NULL is returned. * * The return value is an array of cluster centers. Each center is an array of * DOUBLEs whose size is two greater than the dimensionality of the inputs. * Letting M denote the dimensionality of the inputs, then the first M elements * of each center contain the mean coordinates for that cluster. The next * element contains the average squared distance of the points from the cluster * mean. The last element contains the number of points assigned to that * cluster. Note that the number of centers may be less than K if less then K * points were passed in in the first argument. */ @Description(name = "kmeans", value = "_FUNC_(points, K, max_iterations) - Perform K-means clustering on a collection of points represented as arrays and returns an array of cluster centers.") public class UDFKmeans extends UDF { public int sample(double[] weights) throws SemanticException { double weight_sum = 0; for (int ii = 0; ii < weights.length; ++ii) { if (weights[ii] < 0.0) { return -3; } weight_sum += weights[ii]; } double r = Math.random(); if (weight_sum == 0.0) { return (int)(r * weights.length); } for (int ii = 0; ii < weights.length; ++ii) { if (r < weights[ii] / weight_sum) { return ii; } r -= weights[ii] / weight_sum; } return -2; } public double squared_dist(ArrayList<Double> center, ArrayList<Double> point) throws SemanticException { int M = point.size(); if (M != center.size()) { throw new UDFArgumentTypeException(M, "This should never happen."); } double dist2 = 0; for (int mm = 0; mm < M; ++mm) { dist2 += (center.get(mm) - point.get(mm)) * (center.get(mm) - point.get(mm)); } return dist2; } public ArrayList<ArrayList<Double>> evaluate(ArrayList<ArrayList<Double>> points, Integer K, Integer max_iterations) throws SemanticException { if (K == null || max_iterations == null || points == null) { return null; } if (K <= 0) { throw new UDFArgumentTypeException(K, "K should be positive."); } int N = points.size(); // If we have fewer points than clusters, then just return the input. if (N < K) { for (int ii = 0; ii < N; ++ii) { points.get(ii).add(0.0); points.get(ii).add(1.0); } return points; } // First initialize using kmeans++. int M = points.get(0).size(); ArrayList<ArrayList<Double>> centers = new ArrayList<ArrayList<Double>>(); double dist2s[] = new double[N]; for (int ii = 0; ii < K; ++ii) { int new_center = -1; if (ii > 0) { // Compute the distance to all centers. for (int jj = 0; jj < N; ++jj) { dist2s[jj] = 1e100; ArrayList<Double> point = points.get(jj); if (point.size() != M) { throw new UDFArgumentTypeException(M, "Sizes of tuples do not match."); } for (int kk = 0; kk < ii; ++kk) { double dist2 = squared_dist(centers.get(kk), point); if (dist2 < dist2s[jj]) { dist2s[jj] = dist2; } } } // Select a new point. new_center = sample(dist2s); } else { new_center = (int)(Math.random() * N); } ArrayList<Double> point = points.get(new_center); if (point.size() != M) { throw new UDFArgumentTypeException(M, "Sizes of tuples do not match."); } // Note, we go through the following rigamarole to ensure we have a proper clone. ArrayList<Double> new_point = new ArrayList<Double>(); for (int mm = 0; mm < M; ++mm) { new_point.add(point.get(mm).doubleValue()); } centers.add(new_point); } int[] assignments = new int[N]; int[] center_counts = new int[K]; for (int jj = 0; jj < N; ++jj) { assignments[jj] = -1; } for (int ii = 0; ii < max_iterations; ++ii) { for (int kk = 0; kk < K; ++kk) { center_counts[kk] = 0; } boolean changed = false; // Compute the assignments. for (int jj = 0; jj < N; ++jj) { int old_assignment = assignments[jj]; // Find the closest point. ArrayList<Double> point = points.get(jj); double mindist2 = 1e100; for (int kk = 0; kk < K; ++kk) { // Compute the distance to the kk'th center. double dist2 = squared_dist(centers.get(kk), point); if (dist2 < mindist2) { assignments[jj] = kk; mindist2 = dist2; } } if (assignments[jj] != old_assignment) { changed = true; } center_counts[assignments[jj]]++; } if (!changed) { break; } // Compute the means. for (int kk = 0; kk < K; ++kk) { for (int mm = 0; mm < M; ++mm) { centers.get(kk).set(mm, 0.0); } } for (int jj = 0; jj < N; ++jj) { ArrayList<Double> point = points.get(jj); ArrayList<Double> center = centers.get(assignments[jj]); for (int mm = 0; mm < M; ++mm) { center.set(mm, center.get(mm) + point.get(mm) / center_counts[assignments[jj]]); } } } // Compute sd and number of points in each cluster.. ArrayList<ArrayList<Double>> extra_stats = new ArrayList<ArrayList<Double>>(); for (int kk = 0; kk < K; ++kk) { ArrayList<Double> extra_stat = new ArrayList<Double>(); extra_stat.add(0.0); extra_stat.add(0.0); extra_stats.add(extra_stat); } for (int jj = 0; jj < N; ++jj) { ArrayList<Double> point = points.get(jj); ArrayList<Double> center = centers.get(assignments[jj]); ArrayList<Double> extra_stat = extra_stats.get(assignments[jj]); extra_stat.set(1, extra_stat.get(1) + 1); extra_stat.set(0, extra_stat.get(0) + squared_dist(center, point)); } for (int kk = 0; kk < K; ++kk) { double num_points = extra_stats.get(kk).get(1); centers.get(kk).add(extra_stats.get(kk).get(0) / num_points); centers.get(kk).add(num_points); } return centers; } }