/* * Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. * This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). * http://www.cs.umass.edu/~mccallum/mallet This software is provided under the * terms of the Common Public License, version 1.0, as published by * http://www.opensource.org. For further information, see the file `LICENSE' * included with this distribution. */ /** * Clusters a set of point via k-Means. The instances that are clustered are * expected to be of the type FeatureVector. * * EMPTY_SINGLE and other changes implemented March 2005 Heuristic cluster * selection implemented May 2005 * * @author Jerod Weinman <A * HREF="mailto:weinman@cs.umass.edu">weinman@cs.umass.edu</A> * @author Mike Winter <a href = * "mailto:mike.winter@gmail.com">mike.winter@gmail.com</a> * */ package cc.mallet.cluster; import java.util.ArrayList; import java.util.Random; import java.util.logging.Logger; import cc.mallet.pipe.Pipe; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.types.Metric; import cc.mallet.types.SparseVector; import cc.mallet.util.VectorStats; /** * KMeans Clusterer * * Clusters the points into k clusters by minimizing the total intra-cluster * variance. It uses a given {@link Metric} to find the distance between * {@link Instance}s, which should have {@link SparseVector}s in the data * field. * */ public class KMeans extends Clusterer { private static final long serialVersionUID = 1L; // Stop after movement of means is less than this static double MEANS_TOLERANCE = 1e-2; // Maximum number of iterations static int MAX_ITER = 100; // Minimum fraction of points that move static double POINTS_TOLERANCE = .005; /** * Treat an empty cluster as an error condition. */ public static final int EMPTY_ERROR = 0; /** * Drop an empty cluster */ public static final int EMPTY_DROP = 1; /** * Place the single instance furthest from the previous cluster mean */ public static final int EMPTY_SINGLE = 2; Random randinator; Metric metric; int numClusters; int emptyAction; ArrayList<SparseVector> clusterMeans; private static Logger logger = Logger .getLogger("edu.umass.cs.mallet.base.cluster.KMeans"); /** * Construct a KMeans object * * @param instancePipe Pipe for the instances being clustered * @param numClusters Number of clusters to use * @param metric Metric object to measure instance distances * @param emptyAction Specify what should happen when an empty cluster occurs */ public KMeans(Pipe instancePipe, int numClusters, Metric metric, int emptyAction) { super(instancePipe); this.emptyAction = emptyAction; this.metric = metric; this.numClusters = numClusters; this.clusterMeans = new ArrayList<SparseVector>(numClusters); this.randinator = new Random(); } /** * Construct a KMeans object * * @param instancePipe Pipe for the instances being clustered * @param numClusters Number of clusters to use * @param metric Metric object to measure instance distances <p/> If an empty * cluster occurs, it is considered an error. */ public KMeans(Pipe instancePipe, int numClusters, Metric metric) { this(instancePipe, numClusters, metric, EMPTY_ERROR); } /** * Cluster instances * * @param instances List of instances to cluster */ @Override public Clustering cluster(InstanceList instances) { assert (instances.getPipe() == this.instancePipe); // Initialize clusterMeans initializeMeansSample(instances, this.metric); int clusterLabels[] = new int[instances.size()]; ArrayList<InstanceList> instanceClusters = new ArrayList<InstanceList>( numClusters); int instClust; double instClustDist, instDist; double deltaMeans = Double.MAX_VALUE; double deltaPoints = (double) instances.size(); int iterations = 0; SparseVector clusterMean; for (int c = 0; c < numClusters; c++) { instanceClusters.add(c, new InstanceList(instancePipe)); } logger.info("Entering KMeans iteration"); while (deltaMeans > MEANS_TOLERANCE && iterations < MAX_ITER && deltaPoints > instances.size() * POINTS_TOLERANCE) { iterations++; deltaPoints = 0; // For each instance, measure its distance to the current cluster // means, and subsequently assign it to the closest cluster // by adding it to an corresponding instance list // The mean of each cluster InstanceList is then updated. for (int n = 0; n < instances.size(); n++) { instClust = 0; instClustDist = Double.MAX_VALUE; for (int c = 0; c < numClusters; c++) { instDist = metric.distance(clusterMeans.get(c), (SparseVector) instances.get(n).getData()); if (instDist < instClustDist) { instClust = c; instClustDist = instDist; } } // Add to closest cluster & label it such instanceClusters.get(instClust).add(instances.get(n)); if (clusterLabels[n] != instClust) { clusterLabels[n] = instClust; deltaPoints++; } } deltaMeans = 0; for (int c = 0; c < numClusters; c++) { if (instanceClusters.get(c).size() > 0) { clusterMean = VectorStats.mean(instanceClusters.get(c)); deltaMeans += metric.distance(clusterMeans.get(c), clusterMean); clusterMeans.set(c, clusterMean); instanceClusters.set(c, new InstanceList(instancePipe)); } else { logger.info("Empty cluster found."); switch (emptyAction) { case EMPTY_ERROR: return null; case EMPTY_DROP: logger.fine("Removing cluster " + c); clusterMeans.remove(c); instanceClusters.remove(c); for (int n = 0; n < instances.size(); n++) { assert (clusterLabels[n] != c) : "Cluster size is " + instanceClusters.get(c).size() + "+ yet clusterLabels[n] is " + clusterLabels[n]; if (clusterLabels[n] > c) clusterLabels[n]--; } numClusters--; c--; // <-- note this trickiness. bad style? maybe. // it just means now that we've deleted the entry, // we have to repeat the index to get the next entry. break; case EMPTY_SINGLE: // Get the instance the furthest from any centroid // and make it a new centroid. double newCentroidDist = 0; int newCentroid = 0; InstanceList cacheList = null; for (int clusters = 0; clusters < clusterMeans.size(); clusters++) { SparseVector centroid = clusterMeans.get(clusters); InstanceList centInstances = instanceClusters.get(clusters); // Dont't create new empty clusters. if (centInstances.size() <= 1) continue; for (int n = 0; n < centInstances.size(); n++) { double currentDist = metric.distance(centroid, (SparseVector) centInstances.get(n).getData()); if (currentDist > newCentroidDist) { newCentroid = n; newCentroidDist = currentDist; cacheList = centInstances; } } } if (cacheList == null) { logger.info("Can't find an instance to move. Exiting."); // Can't find an instance to move. return null; } else clusterMeans.set(c, (SparseVector) cacheList.get( newCentroid).getData()); default: return null; } } } logger.info("Iter " + iterations + " deltaMeans = " + deltaMeans); } if (deltaMeans <= MEANS_TOLERANCE) logger.info("KMeans converged with deltaMeans = " + deltaMeans); else if (iterations >= MAX_ITER) logger.info("Maximum number of iterations (" + MAX_ITER + ") reached."); else if (deltaPoints <= instances.size() * POINTS_TOLERANCE) logger.info("Minimum number of points (np*" + POINTS_TOLERANCE + "=" + (int) (instances.size() * POINTS_TOLERANCE) + ") moved in last iteration. Saying converged."); return new Clustering(instances, numClusters, clusterLabels); } /** * Uses a MAX-MIN heuristic to seed the initial cluster means.. * * @param instList List of instances. * @param metric Distance metric. */ private void initializeMeansSample(InstanceList instList, Metric metric) { // InstanceList has no remove() and null instances aren't // parsed out by most Pipes, so we have to pre-process // here and possibly leave some instances without // cluster assignments. ArrayList<Instance> instances = new ArrayList<Instance>(instList.size()); for (int i = 0; i < instList.size(); i++) { Instance ins = instList.get(i); SparseVector sparse = (SparseVector) ins.getData(); if (sparse.numLocations() == 0) continue; instances.add(ins); } // Add next center that has the MAX of the MIN of the distances from // each of the previous j-1 centers (idea from Andrew Moore tutorial, // not sure who came up with it originally) for (int i = 0; i < numClusters; i++) { double max = 0; int selected = 0; for (int k = 0; k < instances.size(); k++) { double min = Double.MAX_VALUE; Instance ins = instances.get(k); SparseVector inst = (SparseVector) ins.getData(); for (int j = 0; j < clusterMeans.size(); j++) { SparseVector centerInst = clusterMeans.get(j); double dist = metric.distance(centerInst, inst); if (dist < min) min = dist; } if (min > max) { selected = k; max = min; } } Instance newCenter = instances.remove(selected); clusterMeans.add((SparseVector) newCenter.getData()); } } /** * Return the ArrayList of cluster means after a run of the algorithm. * * @return An ArrayList of Instances. */ public ArrayList<SparseVector> getClusterMeans() { return this.clusterMeans; } }