package clear.util.cluster; import clear.util.tuple.JIntDoubleTuple; import com.carrotsearch.hppc.IntOpenHashSet; import com.carrotsearch.hppc.cursors.IntCursor; import java.util.ArrayList; import java.util.Arrays; import java.util.Random; public class Kmeans { private final int RAND_SEED = 0; private int N, D; private JIntDoubleTuple[][] v_unit; private double[] d_centroid; private double[] d_scala; public Kmeans(JIntDoubleTuple[][] unit, int dimension) { v_unit = unit; N = unit.length; D = dimension; } /** * K-means clustering. * * @param threshold minimum RSS. * @return each row represents a cluster, and each column represents a tuple * of (index of a unit vector, similarity to the centroid). */ public ArrayList<ArrayList<JIntDoubleTuple>> cluster(int K, double threshold) { ArrayList<ArrayList<JIntDoubleTuple>> currCluster = null; ArrayList<ArrayList<JIntDoubleTuple>> prevCluster = null; double prevRss = -1, currRss; initCentroids(K); int max = N / K; for (int iter = 0; iter < max; iter++) { System.out.println("\n===== Iteration: " + iter + " =====\n"); currCluster = getClusters(K); updateCentroids(K, currCluster); currRss = getRSS(K, currCluster); if (prevRss >= currRss) { return prevCluster; } if (currRss >= threshold) { break; } prevRss = currRss; prevCluster = currCluster; } return currCluster; } /** * Initializes random centroids. */ private void initCentroids(int K) { IntOpenHashSet set = new IntOpenHashSet(); Random rand = new Random(RAND_SEED); d_centroid = new double[K * D]; d_scala = new double[K]; while (set.size() < K) { set.add(rand.nextInt(N)); } int k = 0; double scala; for (IntCursor cur : set) { scala = 0; for (JIntDoubleTuple tup : v_unit[cur.value]) { d_centroid[getCentroidIndex(k, tup.i)] = tup.d; scala += tup.d * tup.d; } d_scala[k++] = Math.sqrt(scala); } } /** * @return centroid of each cluster. */ private void updateCentroids(int K, ArrayList<ArrayList<JIntDoubleTuple>> cluster) { ArrayList<JIntDoubleTuple> ck; int i, k, size; double scala; Arrays.fill(d_centroid, 0); Arrays.fill(d_scala, 0); System.out.print("Updating centroids: "); for (k = 0; k < K; k++) { ck = cluster.get(k); for (JIntDoubleTuple tup1 : ck) { for (JIntDoubleTuple tup2 : v_unit[tup1.i]) { d_centroid[getCentroidIndex(k, tup2.i)] += tup2.d; } } size = ck.size(); scala = 0; for (i = k * D; i < (k + 1) * D; i++) { if (d_centroid[i] > 0) { d_centroid[i] /= size; scala += d_centroid[i] * d_centroid[i]; } } d_scala[k] = Math.sqrt(scala); System.out.print("."); } System.out.println(); } /** * Each cluster contains indices of {@link Kmeans#v_unit}. */ private ArrayList<ArrayList<JIntDoubleTuple>> getClusters(int K) { ArrayList<ArrayList<JIntDoubleTuple>> cluster = new ArrayList<>(K); JIntDoubleTuple max = new JIntDoubleTuple(-1, -1); JIntDoubleTuple[] unit; int i, k; double sim; for (k = 0; k < K; k++) { cluster.add(new ArrayList<JIntDoubleTuple>()); } System.out.print("Clustering: "); for (i = 0; i < N; i++) { unit = v_unit[i]; max.set(-1, -1); for (k = 0; k < K; k++) { if ((sim = cosine(unit, k)) > max.d) { max.set(k, sim); } } cluster.get(max.i).add(new JIntDoubleTuple(i, max.d)); if (i % 10000 == 0) { System.out.print("."); } } System.out.println(); for (k = 0; k < K; k++) { System.out.printf("%4d: %d\n", k, cluster.get(k).size()); } return cluster; } /** * @param k [0, K-1]. * @param index [0, D-1]. */ private int getCentroidIndex(int k, int index) { return k * D + index; } private double getRSS(int K, ArrayList<ArrayList<JIntDoubleTuple>> cluster) { double sim = 0; System.out.print("Calulating RSS: "); for (int k = 0; k < K; k++) { for (JIntDoubleTuple tup : cluster.get(k)) { sim += cosine(v_unit[tup.i], k); } System.out.print("."); } System.out.println(); sim /= N; System.out.println("RSS = " + sim); return sim / N; } private double cosine(JIntDoubleTuple[] unit, int k) { double scala = 0, dot = 0; for (JIntDoubleTuple tup : unit) { dot += tup.d * d_centroid[getCentroidIndex(k, tup.i)]; scala += tup.d * tup.d; } scala = Math.sqrt(scala); return dot / (scala * d_scala[k]); } }