/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.cluster;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import com.carrotsearch.hppc.cursors.IntCursor;
import edu.emory.clir.clearnlp.collection.pair.DoubleIntPair;
import edu.emory.clir.clearnlp.collection.pair.ObjectDoublePair;
import edu.emory.clir.clearnlp.collection.set.IntHashSet;
import edu.emory.clir.clearnlp.util.BinUtils;
import edu.emory.clir.clearnlp.util.MathUtils;
/**
* Kmeans++ algorithm.
* @since 3.0.3
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
public class KmeansClustering extends AbstractCluster
{
final protected int K;
final protected int NUM_THREADS;
final protected int MAX_ITERATIONS;
final protected double RSS_THRESHOLD;
volatile double[] D2;
public KmeansClustering(int k, int maxIterations, double rssThreshold, int numThreads)
{
super();
K = k;
MAX_ITERATIONS = maxIterations;
RSS_THRESHOLD = rssThreshold;
NUM_THREADS = numThreads;
}
@Override
public List<Cluster> cluster()
{
ObjectDoublePair<List<Cluster>> current = new ObjectDoublePair<>(null, 0);
List<SparseVector> centroids = initialization();
double previousRSS;
for (int i=0; i<MAX_ITERATIONS; i++)
{
BinUtils.LOG.info(String.format("Iteration: %d\n", i));
previousRSS = current.d;
current = maximization(centroids);
centroids = expectation(current.o);
if (current.d - previousRSS < RSS_THRESHOLD) break;
}
return current.o;
}
// ==================================== Initialization ====================================
private List<SparseVector> initialization()
{
IntHashSet centroidSet = new IntHashSet();
int newCentroid, i, N = s_points.size();
Random rand = new Random(1);
double[] cum;
double sum;
BinUtils.LOG.info("Initialization:");
D2 = new double[N];
Arrays.fill(D2, Double.MAX_VALUE);
centroidSet.add(newCentroid = rand.nextInt(N));
D2[newCentroid] = 0;
while (centroidSet.size() < K)
{
sum = computeD2(centroidSet, newCentroid);
cum = cumulativeD2(N, sum);
for (i=0; i<N; i++)
{
if (!centroidSet.contains(i) && rand.nextDouble() < cum[i])
{
centroidSet.add(newCentroid = i);
D2[newCentroid] = 0;
break;
}
}
}
D2 = null;
BinUtils.LOG.info(centroidSet.toString()+"\n");
List<SparseVector> centroids = new ArrayList<>(K);
for (IntCursor c : centroidSet) centroids.add(s_points.get(c.value));
return centroids;
}
private double computeD2(IntHashSet centroidSet, int newCentroidID)
{
ExecutorService pool = Executors.newFixedThreadPool(NUM_THREADS);
List<Future<Double>> list = new ArrayList<>(NUM_THREADS);
SparseVector newCentroid = s_points.get(newCentroidID);
double centroidNorm = newCentroid.euclideanNorm();
int i, j, N = s_points.size(), GAP = gap();
Callable<Double> task;
for (i=0; i<N; i+=GAP)
{
if ((j = i + GAP) > N) j = N;
task = new IntializationTask(centroidSet, newCentroid, centroidNorm, i, j);
list.add(pool.submit(task));
}
double sum = 0;
try {for (Future<Double> f : list) sum += f.get();}
catch (Exception e) {e.printStackTrace();}
return sum;
}
private double[] cumulativeD2(int N, double sum)
{
double[] c = new double[N];
c[0] = D2[0] / sum;
for (int i=1; i<N; i++)
c[i] = c[i-1] + D2[i]/sum;
return c;
}
private class IntializationTask implements Callable<Double>
{
private IntHashSet centroid_set;
private SparseVector new_centroid;
private int begin_index;
private int end_index;
private double centroid_norm;
public IntializationTask(IntHashSet centroidSet, SparseVector newCentroid, double centroidNorm, int beginIndex, int endIndex)
{
centroid_set = centroidSet;
new_centroid = newCentroid;
centroid_norm = centroidNorm;
begin_index = beginIndex;
end_index = endIndex;
}
public Double call()
{
SparseVector point;
double sum = 0;
for (int i=begin_index; i<end_index; i++)
{
if (centroid_set.contains(i)) continue;
point = s_points.get(i);
D2[i] = Math.min(D2[i], 1-cosineSimilarity(new_centroid, centroid_norm, point, point.euclideanNorm()));
sum += D2[i];
}
return sum;
}
}
// ==================================== Maximization ====================================
private ObjectDoublePair<List<Cluster>> maximization(List<SparseVector> centroids)
{
List<Future<ObjectDoublePair<List<Cluster>>>> list = new ArrayList<>(NUM_THREADS);
ExecutorService pool = Executors.newFixedThreadPool(NUM_THREADS);
double[] centroidNorms = euclideanNorms(centroids);
Callable<ObjectDoublePair<List<Cluster>>> task;
int i, j, N = s_points.size(), GAP = gap();
BinUtils.LOG.info("- Maximization: ");
for (i=0; i<N; i+=GAP)
{
if ((j = i + GAP) > N) j = N;
task = new MaximizationTask(centroids, centroidNorms, i, j);
list.add(pool.submit(task));
}
List<Cluster> clusters = centroids.stream().map(c -> new Cluster()).collect(Collectors.toCollection(ArrayList::new));
ObjectDoublePair<List<Cluster>> p;
double rss = 0;
try
{
for (Future<ObjectDoublePair<List<Cluster>>> f : list)
{
p = f.get();
rss += p.d;
for (i=0; i<K; i++) clusters.get(i).merge(p.o.get(i));
}
}
catch (Exception e) {e.printStackTrace();}
BinUtils.LOG.info(String.format("%f\n", rss));
return new ObjectDoublePair<List<Cluster>>(clusters, rss);
}
private class MaximizationTask implements Callable<ObjectDoublePair<List<Cluster>>>
{
List<SparseVector> centroid_list;
private double[] centroid_norms;
private int begin_index;
private int end_index;
public MaximizationTask(List<SparseVector> centroidList, double[] centNorms, int beginIndex, int endIndex)
{
centroid_list = centroidList;
centroid_norms = centNorms;
begin_index = beginIndex;
end_index = endIndex;
}
public ObjectDoublePair<List<Cluster>> call()
{
List<Cluster> clusters = centroid_list.stream().map(c -> new Cluster()).collect(Collectors.toCollection(ArrayList::new));
DoubleIntPair max = new DoubleIntPair(0, 0);
double rss = 0;
for (int i=begin_index; i<end_index; i++)
{
max = max(centroid_list, centroid_norms, s_points.get(i));
clusters.get(max.i).addPoint(s_points.get(i));
rss += max.d;
}
return new ObjectDoublePair<>(clusters, rss);
}
}
private DoubleIntPair max(List<SparseVector> centroids, double[] centroidNorms, SparseVector point)
{
DoubleIntPair max = new DoubleIntPair(-10000d, 0);
double d, pointNorm = point.euclideanNorm();
for (int k=centroidNorms.length-1; k>=0; k--)
{
d = cosineSimilarity(centroids.get(k), centroidNorms[k], point, pointNorm);
if (d > max.d) max.set(d, k);
}
return max;
}
private double[] euclideanNorms(List<SparseVector> points)
{
return points.stream().mapToDouble(point -> point.euclideanNorm()).toArray();
}
// ==================================== Expectation ====================================
private List<SparseVector> expectation(List<Cluster> clusters)
{
ExecutorService pool = Executors.newFixedThreadPool(NUM_THREADS);
List<Future<SparseVector>> list = new ArrayList<>(K);
Callable<SparseVector> task;
BinUtils.LOG.info("- Expectation\n");
for (int i=0; i<K; i++)
{
task = new ExpectationTask(clusters.get(i));
list.add(pool.submit(task));
}
List<SparseVector> centroids = new ArrayList<>();
try {for (Future<SparseVector> f : list) centroids.add(f.get());}
catch (Exception e) {e.printStackTrace();}
return centroids;
}
private class ExpectationTask implements Callable<SparseVector>
{
private Cluster cluster;
public ExpectationTask(Cluster cluster)
{
this.cluster = cluster;
}
public SparseVector call()
{
SparseVector centroid = new SparseVector(-1);
for (SparseVector v : cluster.getPointSet()) centroid.add(v);
centroid.divide(cluster.size());
return centroid;
}
}
private double cosineSimilarity(SparseVector centroid, double centNorm, SparseVector point, double pointNorm)
{
return centroid.dotProduct(point) / (centNorm * pointNorm);
}
private int gap()
{
return (int)Math.ceil(MathUtils.divide(s_points.size(), NUM_THREADS));
}
}