/**
* Copyright 2013-2015 Pierre Merienne
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.pmerienne.trident.ml.clustering;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import com.github.pmerienne.trident.ml.util.MathUtil;
/**
*
* Inspired from
* http://www.cs.princeton.edu/courses/archive/fall08/cos436/Duda/C/sk_means.htm
*
* @author pmerienne
*
*/
public class KMeans implements Clusterer, Serializable {
private static final long serialVersionUID = 338231277453149972L;
private List<Long> counts = null;
private double[][] centroids;
private List<double[]> initFeatures = new ArrayList<double[]>();
private Integer nbCluster;
public KMeans(Integer nbCluster) {
this.nbCluster = nbCluster;
}
@Override
public Integer classify(double[] features) {
if (!this.isReady()) {
throw new IllegalStateException("KMeans is not ready yet");
}
// Find nearest centroid
Integer nearestCentroidIndex = this.nearestCentroid(features);
return nearestCentroidIndex;
}
@Override
public Integer update(double[] features) {
if (!this.isReady()) {
this.initIfPossible(features);
return null;
} else {
Integer nearestCentroid = this.classify(features);
// Increment count
this.counts.set(nearestCentroid, this.counts.get(nearestCentroid) + 1);
// Move centroid
double[] update = MathUtil.mult(MathUtil.subtract(features, this.centroids[nearestCentroid]), 1.0 / this.counts.get(nearestCentroid));
this.centroids[nearestCentroid] = MathUtil.add(this.centroids[nearestCentroid], update);
return nearestCentroid;
}
}
@Override
public double[] distribution(double[] features) {
if (!this.isReady()) {
throw new IllegalStateException("KMeans is not ready yet");
}
double[] distribution = new double[this.nbCluster];
double[] currentCentroid;
for (int i = 0; i < this.nbCluster; i++) {
currentCentroid = this.centroids[i];
distribution[i] = MathUtil.euclideanDistance(currentCentroid, features);
}
return distribution;
}
@Override
public double[][] getCentroids() {
return this.centroids;
}
protected Integer nearestCentroid(double[] features) {
// Find nearest centroid
Integer nearestCentroidIndex = 0;
Double minDistance = Double.MAX_VALUE;
double[] currentCentroid;
Double currentDistance;
for (int i = 0; i < this.centroids.length; i++) {
currentCentroid = this.centroids[i];
if (currentCentroid != null) {
currentDistance = MathUtil.euclideanDistance(currentCentroid, features);
if (currentDistance < minDistance) {
minDistance = currentDistance;
nearestCentroidIndex = i;
}
}
}
return nearestCentroidIndex;
}
protected boolean isReady() {
boolean countsReady = this.counts != null;
boolean centroidsReady = this.centroids != null;
return countsReady && centroidsReady;
}
protected void initIfPossible(double[] features) {
this.initFeatures.add(features);
// magic number : 10 ??!
if (this.initFeatures.size() >= 10 * this.nbCluster) {
this.initCentroids();
}
}
/**
* Init clusters using the k-means++ algorithm. (Arthur, D. and
* Vassilvitskii, S. (2007). "k-means++: the advantages of careful seeding".
*
*/
protected void initCentroids() {
// Init counts
this.counts = new ArrayList<Long>(this.nbCluster);
for (int i = 0; i < this.nbCluster; i++) {
this.counts.add(0L);
}
this.centroids = new double[this.nbCluster][];
Random random = new Random();
// Choose one centroid uniformly at random from among the data points.
final double[] firstCentroid = this.initFeatures.remove(random.nextInt(this.initFeatures.size()));
this.centroids[0] = firstCentroid;
double[] dxs;
for (int j = 1; j < this.nbCluster; j++) {
// For each data point x, compute D(x)
dxs = this.computeDxs();
// Add one new data point as a center.
double[] features;
double r = random.nextDouble() * dxs[dxs.length - 1];
for (int i = 0; i < dxs.length; i++) {
if (dxs[i] >= r) {
features = this.initFeatures.remove(i);
this.centroids[j] = features;
break;
}
}
}
this.initFeatures.clear();
}
/**
* For each features in {@link KMeans#initFeatures}, compute D(x), the
* distance between x and the nearest center that has already been chosen.
*
* @return
*/
protected double[] computeDxs() {
double[] dxs = new double[this.initFeatures.size()];
int sum = 0;
double[] features;
int nearestCentroidIndex;
double[] nearestCentroid;
for (int i = 0; i < this.initFeatures.size(); i++) {
features = this.initFeatures.get(i);
nearestCentroidIndex = this.nearestCentroid(features);
nearestCentroid = this.centroids[nearestCentroidIndex];
sum += Math.pow(MathUtil.euclideanDistance(features, nearestCentroid), 2);
dxs[i] = sum;
}
return dxs;
}
@Override
public void reset() {
this.counts = null;
this.centroids = null;
this.initFeatures = new ArrayList<double[]>();
}
}