package cc.mallet.cluster; import java.util.logging.Logger; import cc.mallet.cluster.neighbor_evaluator.AgglomerativeNeighbor; import cc.mallet.cluster.neighbor_evaluator.Neighbor; import cc.mallet.cluster.neighbor_evaluator.NeighborEvaluator; import cc.mallet.cluster.util.ClusterUtils; import cc.mallet.cluster.util.PairwiseMatrix; import cc.mallet.pipe.Pipe; import cc.mallet.types.Instance; import cc.mallet.types.InstanceList; import cc.mallet.util.MalletProgressMessageLogger; /** * Greedily merges Instances until convergence. New merges are scored * using {@link NeighborEvaluator}. * * @author "Aron Culotta" <culotta@degas.cs.umass.edu> * @version 1.0 * @since 1.0 * @see HillClimbingClusterer */ public class GreedyAgglomerative extends HillClimbingClusterer { private static final long serialVersionUID = 1L; private static Logger progressLogger = MalletProgressMessageLogger.getLogger(GreedyAgglomerative.class.getName()+"-pl"); /** * Converged when merge score is below this value. */ protected double stoppingThreshold; /** * True if should stop clustering. */ protected boolean converged; /** * Cache for calls to {@link NeighborhoodEvaluator}. In some * experiments, reduced running time by nearly half. */ protected PairwiseMatrix scoreCache; /** * * @param instancePipe Pipe for each underying {@link Instance}. * @param evaluator To score potential merges. * @param stoppingThreshold Clustering converges when the evaluator score is below this value. * @return */ public GreedyAgglomerative (Pipe instancePipe, NeighborEvaluator evaluator, double stoppingThreshold) { super(instancePipe, evaluator); this.stoppingThreshold = stoppingThreshold; this.converged = false; } /** * * @param instances * @return A singleton clustering (each Instance in its own cluster). */ public Clustering initializeClustering (InstanceList instances) { reset(); return ClusterUtils.createSingletonClustering(instances); } public boolean converged (Clustering clustering) { return converged; } /** * Reset convergence to false so a new round of clustering can begin. */ public void reset () { converged = false; scoreCache = null; evaluator.reset(); } /** * For each pair of clusters, calculate the score of the {@link Neighbor} * that would result from merging the two clusters. Choose the merge that * obtains the highest score. If no merge improves score, return original * Clustering * * @param clustering * @return */ public Clustering improveClustering (Clustering clustering) { double bestScore = Double.NEGATIVE_INFINITY; int[] toMerge = new int[]{-1,-1}; for (int i = 0; i < clustering.getNumClusters(); i++) { for (int j = i + 1; j < clustering.getNumClusters(); j++) { double score = getScore(clustering, i, j); if (score > bestScore) { bestScore = score; toMerge[0] = i; toMerge[1] = j; } } } converged = (bestScore < stoppingThreshold); if (!(converged)) { progressLogger.info("Merging " + toMerge[0] + "(" + clustering.size(toMerge[0]) + " nodes) and " + toMerge[1] + "(" + clustering.size(toMerge[1]) + " nodes) [" + bestScore + "] numClusters=" + clustering.getNumClusters()); updateScoreMatrix(clustering, toMerge[0], toMerge[1]); clustering = ClusterUtils.mergeClusters(clustering, toMerge[0], toMerge[1]); } else { progressLogger.info("Converged with score " + bestScore); } return clustering; } /** * * @param clustering * @param i * @param j * @return The score for merging these two clusters. */ protected double getScore (Clustering clustering, int i, int j) { if (scoreCache == null) scoreCache = new PairwiseMatrix(clustering.getNumInstances()); int[] ci = clustering.getIndicesWithLabel(i); int[] cj = clustering.getIndicesWithLabel(j); if (scoreCache.get(ci[0], cj[0]) == 0.0) { double val = evaluator.evaluate( new AgglomerativeNeighbor(clustering, ClusterUtils.copyAndMergeClusters(clustering, i, j), ci, cj)); for (int ni = 0; ni < ci.length; ni++) for (int nj = 0; nj < cj.length; nj++) scoreCache.set(ci[ni], cj[nj], val); } return scoreCache.get(ci[0], cj[0]); } /** * Resets the values of clusters that have been merged. * @param clustering * @param i * @param j */ protected void updateScoreMatrix (Clustering clustering, int i, int j) { int size = clustering.getNumInstances(); int[] ci = clustering.getIndicesWithLabel(i); for (int ni = 0; ni < ci.length; ni++) { for (int nj = 0; nj < size; nj++) if (ci[ni] != nj) scoreCache.set(ci[ni], nj, 0.0); } int[] cj = clustering.getIndicesWithLabel(j); for (int ni = 0; ni < cj.length; ni++) { for (int nj = 0; nj < size; nj++) if (cj[ni] != nj) scoreCache.set(cj[ni], nj, 0.0); } } public String toString () { return "class=" + this.getClass().getName() + "\nstoppingThreshold=" + stoppingThreshold + "\nneighborhoodEvaluator=[" + evaluator + "]"; } }