/* * Copyright 2004-2010 Information & Software Engineering Group (188/1) * Institute of Software Technology and Interactive Systems * Vienna University of Technology, Austria * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.ifs.tuwien.ac.at/dm/somtoolbox/license.html * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package at.tuwien.ifs.somtoolbox.clustering; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.ThreadPoolExecutor; import org.apache.commons.lang.NotImplementedException; import prefuse.data.Node; import prefuse.data.Tree; import at.tuwien.ifs.somtoolbox.SOMToolboxException; import at.tuwien.ifs.somtoolbox.clustering.functions.ClusterElementFunctions; import at.tuwien.ifs.somtoolbox.structures.ElementWithIndex; import at.tuwien.ifs.somtoolbox.util.Indices2D; import at.tuwien.ifs.somtoolbox.util.StdErrProgressWriter; /** * @author Rudolf Mayer * @version $Id: WardClustering.java 3932 2010-11-09 16:56:38Z mayer $ */ public class WardClustering<E> implements HierarchicalClusteringAlgorithm<E> { private int numberOfCPUs = 1; private CountDownLatch doneSignal; private ThreadPoolExecutor e; protected ArrayList<HierarchicalCluster<E>>[] clusterLevels; protected double[] clusterLevelMergeCosts; protected ClusterElementFunctions<E> elementDistance; protected double threshold = Double.MIN_VALUE; protected int targetSize = 1; /** * Stores the clusters; starts from being a list of clusters containing only one element (n elements), up to a final * set of clusters; there will be at least one single cluster containing all elements (if the {@link #targetSize} or * {@link #threshold} permit that), or potentially any other number m <= n. Intermediate results, with numbers of * clusters l, m < l < n, are stored in #clusterLevels. */ protected List<HierarchicalCluster<E>> clusters; /** Turn on some debugging */ protected boolean debug = false; /** Do full clustering tree. */ public WardClustering(ClusterElementFunctions<E> elementDistance) { this.elementDistance = elementDistance; } /** Do clustering with a specified threshold to stop building the tree. */ public WardClustering(ClusterElementFunctions<E> elementDistance, double threshold) { this(elementDistance); this.threshold = threshold; } /** Do clustering with a specified target number of clusters to reach. */ public WardClustering(ClusterElementFunctions<E> elementDistance, int targetSize) { this(elementDistance); this.targetSize = targetSize; } @Override public Tree getPrefuseTree() { Tree tree = new Tree(); tree.addColumn(HierarchicalCluster.COLUMN_NAME_LEVEL, String.class); tree.addColumn(HierarchicalCluster.COLUMN_NAME_CONTENT, String.class); tree.addColumn(HierarchicalCluster.COLUMN_NAME_CONTENT_LONG, String.class); final Node root = tree.addRoot(); if (clusters.size() == 1) { clusters.get(0).buildPrefuseTree(tree, root); } else { for (HierarchicalCluster<E> cluster : clusters) { cluster.buildPrefuseTree(tree, tree.addChild(root)); } } return tree; } @Override public List<HierarchicalCluster<E>> doCluster(List<E> data) { init(data); if (threshold != Double.MIN_VALUE) { double minESSIncrease = getInitialMinESS(clusters); // System.out.println("minESS: " + minESS); // System.out.println("threshold: " + threshold); // System.out.println(ArrayUtils.toString(clusters.toArray())); while (minESSIncrease < threshold && clusters.size() > 1) { // merge clusters // System.out.println("minESS: " + minESS); // System.out.println("threshold: " + threshold); minESSIncrease = clusterStep(clusters).getMergeCostIncrease(); } } else { StdErrProgressWriter progress = new StdErrProgressWriter(data.size() - targetSize, "Merging clusters "); while (clusters.size() > targetSize) { // merge clusters clusterStep(clusters); progress.progress(); } return clusters; } return clusters; } @SuppressWarnings( { "unchecked" }) protected void init(List<E> data) { clusterLevels = new ArrayList[data.size()]; clusterLevelMergeCosts = new double[data.size()]; if (data.get(0) instanceof Cluster) { clusters = (List<HierarchicalCluster<E>>) data; } else { clusters = new ArrayList<HierarchicalCluster<E>>(); for (int i = 0; i < data.size(); i++) { HierarchicalCluster<E> e2 = new HierarchicalCluster<E>(data.get(i), "Cluster " + i); if (data.get(i) instanceof ElementWithIndex) { e2.setLabel(((ElementWithIndex) data.get(i)).getLabel()); } clusters.add(e2); } } clusterLevels[clusters.size() - 1] = new ArrayList<HierarchicalCluster<E>>(clusters); } public double getInitialMinESS(List<HierarchicalCluster<E>> clusters) { double[][] ess = new double[clusters.size()][clusters.size()]; double minESS = Double.MAX_VALUE; for (int i = 0; i < clusters.size(); i++) { for (int j = i + 1; j < clusters.size(); j++) { HierarchicalCluster<E> cluster1 = clusters.get(i); HierarchicalCluster<E> cluster2 = clusters.get(j); HierarchicalCluster<E> merged = new HierarchicalCluster<E>(cluster1, cluster2); merged.setMergeCost(ess(merged)); cluster1.setMergeCost(ess(cluster1)); cluster2.setMergeCost(ess(cluster2)); double increase = merged.getMergeCostIncrease(); ess[i][j] = ess[j][i] = increase; if (increase < minESS) { minESS = increase; // short-cut if (minESS == 0) { return minESS; } } } } return minESS; } public double ess(Cluster<E> cluster) { if (cluster.size() == 1) { return 0; } double e = 0; E meanLine = elementDistance.meanObject(cluster); for (E e2 : cluster) { e += elementDistance.distance(e2, meanLine); } return e; } public HierarchicalCluster<E> clusterStep(List<HierarchicalCluster<E>> clusters) { HierarchicalCluster<E> cMerged = null; // we need to build the cross-product of all elements, i.e. clusters.size()*clusters.size() combinations, // forming a symmetric matrix // we can cut off a bit by just taking one half of the matrix, and cutting the diagonal if (numberOfCPUs > 1 && clusters.size() > numberOfCPUs) { // we are taking the hard-way for multi-core clustering. a technically much easier solution would be: // - build a list of all needed combinations of clusters i and j (i.e. one half of the symmetric matrix) // - cut the list in numberOfCPUs pieces, and process each of them, then merge // this is easy to do, but needs a list of ((clusters.size()*clusters.size()/2 - clusters.size()) entries, // potentially a lot of memory // // thus, we calculate indices for starting and ending offsets in the matrix final Indices2D[] indices = getIndices(clusters); doneSignal = new CountDownLatch(indices.length); // important: we need to reset this in each iteration! // System.out.println("distributed clustering, count: " + doneSignal.getCount()); ArrayList<ClusterThread> threads = new ArrayList<ClusterThread>((int) doneSignal.getCount()); // create & execute all the threads for (Indices2D indice : indices) { final ClusterThread thread = new ClusterThread(indice); threads.add(thread); e.execute(thread); } try { doneSignal.await(); // wait for all processes to finish } catch (InterruptedException ie) { } // merge the results double minESSIncrease = Double.MAX_VALUE; for (int i = 0; i < threads.size(); i++) { final HierarchicalCluster<E> mergedCluster = threads.get(i).mergedCluster; // System.out.println(mergedCluster); final double increase = mergedCluster.getMergeCostIncrease(); if (increase < minESSIncrease) { cMerged = mergedCluster; minESSIncrease = increase; } } } else { // System.out.println("normal clustering"); cMerged = findOptiomalClusterMerger(clusters); } if (debug) { System.out.println("\nMerging clusters with size " + cMerged.getLeftNode().size() + " & " + cMerged.getRightNode().size() + ", ESS: " + cMerged.getMergeCostIncrease()); } clusters.remove(cMerged.getLeftNode()); clusters.remove(cMerged.getRightNode()); clusters.add(cMerged); clusterLevels[clusters.size() - 1] = new ArrayList<HierarchicalCluster<E>>(clusters); clusterLevelMergeCosts[clusters.size() - 1] = cMerged.getMergeCost(); return cMerged; } private HierarchicalCluster<E> findOptiomalClusterMerger(List<HierarchicalCluster<E>> clusters) { return findOptiomalClusterMerger(clusters, 0, 0 + 1, clusters.size() - 1, clusters.size() - 1); } private HierarchicalCluster<E> findOptiomalClusterMerger(List<HierarchicalCluster<E>> clusters, int startX, int startY, int endX, int endY) { double minESSIncrease = Double.MAX_VALUE; HierarchicalCluster<E> cMerged = null; // System.out.println("find best cluster in " + startX + ", " + startY + ", " + endX + ", " + endY); for (int i = startX; i <= endX; i++) { for (int j = i == startX ? startY : i + 1; j <= (i == endX ? endY : clusters.size() - 1); j++) { // System.out.println(i + ", " + j); HierarchicalCluster<E> c1 = clusters.get(i); HierarchicalCluster<E> c2 = clusters.get(j); HierarchicalCluster<E> cnew = new HierarchicalCluster<E>(c1, c2); cnew.setMergeCost(ess(cnew)); c1.setMergeCost(ess(c1)); c2.setMergeCost(ess(c2)); double increase = cnew.getMergeCostIncrease(); if (increase < minESSIncrease) { minESSIncrease = increase; cMerged = cnew; // shortcut, stop if we have already a minimal ESS // FIXME: maybe if there are more elements with the same ESS, they should all be merged in the same // step? if (minESSIncrease == startX) { System.out.println("\nBreaking calculation, minESSIncrease is 0"); i = endX; break; } } } } return cMerged; } /** Returns the clustering at a certain level indicated by the merge cost for that level */ public ArrayList<HierarchicalCluster<E>> getClustersByThreshold(double threshold) throws SOMToolboxException { if (threshold < 0) { throw new SOMToolboxException("Can only get positive number of clusters!"); } // FIXME continue merging an under-developed clustering tree. if (clusterLevels[0] == null) { throw new NotImplementedException( "Continuing of cluster merging of a not fully grown tree is not yet implemented!"); } for (int i = 0; i < clusterLevelMergeCosts.length; i++) { double mergeCost = clusterLevelMergeCosts[i]; if (mergeCost <= threshold) { return clusterLevels[i]; } } return clusterLevels[clusterLevels.length - 1]; } /** * Returns the clustering at a certain level indicated by the relative merge cost for that level, compared to the * costs of merging all data items */ public ArrayList<HierarchicalCluster<E>> getClustersByRelativeThreshold(double percent) throws SOMToolboxException { return getClustersByThreshold(clusterLevelMergeCosts[0] * percent); } /** Returns the clustering at a certain level, where the level equals the number of clusters */ public ArrayList<HierarchicalCluster<E>> getClustersAtLevel(int num) throws SOMToolboxException { if (num <= 0) { throw new SOMToolboxException("Can only get positive number of clusters!"); } else if (num > clusterLevels.length) { throw new SOMToolboxException("Cluster index " + num + " out of bounds (" + (clusterLevels.length - 1) + ") !"); } else { return clusterLevels[num - 1]; } } @Override public HashMap<Integer, ArrayList<HierarchicalCluster<E>>> getClustersAtLevel() throws SOMToolboxException { HashMap<Integer, ArrayList<HierarchicalCluster<E>>> result = new HashMap<Integer, ArrayList<HierarchicalCluster<E>>>( clusterLevels.length); for (int i = 1; i <= clusterLevels.length; i++) { result.put(i, getClustersAtLevel(i)); } return result; } public void setDebug(boolean debug) { this.debug = debug; System.out.println("Ward clustering, set debug to: " + debug); } public void setNumberOfCPUs(int numberOfCPUs) { System.out.println("Ward clustering, working with " + numberOfCPUs + " CPUs."); this.numberOfCPUs = numberOfCPUs; e = (ThreadPoolExecutor) Executors.newFixedThreadPool(numberOfCPUs); doneSignal = new CountDownLatch(numberOfCPUs); // System.out.println("Set donesignal: " + doneSignal.getCount() + ", executor thread: " + e.getCorePoolSize()); } /** find matrix indices for starting & ending the symmetrical matrix indices */ private Indices2D[] getIndices(List<HierarchicalCluster<E>> clusters) { // the number of elements in the matrix is 1+2+3...n final int size = clusters.size(); int elementCount = (int) ((size - 1) * size / 2d); int splitSize = (int) Math.ceil(elementCount / (double) numberOfCPUs); // there might be cases in which we can't use the last CPU // e.g. we have 36 combinations, and 7 CPUs // ==> the splitsize would be 5.14, ceiling up to 6 (as 5*7 == 35 and thus not covering all combinations) // but then, already 6 splits (6*6) is covering all combinations, thus we don't need the last split! int splits = (numberOfCPUs - 1) * splitSize == elementCount ? numberOfCPUs - 1 : numberOfCPUs; Indices2D[] indices = new Indices2D[splits]; // System.out.println("\nclusters:" + clusters.size() + ", elementCount: " + elementCount + ", splitsize:" + // splitSize); int index = 0; for (int i = 0; i < size; i++) { for (int j = i + 1; j < size; j++) { final int indicesIndex = index / splitSize; // System.out.println("size: " + size + ", (index + 1) / splitSize: " + indicesIndex); // System.out.println(index + " => " + i + "/" + j); if (index % splitSize == 0) { // start index indices[indicesIndex] = new Indices2D(i, j, 0, 0); // System.out.println("set start for split " + indicesIndex); } if ((index + 1) % splitSize == 0 || index + 1 == elementCount) { // end index indices[indicesIndex].setEnd(i, j); // System.out.println("set end for split " + indicesIndex); } index++; } } // System.out.println("\nclusters:" + clusters.size() + ", elementCount: " + elementCount + ", splitsize:" + // splitSize + " => " // + Arrays.toString(indices)); return indices; } public List<HierarchicalCluster<E>> getClusters() { return clusters; } class ClusterThread implements Runnable { private HierarchicalCluster<E> mergedCluster; private Indices2D coords; public ClusterThread(Indices2D coords) { this.coords = coords; } @Override public void run() { mergedCluster = findOptiomalClusterMerger(clusters, coords.startX, coords.startY, coords.endX, coords.endY); // System.out.println("found cluster for " + coords + " : " + mergedCluster); doneSignal.countDown(); } } }