package org.seqcode.ml.clustering.kmeans; import java.util.Collection; import java.util.Vector; import org.seqcode.ml.clustering.Cluster; import org.seqcode.ml.clustering.ClusterRepresentative; import org.seqcode.ml.clustering.ClusteringMethod; import org.seqcode.ml.clustering.DefaultCluster; import org.seqcode.ml.clustering.PairwiseElementMetric; import org.seqcode.ml.clustering.vectorcluster.VectorClusterElement; import java.util.Iterator; /** * @author Timothy Danford */ public class KMeansClustering<X> implements ClusteringMethod<X> { private PairwiseElementMetric<X> metric; private ClusterRepresentative<X> repr; private Vector<X> startMeans; private int numClusters; private int iterations; private Vector<X> elmts; private Vector<DefaultCluster<X>> clusters; private Vector<X> clusterMeans; public KMeansClustering(PairwiseElementMetric<X> m, ClusterRepresentative<X> r, Collection<X> starts) { metric = m; repr = r; numClusters = starts.size(); clusters = new Vector<DefaultCluster<X>>(); for(int c=0; c<numClusters; c++){clusters.add(new DefaultCluster<X>());} clusterMeans = new Vector<X>(starts); startMeans = new Vector<X>(starts); iterations = 10; elmts = new Vector<X>(); } public void setIterations(int i) { iterations = i; } public Collection<Cluster<X>> clusterElements(Collection<X> e) {return(clusterElements(e, 0));} public Collection<Cluster<X>> clusterElements(Collection<X> e, double convergenceDifference) { init(e); boolean converged=false; for(int i = 0; i < iterations && !converged; i++) { Vector<X> oldClusterMeans = (Vector<X>) clusterMeans.clone(); //K-means assignToClusters(); getClusterMeans(); //Check convergence double totalDist=0; for(int c = 0; c < numClusters; c++) { totalDist+= Math.abs(metric.evaluate(oldClusterMeans.get(c), clusterMeans.get(c))); }if(totalDist<=convergenceDifference) converged=true; } return new Vector<Cluster<X>>(clusters); } private void assignToClusters() { for(int i = 0; i < numClusters; i++) { clusters.get(i).clear(); } for(int k = 0; k < elmts.size(); k++) { X e = elmts.get(k); int minCluster = -1; double minDist = 0.0; for(int i = 0; i < numClusters; i++) { double clustDist = metric.evaluate(e, clusterMeans.get(i)); if(minCluster == -1 || clustDist < minDist) { minDist = clustDist; minCluster = i; } } clusters.get(minCluster).addElement(e); } } public Vector<X> getClusterMeans() { for(int i = 0; i < numClusters; i++) { clusterMeans.set(i, repr.getRepresentative(clusters.get(i))); }return(clusterMeans); } private void init(Collection<X> e) { elmts = new Vector<X>(e); for(int i = 0; i < numClusters; i++) { clusters.set(i, new DefaultCluster<X>()); clusterMeans.set(i, startMeans.get(i)); } } public double sumOfSquaredDistance(){ double totalDist =0; for(int k = 0; k < elmts.size(); k++) { X e = elmts.get(k); int minCluster = -1; double minDist = 0.0; for(int i = 0; i < numClusters; i++) { double clustDist = metric.evaluate(e, clusterMeans.get(i)); if(minCluster == -1 || clustDist < minDist) { minDist = clustDist; minCluster = i; } }totalDist += minDist*minDist; } return(totalDist); } public double silhouette(){ double sh = 0.0; for(int k=0; k< elmts.size(); k++){ X e = elmts.get(k); double[] avgDist = new double[numClusters]; int minCluster = -1; double minDist = 0.0; for(int i = 0; i < numClusters; i++) { double clustDist = metric.evaluate(e, clusterMeans.get(i)); if(minCluster == -1 || clustDist < minDist) { minDist = clustDist; minCluster = i; } double currDist = 0.0; for(X ve : clusters.get(i).getElements()){ currDist = currDist + metric.evaluate(e, ve); } currDist = currDist/clusters.get(i).size(); avgDist[i] = currDist; } double neighborDistance = Double.MAX_VALUE; for(int i=0; i<numClusters; i++){ if(i!=minCluster){ // Not its cluster if(avgDist[i]<neighborDistance){ neighborDistance = avgDist[i]; } } } sh = sh + (neighborDistance - avgDist[minCluster])/Math.max(neighborDistance, avgDist[minCluster]); } sh = sh/elmts.size(); return sh; } }