package cc.mallet.cluster; import java.util.logging.Logger; import cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator; import cc.mallet.cluster.util.ClusterUtils; import cc.mallet.pipe.Pipe; import cc.mallet.types.Instance; import cc.mallet.util.MalletProgressMessageLogger; import gnu.trove.TIntArrayList; /** * Greedily merges Instances until convergence. New merges are scored * using {@link NeighborEvaluator}. * * Differs from {@link GreedyAgglomerative} in that one cluster is * created at a time. That is, nodes are added to a cluster until * convergence. Then, a new cluster is created from the remaining * nodes. This reduces the number of comparisons from O(n^2) to * O(nlg|n|). * * @author "Aron Culotta" <culotta@degas.cs.umass.edu> * @version 1.0 * @since 1.0 * @see GreedyAgglomerative */ public class GreedyAgglomerativeByDensity extends GreedyAgglomerative { private static final long serialVersionUID = 1L; private static Logger progressLogger = MalletProgressMessageLogger.getLogger(GreedyAgglomerativeByDensity.class.getName()+"-pl"); /** * If true, perform greedy agglomerative clustering on the clusters * at the end of convergence. This may alleviate the greediness of * the byDensity clustering algorithm. */ boolean doPostConvergenceMerges; /** * Integers representing the Instance indices that have not yet been placed in a cluster. */ TIntArrayList unclusteredInstances; /** * Index of an Instance in the cluster currently being created. */ int instanceBeingClustered; /** * Randomness to order instanceBeingClustered. */ java.util.Random random; /** * * @param instancePipe Pipe for each underying {@link Instance}. * @param evaluator To score potential merges. * @param stoppingThreshold Clustering converges when the evaluator score is below this value. * @param doPostConvergenceMerges If true, perform greedy * agglomerative clustering on the clusters at the end of * convergence. This may alleviate the greediness of the byDensity * clustering algorithm. * @return */ public GreedyAgglomerativeByDensity (Pipe instancePipe, NeighborEvaluator evaluator, double stoppingThreshold, boolean doPostConvergenceMerges, java.util.Random random) { super(instancePipe, evaluator, stoppingThreshold); this.doPostConvergenceMerges = doPostConvergenceMerges; this.random = random; this.instanceBeingClustered = -1; } public boolean converged (Clustering clustering) { return converged; } /** * Reset convergence to false and clear state so a new round of * clustering can begin. */ public void reset () { super.reset(); this.unclusteredInstances = null; this.instanceBeingClustered = -1; } public Clustering improveClustering (Clustering clustering) { if (instanceBeingClustered == -1) sampleNextInstanceToCluster(clustering); int clusterIndex = clustering.getLabel(instanceBeingClustered); double bestScore = Double.NEGATIVE_INFINITY; int clusterToMerge = -1; int instanceToMerge = -1; for (int i = 0; i < unclusteredInstances.size(); i++) { int neighbor = unclusteredInstances.get(i); int neighborCluster = clustering.getLabel(neighbor); double score = getScore(clustering, clusterIndex, neighborCluster); if (score > bestScore) { bestScore = score; clusterToMerge = neighborCluster; instanceToMerge = neighbor; } } if (bestScore < stoppingThreshold) { // Move on to next instance to cluster. sampleNextInstanceToCluster(clustering); if (instanceBeingClustered != -1 && unclusteredInstances.size() != 0) return improveClustering(clustering); else { // Converged and no more instances to cluster. if (doPostConvergenceMerges) { throw new UnsupportedOperationException("PostConvergenceMerges not yet implemented."); } converged = true; progressLogger.info("Converged with score " + bestScore); } } else { // Merge and continue. progressLogger.info("Merging " + clusterIndex + "(" + clustering.size(clusterIndex) + " nodes) and " + clusterToMerge + "(" + clustering.size(clusterToMerge) + " nodes) [" + bestScore + "] numClusters=" + clustering.getNumClusters()); updateScoreMatrix(clustering, clusterIndex, clusterToMerge); unclusteredInstances.remove(unclusteredInstances.indexOf(instanceToMerge)); clustering = ClusterUtils.mergeClusters(clustering, clusterIndex, clusterToMerge); } return clustering; } private void sampleNextInstanceToCluster (Clustering clustering) { if (unclusteredInstances == null) fillUnclusteredInstances(clustering.getNumInstances()); instanceBeingClustered = (unclusteredInstances.size() == 0) ? -1 : unclusteredInstances.remove(0); } private void fillUnclusteredInstances (int size) { unclusteredInstances = new TIntArrayList(size); for (int i = 0; i < size; i++) unclusteredInstances.add(i); unclusteredInstances.shuffle(random); } public String toString () { return "class=" + this.getClass().getName() + "\nstoppingThreshold=" + stoppingThreshold + "\ndoPostConvergenceMerges=" + doPostConvergenceMerges + "\nneighborhoodEvaluator=[" + evaluator + "]"; } }