/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * MatrixHAC.java * Copyright (C) 2001 Mikhail Bilenko * */ /** * Similarity-matrix Implementation of Hierarachical Agglomerative Clustering. * <p> * Valid options are:<p> * * -N <0-10000> <br> * Number of clusters. <p> * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.11 $ */ package weka.clusterers; import java.io.*; import java.util.*; import java.text.*; import weka.core.*; import weka.core.metrics.*; import weka.filters.unsupervised.attribute.Remove; import weka.filters.unsupervised.attribute.Normalize; import weka.filters.Filter; public class HAC extends Clusterer implements SemiSupClusterer, OptionHandler{ /* name of the clusterer */ String m_name = "HAC"; /** Number of clusters */ protected int m_numClusters = -1; /** Number of clusters in the process*/ protected int m_numCurrentClusters = 0; /** ID of current cluster */ protected int m_clusterID = 0; /** Number of seeded clusters */ protected int m_numSeededClusters = 0; /** Dot file name for dumping graph for tree visualization */ protected String m_dotFileName = "user-features.dot"; /** Dot file name for dumping graph for tree visualization */ protected PrintWriter m_dotWriter = null; /** Instances that we are working with */ Instances m_instances; Instances m_descrInstances; /** holds the clusters */ protected ArrayList m_clusters = null; /** * temporary variable holding cluster assignments */ protected int [] m_clusterAssignments; /** distance matrix */ protected double[][] m_distanceMatrix = null; /** cluster similarity type */ public final static int SINGLE_LINK = 0; public final static int COMPLETE_LINK = 1; public final static int GROUP_AVERAGE = 2; public static final Tag[] TAGS_LINKING = { new Tag(SINGLE_LINK, "Single link"), new Tag(COMPLETE_LINK, "Complete link"), new Tag(GROUP_AVERAGE, "Group-average") }; /** Default linking method */ protected int m_linkingType = GROUP_AVERAGE; /** starting index of test data in unlabeledData if transductive clustering */ protected int m_StartingIndexOfTest = -1; /** seeding */ protected boolean m_seedable = false; /** holds the ([seed instance] -> [clusterLabel of seed instance]) mapping */ protected HashMap m_SeedHash = null; /** A 'checksum hash' where indices are hashed to the sum of their attribute values */ protected HashMap m_checksumHash = null; protected double[] m_checksumPerturb = null; /** * holds the random Seed, useful for random selection initialization */ protected int m_randomSeed = 100; protected Random m_randomGen = null; /** instance hash */ protected HashMap m_instancesHash = null; /** reverse instance hash */ protected HashMap m_reverseInstancesHash = null; /** The threshold distance beyond which no clusters are merged (except for one - TODO) */ protected double m_mergeThreshold = 0.8; /** verbose? */ protected boolean m_verbose = false; /** metric used to calculate similarity/distance */ // protected Metric m_metric = new WeightedDotP(); // protected String m_metricName = new String("weka.core.metrics.WeightedDotP"); protected Metric m_metric = new WeightedEuclidean(); protected String m_metricName = new String("weka.core.metrics.WeightedEuclidean"); /** Is the metric (and hence the algorithm) relying on similarities or distances? */ protected boolean m_isDistanceBased = false; /** has the metric has been constructed? a fix for multiple buildClusterer's */ protected boolean m_metricBuilt = false; // =============== // Public methods. // =============== /** empty constructor, required to call using Class.forName */ public HAC() {} /* Constructor */ public HAC(Metric metric) { m_metric = metric; m_metricName = m_metric.getClass().getName(); m_isDistanceBased = metric.isDistanceBased(); } /** Sets training instances */ public void setInstances(Instances instances) { m_instances = instances; } /** Return training instances */ public Instances getInstances() { return m_instances; } /** * Set the number of clusters to generate * * @param n the number of clusters to generate */ public void setNumClusters(int n) { m_numClusters = n; if (m_verbose) { System.out.println("Number of clusters: " + n); } } /** Set the merge threshold */ public void setMergeThreshold(double threshold) { m_mergeThreshold = threshold; } /** Get the merge threshold */ public double getMergeThreshold() { return m_mergeThreshold; } /** * Set the distance metric * * @param s the metric */ public void setMetric (LearnableMetric m) { m_metric = m; m_metricName = m_metric.getClass().getName(); m_isDistanceBased = m.isDistanceBased(); } /** * Get the distance metric * * @returns the distance metric used */ public Metric getMetric () { return m_metric; } /** * Get the distance metric name * * @returns the name of the distance metric used */ public String metricName () { return m_metricName; } /** * We always want to implement SemiSupClusterer from a class extending Clusterer. * We want to be able to return the underlying parent class. * @return parent Clusterer class */ public Clusterer getThisClusterer() { return this; } /** * Cluster given instances to form the specified number of clusters. * * @param data instances to be clustered * @param num_clusters number of clusters to create * @exception Exception if something goes wrong. */ public void buildClusterer(Instances data, int num_clusters) throws Exception { setNumClusters(num_clusters); buildClusterer(data); } /** * Clusters unlabeledData and labeledData (with labels removed), * using labeledData as seeds * * @param labeledData labeled instances to be used as seeds * @param unlabeledData unlabeled instances * @param classIndex attribute index in labeledData which holds class info * @param numClusters number of clusters * @exception Exception if something goes wrong. */ public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, int numClusters) throws Exception { // remove labels of labeledData before putting in SeedHash Instances clusterData = new Instances(labeledData); clusterData.deleteClassAttribute(); // create SeedHash from labeledData if (getSeedable()) { Seeder seeder = new Seeder(clusterData, labeledData); setSeedHash(seeder.getAllSeeds()); } // add unlabeled data to labeled data (labels removed), not the // other way around, so that the hash table entries are consistent // with the labeled data without labels for (int i=0; i<unlabeledData.numInstances(); i++) { clusterData.add(unlabeledData.instance(i)); } if (m_verbose) { System.out.println("combinedData has size: " + clusterData.numInstances() + "\n"); } // learn metric using labeled data, then cluster both the labeled and unlabeled data m_metric.buildMetric(labeledData); m_metricBuilt = true; // check if the number of clusters is dynamically set to the number of classes if (m_numClusters == -1) { m_numClusters = labeledData.numClasses(); System.out.println("DYNAMIC NUMBER OF CLUSTERS, setting to " + m_numClusters); } buildClusterer(clusterData, numClusters); } /** * Clusters unlabeledData and labeledData (with labels removed), * using labeledData as seeds * * @param labeledData labeled instances to be used as seeds * @param unlabeledData unlabeled instances * @param classIndex attribute index in labeledData which holds class info * @param numClusters number of clusters * @param startingIndexOfTest from where test data starts in unlabeledData, useful if clustering is transductive * @exception Exception if something goes wrong. */ public void buildClusterer(Instances labeledData, Instances unlabeledData, int classIndex, int numClusters, int startingIndexOfTest) throws Exception { m_StartingIndexOfTest = startingIndexOfTest + labeledData.numInstances(); buildClusterer(labeledData, unlabeledData, classIndex, numClusters); } /** * Cluster given instances. If no threshold or number of clusters is set, * clustering proceeds until two clusters are left. * * @param data instances to be clustered * @exception Exception if something goes wrong. */ public void buildClusterer(Instances data) throws Exception { m_randomGen = new Random(m_randomSeed); m_dotWriter = new PrintWriter(new BufferedOutputStream(new FileOutputStream(m_dotFileName))); m_dotWriter.println("digraph HAC {\n"); setInstances(data); m_clusterAssignments = new int[m_instances.numInstances()]; if (m_verbose && m_SeedHash != null) { System.out.println("Using seeding ..."); } if (m_instances.checkForNominalAttributes() && m_instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle nominal attributes\n"); } //m_instances = filterInstanceDescriptions(m_instances); // Don't rebuild the metric if it was already trained if (!m_metricBuilt) { m_metric.buildMetric(data); } hashInstances(m_instances); createDistanceMatrix(); cluster(); unhashClusters(); m_dotWriter.println("}"); m_dotWriter.close(); } /** If some of the attributes start with "__", form a separate Instances set with * descriptions and filter them out of the argument dataset. * Return the original dataset without the filtered out attributes */ protected Instances filterInstanceDescriptions(Instances instances) throws Exception { Instances filteredInstances; // Normalize normalizeFilter = new Normalize(); // normalizeFilter.setInputFormat(instances); // instances = Filter.useFilter(instances, normalizeFilter); // System.out.println("Normalized the instance attributes"); // go through the attributes and find the description attributes ArrayList descrIndexList = new ArrayList(); for (int i = 0; i < instances.numAttributes(); i++) { Attribute attr = instances.attribute(i); if (attr.name().startsWith("__")) { descrIndexList.add(new Integer(i)); System.out.println("filtering " + attr); } } // filter out the description attributes if necessary if (descrIndexList.size() > 0) { m_descrInstances = new Instances(instances); // filter out the descriptions first int[] descrIndeces = new int[descrIndexList.size()]; for (int i = 0; i < descrIndexList.size(); i++) { descrIndeces[i] = ((Integer) descrIndexList.get(i)).intValue(); } Remove attributeFilter = new Remove(); attributeFilter.setAttributeIndicesArray(descrIndeces); attributeFilter.setInvertSelection(false); attributeFilter.setInputFormat(instances); filteredInstances = Filter.useFilter(instances, attributeFilter); attributeFilter.setInvertSelection(true); attributeFilter.setInputFormat(instances); m_descrInstances = Filter.useFilter(instances, attributeFilter); } else { filteredInstances = new Instances(instances); } return filteredInstances; } /** * Reset all values that have been learned */ public void resetClusterer() throws Exception{ if (m_metric instanceof LearnableMetric) ((LearnableMetric)m_metric).resetMetric(); m_SeedHash = null; } /** Set the m_SeedHash */ public void setSeedHash(HashMap seedhash) { m_SeedHash = seedhash; m_seedable = true; } /** * Set the random number seed * @param s the seed */ public void setRandomSeed (int s) { m_randomSeed = s; } /** Return the random number seed */ public int getRandomSeed () { return m_randomSeed; } /** Turn seeding on and off * @param seedable should seeding be done? */ public void setSeedable(boolean seedable) { m_seedable = seedable; } /** Turn seeding on and off * @param seedable should seeding be done? */ public boolean getSeedable() { return m_seedable; } /** Read the seeds from a hastable, where every key is an instance and every value is: * a FastVector of Doubles: [(Double) probInCluster0 ... (Double) probInClusterN] * @param seedVector vector containing seeds */ public void seedClusterer(HashMap SeedHash) { if (m_seedable) { m_SeedHash = SeedHash; } } /** * returns the SeedHash * @return seeds hash */ public HashMap getSeedHash() {return m_SeedHash;} /** * Create the hashtable from given Instances; * keys are numeric indeces, values are actual Instances * * @param data Instances * */ protected void hashInstances (Instances data) { int next_value = 0; m_instancesHash = new HashMap(); m_reverseInstancesHash = new HashMap(); m_checksumHash = new HashMap(); // initialize checksum perturbations m_checksumPerturb = new double[data.numAttributes()]; for (int i = 0; i < m_checksumPerturb.length; i++) { m_checksumPerturb[i] = m_randomGen.nextFloat(); } for (Enumeration enum = data.enumerateInstances(); enum.hasMoreElements();) { Instance instance = (Instance) enum.nextElement(); if (!m_instancesHash.containsValue(instance)) { Integer idx = new Integer(next_value); next_value++; m_instancesHash.put(idx, instance); m_reverseInstancesHash.put(instance, idx); // hash the checksum value double [] values = instance.toDoubleArray(); double checksum = 0; for (int i = 0; i < values.length; i++) { checksum += m_checksumPerturb[i] * values[i];} Double checksumIdx = new Double(checksum); if (m_checksumHash.containsKey(checksumIdx)) { Object prev = m_checksumHash.get(checksumIdx); ArrayList chain; if (prev instanceof Integer) { chain = new ArrayList(); Integer prevIdx = (Integer) m_checksumHash.get(checksumIdx); chain.add(prevIdx); } else { //instanceof Arraylist chain = (ArrayList) m_checksumHash.get(checksumIdx); } chain.add(idx); m_checksumHash.put(checksumIdx, chain); } else { // no collisions m_checksumHash.put(checksumIdx, idx); } } else { System.err.println("Already encountered instance, skipping " + instance); } } } /** * assuming m_clusters contains the clusters of indeces, convert it to * clusters containing actual instances */ protected void unhashClusters() throws Exception{ if (m_clusters == null || m_instancesHash == null) throw new Exception ("Clusters or hash not initialized"); ArrayList clusters = new ArrayList(); for (int i = 0; i < m_clusters.size(); i++ ) { Cluster cluster = (Cluster) m_clusters.get(i); Cluster newCluster = new Cluster(); for (int j = 0; j < cluster.size(); j++) { Integer instanceIdx = (Integer) cluster.get(j); double wt = cluster.weightAt(j); newCluster.add((Instance)m_instancesHash.get(instanceIdx), wt); } clusters.add(newCluster); } m_clusters = clusters; } /** * Fill the distance matrix with values using the metric * */ protected void createDistanceMatrix () throws Exception { int n = m_instancesHash.size(); double sim; m_distanceMatrix = new double[n][n]; for (int i = 0; i < n; i++) { for (int j = i+1; j < n; j++) { m_distanceMatrix[i][j] = m_distanceMatrix[j][i] = m_metric.distance((Instance) m_instancesHash.get(new Integer(i)), (Instance) m_instancesHash.get(new Integer(j))); } } } /** * Set the type of clustering * * @param type Clustering type: can be HAC.SINGLE_LINK, HAC.COMPLETE_LINK, * or HAC.GROUP_AVERAGE */ public void setLinkingType (SelectedTag linkingType) { if (linkingType.getTags() == TAGS_LINKING) { m_linkingType = linkingType.getSelectedTag().getID(); } } /** * Get the linking type * * @returns the linking type */ public SelectedTag getLinkingType () { return new SelectedTag(m_linkingType, TAGS_LINKING); } /** * Internal method that initializes distances between seed clusters to * POSITIVE_INFINITY */ protected void initConstraints() { for (int i = 0; i < m_instances.numInstances(); i++) { if (m_clusterAssignments[i] < m_numSeededClusters) { // make distances to elements from other seeded clusters POSITIVE_INFINITY for (int j = i+1; j < m_instances.numInstances(); j++) { if (m_clusterAssignments[j] < m_numSeededClusters && m_clusterAssignments[j] != m_clusterAssignments[i]) { m_distanceMatrix[i][j] = m_distanceMatrix[j][i] = Double.POSITIVE_INFINITY; } } } } } /** * Internal method that produces the actual clusters */ protected void cluster() throws Exception { double last_distance = Double.MIN_VALUE; m_numCurrentClusters = 0; m_numSeededClusters = 0; m_clusters = new ArrayList(); TreeSet leftOverSet = null; // Initialize singleton clusters m_clusterAssignments = new int[m_instances.numInstances()]; for (int i = 0; i < m_instances.numInstances(); i++) { m_clusterAssignments[i] = -1; } // utilize seeds if available if (m_SeedHash != null) { if (m_verbose) { System.out.println("Seeding HAC using " + m_SeedHash.size() + " seeds"); } Iterator iterator = m_SeedHash.entrySet().iterator(); int maxClassIdx = -1; HashSet classIdxSet = new HashSet(); while (iterator.hasNext()) { Map.Entry entry = (Map.Entry) iterator.next(); Instance instance = (Instance) entry.getKey(); Integer instanceIdx = (Integer) m_reverseInstancesHash.get(instance); Integer clusterIdx = (Integer) entry.getValue(); classIdxSet.add(clusterIdx); m_clusterAssignments[instanceIdx.intValue()] = clusterIdx.intValue(); if (clusterIdx.intValue() > maxClassIdx) { maxClassIdx = clusterIdx.intValue(); } } m_numCurrentClusters = m_numSeededClusters = classIdxSet.size(); System.out.println("Seeded " + m_numSeededClusters + " clusters"); // If the seeding is incomplete, need to memorize "unseeded" cluster numbers if (m_numCurrentClusters < m_numClusters) { leftOverSet = new TreeSet(); for (int i = 0; i < m_numClusters; i++) { if (!classIdxSet.contains(new Integer(i))) { leftOverSet.add(new Integer(i)); } } } } // assign unseeded instances to singleton clusters for (int i = 0; i < m_instances.numInstances(); i++) { if (m_clusterAssignments[i] == -1) { // utilize "left over clusters first" if (leftOverSet != null) { Integer clusterIdx = (Integer) leftOverSet.first(); m_clusterAssignments[i] = clusterIdx.intValue(); leftOverSet.remove(clusterIdx); if (leftOverSet.isEmpty()) { leftOverSet = null; } } else { m_clusterAssignments[i] = m_numCurrentClusters; } m_numCurrentClusters++; } } // initialize m_clusters arraylist getIntClusters(); if (m_SeedHash != null) { initConstraints(); } // merge clusters until desired number of clusters is reached double mergeDistance = 0; while (m_numCurrentClusters > m_numClusters && mergeDistance < m_mergeThreshold) { mergeDistance = mergeStep(); if (m_verbose) { System.out.println("Merged with " + (m_numCurrentClusters) + " clusters left; distance=" + mergeDistance); } } System.out.println("Done clustering with " + m_clusters.size() + " clusters"); for (int i = 0; i < m_clusters.size(); i++) System.out.print(((Cluster)m_clusters.get(i)).size() + "\t"); initClusterAssignments(); } /** * Internal method that finds two most similar clusters and merges them */ protected double mergeStep() throws Exception{ double bestDistance = Double.MAX_VALUE; double thisDistance; Cluster thisCluster, nextCluster; ArrayList mergeCandidatesList = new ArrayList(); int cluster1_index, cluster2_index; if (m_verbose) { System.out.println("\nBefore merge step there are " + m_clusters.size() + " clusters; m_numCurrentClusters=" + m_numCurrentClusters); } // find two most similar clusters for (int i = 0; i < m_clusters.size()-1; i++){ thisCluster = (Cluster)m_clusters.get(i); for (int j = i+1; j < m_clusters.size(); j++) { thisDistance = clusterDistance(thisCluster, (Cluster) m_clusters.get(j)); if (m_verbose) { // System.out.println("Distance between " + i + " and " + j + " is " + thisDistance); } // If there is a tie, add to the list of top distances if (thisDistance == bestDistance) { mergeCandidatesList.add(new Integer(i)); mergeCandidatesList.add(new Integer(j)); } else if (thisDistance < bestDistance) { // this is the best distance seen this far mergeCandidatesList.clear(); mergeCandidatesList.add(new Integer(i)); mergeCandidatesList.add(new Integer(j)); bestDistance = thisDistance; } } } // randomly pick a most similar pair from the list of candidates int i1 = (int) (mergeCandidatesList.size() * m_randomGen.nextFloat()); int i2 = (i1 % 2 > 0) ? (i1 - 1) : (i1 + 1); int cluster1Idx = ((Integer) mergeCandidatesList.get(i1)).intValue(); int cluster2Idx = ((Integer) mergeCandidatesList.get(i2)).intValue(); if (m_verbose) { System.out.println("\nMerging clusters " + cluster1Idx + " and " + cluster2Idx + "; distance=" + bestDistance); } System.out.print("Best distance=" + ((float)bestDistance) + "; Merging:\n"); printCluster(cluster1Idx); System.out.print("AND\n"); printCluster(cluster2Idx); System.out.print("\n"); Cluster newCluster = mergeClusters(cluster1Idx, cluster2Idx); // check if the new cluster is sufficiently large and "good" HashMap groupCountMap = new HashMap(); for (int i = 0; i < newCluster.size(); i++) { int idx = ((Integer)newCluster.get(i)).intValue(); Instance instance = m_descrInstances.instance(idx); // get the set of groups String groupString = instance.stringValue(1); StringTokenizer tokenizer = new StringTokenizer(groupString, "|"); while (tokenizer.hasMoreTokens()) { String group = tokenizer.nextToken(); if (groupCountMap.containsKey(group)) { Integer count = (Integer) groupCountMap.get(group); groupCountMap.put(group, new Integer(count.intValue() + 1)); } else { groupCountMap.put(group, new Integer(1)); } } } int largestGroupCount = -1; Iterator iterator = groupCountMap.entrySet().iterator(); while(iterator.hasNext()) { Map.Entry entry = (Map.Entry) iterator.next(); int thisCount = ((Integer)entry.getValue()).intValue(); String group = (String) entry.getKey(); if (thisCount > largestGroupCount && !group.equals("grad")) { largestGroupCount = thisCount; } } // if the most common group includes 80% of cluster members, yell! if ((largestGroupCount + 0.0)/(newCluster.size() + 0.0) > 0.6 && newCluster.size() > 2) { System.out.println("HAPPY JOY JOY! LOOK HERE!"); } // have to remove in order because we're using index, argh if (cluster1Idx > cluster2Idx) { m_clusters.remove(cluster1Idx); m_clusters.remove(cluster2Idx); } else { m_clusters.remove(cluster2Idx); m_clusters.remove(cluster1Idx); } m_clusters.add(newCluster); m_numCurrentClusters--; return bestDistance; } /** * Computes the clusters from the cluster assignments * * @exception Exception if clusters could not be computed successfully */ public ArrayList getIntClusters() throws Exception { m_clusters = new ArrayList(); Cluster [] clusterArray = new Cluster[m_numCurrentClusters]; if (m_verbose) { System.out.println("Cluster assignments: "); for (int i=0; i < m_clusterAssignments.length; i++) { System.out.print(i + ":" + m_clusterAssignments[i] + " "); } } for (int i=0; i < m_instances.numInstances(); i++) { Instance inst = m_instances.instance(i); if(clusterArray[m_clusterAssignments[i]] == null) clusterArray[m_clusterAssignments[i]] = new Cluster(m_clusterID++); clusterArray[m_clusterAssignments[i]].add(new Integer(i), 1); // System.out.println("Adding: " + i + " to cluster: " + m_clusterAssignments[i]); } for (int j =0; j< m_numCurrentClusters; j++) { if (clusterArray[j] == null) { System.out.println("Empty cluster: " + j); // printIntClusters(); setVerbose(true); m_numCurrentClusters--; m_numClusters--; } else { m_clusters.add(clusterArray[j]); String labelString = ""; for (int i = 0; i < clusterArray[j].size(); i++){ Instance inst = m_instances.instance(((Integer) (clusterArray[j].get(i))).intValue()); labelString = labelString + printInstance(inst) + "\\n"; } m_dotWriter.println("node" + clusterArray[j].clusterID + "[label = \"" + labelString + "\"]"); } } // printIntClusters(); return m_clusters; } /** * Computes the final clusters from the cluster assignments, for external access * * @exception Exception if clusters could not be computed successfully */ public ArrayList getClusters() throws Exception { ArrayList finalClusters = new ArrayList(); Cluster [] clusterArray = new Cluster[m_numClusters]; for (int i=0; i < m_instances.numInstances(); i++) { Instance inst = m_instances.instance(i); if(clusterArray[m_clusterAssignments[i]] == null) clusterArray[m_clusterAssignments[i]] = new Cluster(); clusterArray[m_clusterAssignments[i]].add(inst, 1); } for (int j =0; j< m_numClusters; j++) finalClusters.add(clusterArray[j]); return finalClusters; } /** * internal method that returns the distance between two clusters */ protected double clusterDistance(Cluster cluster1, Cluster cluster2) { if (cluster2 == null || cluster1 == null) { System.err.println("PANIC! clusterDistance called with null argument(s)"); try{ printIntClusters(); } catch(Exception e){} } int i1 = ((Integer) cluster1.get(0)).intValue(); int i2 = ((Integer) cluster2.get(0)).intValue(); return m_distanceMatrix[i1][i2]; } protected void checkClusters() { } /** Internal method to merge two clusters and update distances */ protected Cluster mergeClusters (int cluster1Idx, int cluster2Idx) throws Exception { Cluster newCluster = new Cluster(m_clusterID++); Cluster cluster1 = (Cluster) m_clusters.get(cluster1Idx); Cluster cluster2 = (Cluster) m_clusters.get(cluster2Idx); int cluster1FirstIdx =((Integer) cluster1.get(0)).intValue(); int cluster2FirstIdx =((Integer) cluster2.get(0)).intValue(); newCluster.copyElements(cluster1); newCluster.copyElements(cluster2); checkClusters(); // Update the distance matrix depending on the linkage type switch (m_linkingType) { case SINGLE_LINK: // go through all clusters and update the distance from first element // to the first element of the new cluster for (int i = 0; i < m_clusters.size(); i++){ if (i != cluster1Idx && i != cluster2Idx) { // skip these clusters themselves Cluster currentCluster = (Cluster) m_clusters.get(i); int currClusterFirstIdx = ((Integer) currentCluster.get(0)).intValue(); if (m_distanceMatrix[cluster1FirstIdx][currClusterFirstIdx] < m_distanceMatrix[cluster2FirstIdx][currClusterFirstIdx]) { // first cluster is closer, no need to update } else { // second cluster is closer, must update distance between the first representative m_distanceMatrix[cluster1FirstIdx][currClusterFirstIdx] = m_distanceMatrix[currClusterFirstIdx][cluster1FirstIdx] = m_distanceMatrix[cluster2FirstIdx][currClusterFirstIdx]; } // check for infinity links if (m_distanceMatrix[cluster2FirstIdx][currClusterFirstIdx] == Double.POSITIVE_INFINITY) { m_distanceMatrix[cluster1FirstIdx][currClusterFirstIdx] = m_distanceMatrix[currClusterFirstIdx][cluster1FirstIdx] = Double.POSITIVE_INFINITY; } if (m_distanceMatrix[cluster1FirstIdx][currClusterFirstIdx] == Double.POSITIVE_INFINITY) { m_distanceMatrix[cluster2FirstIdx][currClusterFirstIdx] = m_distanceMatrix[currClusterFirstIdx][cluster2FirstIdx] = Double.POSITIVE_INFINITY; } } } break; case COMPLETE_LINK: // go through all clusters and update the distance from first element // to the first element of the new cluster for (int i = 0; i < m_clusters.size(); i++){ if (i != cluster1Idx && i != cluster2Idx) { // skip these clusters themselves Cluster currentCluster = (Cluster) m_clusters.get(i); int currClusterFirstIdx = ((Integer) currentCluster.get(0)).intValue(); if (m_distanceMatrix[cluster1FirstIdx][currClusterFirstIdx] > m_distanceMatrix[cluster2FirstIdx][currClusterFirstIdx]) { // first cluster is closer, no need to update } else { // second cluster is closer, must update distance between the first representative m_distanceMatrix[cluster1FirstIdx][currClusterFirstIdx] = m_distanceMatrix[currClusterFirstIdx][cluster1FirstIdx] = m_distanceMatrix[cluster2FirstIdx][currClusterFirstIdx]; } } } break; case GROUP_AVERAGE: // go through all clusters and update the distance from first element // to the first element of the new cluster for (int i = 0; i < m_clusters.size(); i++){ if (i != cluster1Idx && i != cluster2Idx) { // skip these clusters themselves Cluster currentCluster = (Cluster) m_clusters.get(i); int currClusterFirstIdx = ((Integer) currentCluster.get(0)).intValue(); int cluster1Size = cluster1.size(); int cluster2Size = cluster2.size(); // must update distance between the first representative m_distanceMatrix[cluster1FirstIdx][currClusterFirstIdx] = m_distanceMatrix[currClusterFirstIdx][cluster1FirstIdx] = (m_distanceMatrix[cluster1FirstIdx][currClusterFirstIdx] * cluster1Size + m_distanceMatrix[cluster2FirstIdx][currClusterFirstIdx] * cluster2Size) / (cluster1Size + cluster2Size); } } } String labelString = ""; for (int i = 0; i < newCluster.size(); i++){ Instance inst = m_instances.instance(((Integer) (newCluster.get(i))).intValue()); labelString = labelString + printInstance(inst) + "\\n"; } m_dotWriter.println("node" + newCluster.clusterID + "[label = \"" + labelString + "\"]"); m_dotWriter.println("node" + newCluster.clusterID + "->node" + cluster1.clusterID); m_dotWriter.println("node" + newCluster.clusterID + "->node" + cluster2.clusterID); return newCluster; } /** Print an instance for the dot file */ String printInstance(Instance instance) { String stringToPrint; int[] ascendingSortIndicesOfAttributes = Utils.sort(instance.toDoubleArray()); if (m_descrInstances == null) { stringToPrint = instance.toString(); } else { int idx = ((Integer) m_reverseInstancesHash.get(instance)).intValue(); stringToPrint = (m_descrInstances.instance(idx)).toString() + ": "; } DecimalFormat fmt = new DecimalFormat("0.000"); for (int i = 0; i < 5; i++) { Attribute attrib = m_instances.attribute(ascendingSortIndicesOfAttributes[m_instances.numAttributes()-i-1]); if (instance.value(attrib) > 0) { stringToPrint = stringToPrint + attrib.name() + ": " + fmt.format(instance.value(attrib)) + "\t"; } } return stringToPrint; } /** Update the clusterAssignments for all points in two clusters that are about to be merged */ protected void initClusterAssignments() { m_clusterAssignments = new int[m_instances.numInstances()]; for (int i = 0; i < m_clusters.size(); i++) { Cluster cluster = (Cluster) m_clusters.get(i); for (int j = 0; j < cluster.size(); j++) { Integer idx = (Integer) cluster.get(j); // System.out.println("Instance number: " + idx + " has cluster id: " + i); m_clusterAssignments[idx.intValue()] = i; } } } /** Outputs the current clustering * * @exception Exception if something goes wrong */ public void printClusters() throws Exception { if (m_clusters == null) throw new Exception ("Clusters were not created"); for (int i = 0; i < m_clusters.size(); i++) { System.out.println ("Cluster " + i); printCluster(i); } } /** Outputs the specified cluster * * @exception Exception if something goes wrong */ public void printCluster(int i) throws Exception { if (m_clusters == null) throw new Exception ("Clusters were not created"); Cluster cluster = (Cluster) m_clusters.get(i); for (int j = 0; j < cluster.size(); j++) { // Instance instance = (Instance) m_instancesHash.get((Integer) cluster.elementAt(j)); Object o = cluster.get(j); Instance instance = (o instanceof Instance) ? (Instance)o : m_instances.instance(((Integer)o).intValue()); if (m_descrInstances == null) { System.out.print("\t" + instance); } else { System.out.print("\t"); System.out.println(printInstance(instance)); } } } /** Outputs the current clustering * * @exception Exception if something goes wrong */ public void printIntClusters() throws Exception { if (m_clusters == null) throw new Exception ("Clusters were not created"); for (int i = 0; i < m_clusters.size(); i++) { Cluster cluster = (Cluster) m_clusters.get(i); System.out.println ("Cluster " + i + " consists of " + cluster.size() + " elements"); for (int j = 0; j < cluster.size(); j++) { // Instance instance = (Instance) m_instancesHash.get((Integer) cluster.elementAt(j)); Integer idx = (Integer) cluster.get(j); Instance instance = (Instance) m_instancesHash.get(idx); System.out.println("\t\t" + instance); } } } /** * Clusters an instance. * * @param instance the instance to cluster. * @exception Exception if something goes wrong. */ public int clusterInstance(Instance instance) throws Exception { double bestDistance = Double.MAX_VALUE; int instanceIdx = 0; // if (m_reverseInstancesHash.containsKey(instance)) { // instanceIdx = ((Integer) m_reverseInstancesHash.get(instance)).intValue(); // System.out.println("Located index in m_reverseInstancesHash"); // return m_clusterAssignments[instanceIdx]; // } else { double [] values = instance.toDoubleArray(); double checksum = 0; for (int i = 0; i < values.length; i++) { checksum += m_checksumPerturb[i] * values[i]; } Double checksumIdx = new Double(checksum); if (m_checksumHash.containsKey(checksumIdx)) { Object obj = m_checksumHash.get(checksumIdx); if (obj instanceof Integer) { int idx = ((Integer) obj).intValue(); return m_clusterAssignments[idx]; } else { // instanceof Arraylist ArrayList chain = (ArrayList) obj; for (int i = 0; i < chain.size(); i++) { Integer idx = (Integer) chain.get(i); Instance clusteredInstance = (Instance) m_instancesHash.get(idx); if (matchInstance(instance, clusteredInstance)) { return m_clusterAssignments[idx.intValue()]; } } throw new Exception("UNKNOWN INSTANCE!!!!"); } } else { // unknown checksum throw new Exception("UNKNOWN CHECKSUM!!!!"); // ArrayList candidateClusterList = new ArrayList(); // for (int i = 0; i < m_numClusters; i++) { // Cluster thisCluster = (Cluster) m_clusters.get(i); // double thisDistance = distance (instance, thisCluster); // if (thisDistance < bestDistance) { // candidateClusterList.clear(); // candidateClusterList.add (new Integer(i)); // bestDistance = thisDistance; // } else if (thisDistance == bestDistance) { // candidateClusterList.add (new Integer(i)); // } // } // // randomly pick a candidate // int i = (int) (candidateClusterList.size() * Math.random()); // int clusterIdx = ((Integer) candidateClusterList.get(i)).intValue(); // if (clusterIdx != m_clusterAssignments[instanceIdx]) { // System.out.println("Mismatch: idx=" + clusterIdx + " assigned=" + m_clusterAssignments[instanceIdx]); // } // return clusterIdx; } } /** Internal method: check if two instances match on their attribute values */ protected boolean matchInstance(Instance instance1, Instance instance2) { double [] values1 = instance1.toDoubleArray(); double [] values2 = instance2.toDoubleArray(); for (int i = 0; i < values1.length; i++) { if (values1[i] != values2[i]) { return false; } } return true; } /** * internal method that returns the distance between an instance and a cluster */ protected double distance (Instance instance, Cluster cluster) throws Exception { Integer idx; double distance = 0; switch (m_linkingType) { case SINGLE_LINK: double minDistance = Double.MAX_VALUE; for (int i = 0; i < cluster.size(); i++) { Instance clusterInstance = (Instance) cluster.get(i); double currDistance = m_metric.distance(instance, clusterInstance); if (currDistance < minDistance) minDistance = currDistance; } distance = minDistance; break; case COMPLETE_LINK: double maxDistance = Double.MIN_VALUE; for (int i = 0; i < cluster.size(); i++) { Instance clusterInstance = (Instance) cluster.get(i); double currDistance = m_metric.distance(instance, clusterInstance); if (currDistance > maxDistance) maxDistance = currDistance; } distance = maxDistance; break; case GROUP_AVERAGE: double avgDistance = 0; for (int i = 0; i < cluster.size(); i++) { Instance clusterInstance = (Instance) cluster.get(i); avgDistance += m_metric.distance(instance, clusterInstance); } distance = avgDistance/cluster.size(); break; default: throw new Exception("Unknown linkage type!"); } return distance; } /** * set the verbosity level of the clusterer * @param verbose messages on(true) or off (false) */ public void setVerbose (boolean verbose) { m_verbose = verbose; } /** * get the verbosity level of the clusterer * @return verbose messages on(true) or off (false) */ public boolean getVerbose () { return m_verbose; } /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options **/ public Enumeration listOptions() { Vector newVector = new Vector(2); newVector.addElement(new Option("\tThreshold.\n" +"\t(default=MAX_DOUBLE)", "T", 1,"-T <0-MAX_DOUBLE>")); newVector.addElement(new Option("\tNumber of clusters.\n" +"a\t(default=-1)", "N", 1,"-N <-1-MAX_INT100%>")); return newVector.elements(); } /** * Parses a given list of options. * * Valid options are:<p> * * -A <0-100> <br> * Acuity. <p> * * -C <0-100> <br> * Cutoff. <p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported * **/ public void setOptions(String[] options) throws Exception { String optionString; optionString = Utils.getOption('N', options); if (optionString.length() != 0) { setNumClusters(Integer.parseInt(optionString)); } } /** * Gets the current settings of Greedy Agglomerative Clustering * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [70]; int current = 0; options[current++] = "-N"; options[current++] = "" + m_numClusters; if (m_linkingType == SINGLE_LINK) { options[current++] = "-I"; } else if (m_linkingType == COMPLETE_LINK) { options[current++] = "-C"; } else if (m_linkingType == GROUP_AVERAGE) { options[current++] = "-G"; } if (m_seedable) { options[current++] = "-S"; } options[current++] = "-M"; options[current++] = m_metric.getClass().getName(); if (m_metric instanceof OptionHandler) { String[] metricOptions = ((OptionHandler)m_metric).getOptions(); for (int i = 0; i < metricOptions.length; i++) { options[current++] = metricOptions[i]; } } while (current < options.length) { options[current++] = ""; } return options; } /** * Train the clusterer using specified parameters * * @param instances Instances to be used for training */ public void trainClusterer (Instances instances) throws Exception { if (m_metric instanceof LearnableMetric) { if (((LearnableMetric)m_metric).getTrainable()) { ((LearnableMetric)m_metric).learnMetric(instances); } else { throw new Exception ("Metric is not trainable"); } } else { throw new Exception ("Metric is not trainable"); } } /** returns objective function, needed for compatibility with SemiSupClusterer */ public double objectiveFunction() { return Double.NaN; } /** Return the number of clusters */ public int getNumClusters() { return m_numClusters; } /** A duplicate function to conform to Clusterer abstract class. * @returns the number of clusters */ public int numberOfClusters() { return getNumClusters(); } /** * get an array of random indeces out of n possible values. * if the number of requested indeces is larger then maxIdx, returns * maxIdx permuted values * @param maxIdx - the maximum index of the set * @param numIdxs number of indexes to return * @return an array of indexes */ public static int[] randomSubset(int numIdxs, int maxIdx) { Random r = new Random(maxIdx + numIdxs); int[] indeces = new int[maxIdx]; for (int i = 0; i < maxIdx; i++) { indeces[i] = i; } // permute the indeces randomly for (int i = 0; i < indeces.length; i++) { int idx = r.nextInt (maxIdx); int temp = indeces[idx]; indeces[idx] = indeces[i]; indeces[i] = temp; } int []returnIdxs = new int[Math.min(numIdxs,maxIdx)]; for (int i = 0; i < returnIdxs.length; i++) { returnIdxs[i] = indeces[i]; } return returnIdxs; } // Main method for testing this class public static void main(String [] argv) { try { //////// Iris data //String datafile = "/u/ml/software/weka-latest/data/iris.arff"; // String datafile = "/u/mbilenko/ml/tivoli/user-features-GroupClassGrad.arff"; String datafile = "/u/mbilenko/ml/tivoli/data/user-features-processClass.arff"; // String datafile = "/u/mbilenko/weka/data/glass.arff"; // set up the data FileReader reader = new FileReader (datafile); Instances data = new Instances (reader); // filter out bad attributes for tivoli clustering String [] filteredProcesses = {"pico", "twm", "Xvnc", "lpr", "fvwm2", "xclock", "FvwmButtons", "FvwmPager", "ymessenger.bin", "vim", "vi", "xemacs", "xscreensaver", "gnome-panel", "gnome-settings-daemon", "gconfd-2", "xlock", "kdesud", "ssh", "tasklist_applet", "panel", "gnome-session", "gnome-smproxy", "MozillaFirebird-bin", "nautilus", "mutt", "mixer_applet2", "metacity", "bonobo-activation-server", "csh", "nautilus-throbber", "xmms", "realplay", "konqueror", "knode", "kdesktop_lock", "kwrapper", "artsd", "esd", "gnome-panel", "gnome-terminal", "mail", "gnome-name-service", "deskguide_applet", "sawfish", "gaim", "konsole", "opera", "enlightenment", "6", "wmaker"}; System.out.println("filtered=" + filteredProcesses.length); int[] descrIndeces = new int[filteredProcesses.length]; for (int i = 0; i < descrIndeces.length; i++) { Attribute attr = data.attribute(filteredProcesses[i]); System.out.println(i + ": " + attr); descrIndeces[i] = attr.index(); } Remove attributeFilter = new Remove(); attributeFilter.setAttributeIndicesArray(descrIndeces); attributeFilter.setInvertSelection(false); attributeFilter.setInputFormat(data); data = Filter.useFilter(data, attributeFilter); // Make the last attribute be the class int theClass = data.numAttributes(); data.setClassIndex(theClass-1); // starts with 0 // int numClusters = data.numClasses(); Instances clusterData = new Instances(data); clusterData.deleteClassAttribute(); WeightedEuclidean euclidean = new WeightedEuclidean(clusterData.numAttributes()); WeightedDotP dotp = new WeightedDotP(clusterData.numAttributes()); // HAC hac = new HAC(euclidean); HAC hac = new HAC(dotp); hac.setVerbose(false); clusterData = hac.filterInstanceDescriptions(clusterData); // cluster without seeding System.out.println("\nClustering the user data ...\n"); hac.setLinkingType(new SelectedTag(COMPLETE_LINK, TAGS_LINKING)); // trim the instances // int i = 6; // while (i < clusterData.numInstances()) { // clusterData.delete(i); //} // cluster with seeding // ArrayList seedArray = new ArrayList(); // for (int i = 0; i < 19; i++) { // seedArray.add(clusterData.instance(i)); // } // seedArray.add(clusterData.instance(0)); // seedArray.add(clusterData.instance(1)); // seedArray.add(clusterData.instance(2)); // seedArray.add(clusterData.instance(3)); // seedArray.add(clusterData.instance(4)); // seedArray.add(clusterData.instance(50)); // seedArray.add(clusterData.instance(51)); // seedArray.add(clusterData.instance(52)); // seedArray.add(clusterData.instance(53)); // seedArray.add(clusterData.instance(54)); // seedArray.add(clusterData.instance(100)); // seedArray.add(clusterData.instance(101)); // seedArray.add(clusterData.instance(102)); // seedArray.add(clusterData.instance(103)); // seedArray.add(clusterData.instance(104)); // Seeder seeder = new Seeder(clusterData, data); // seeder.setVerbose(false); // seeder.createSeeds(seedArray); // HashMap seedHash = seeder.getSeeds(); // hac.setSeedHash(seedHash); HashMap classInstanceHash = new HashMap(); // get the data for each class for (int i = 0; i < data.numInstances(); i++) { Instance instance = data.instance(i); Integer classValue = new Integer((int) instance.classValue()); if (classInstanceHash.containsKey(classValue)) { ArrayList classList = (ArrayList) classInstanceHash.get(classValue); classList.add(new Integer(i)); System.out.println("Seen class; now has " + classList.size() + " elements"); } else { // unseen class System.out.println("Unseen class " + classValue); ArrayList classList = new ArrayList(); classList.add(new Integer(i)); classInstanceHash.put(classValue, classList); } } // sample from the classes that have more than 1 instance double seedProportion = 0.7; ArrayList seedArray = new ArrayList(); Iterator iterator = classInstanceHash.entrySet().iterator(); while (iterator.hasNext()) { Map.Entry entry = (Map.Entry) iterator.next(); ArrayList classList = (ArrayList) entry.getValue(); System.out.println("Classlist for " + entry.getKey() + " has " + classList.size() + " elements\n"); if (classList.size() > 1) { int [] seedIndeces = randomSubset((int) ((classList.size() + 0.0) * seedProportion), classList.size()); System.out.println("Seeding for class " + entry.getKey() + " using " + seedIndeces.length); for (int i = 0; i < seedIndeces.length; i++) { seedArray.add(clusterData.instance(((Integer)(classList.get(seedIndeces[i]))).intValue())); System.out.println("Adding seed " + classList.get(seedIndeces[i])); } } } Seeder seeder = new Seeder(clusterData, data); seeder.setVerbose(false); seeder.createSeeds(seedArray); HashMap seedHash = seeder.getSeeds(); hac.setSeedHash(seedHash); hac.buildClusterer(clusterData, 1); hac.printClusters(); // System.out.println("Cluster assignments: "); // for (int i=0; i < hac.m_clusterAssignments.length; i++) { // System.out.print(i + ":" + hac.m_clusterAssignments[i] + " "); // } // System.out.println("\n\n"); // for (int j = 0; j < clusterData.numInstances(); j++) { // System.out.println(j + ":" + hac.clusterInstance(clusterData.instance(j))); // } //////////////////////////////////////////////////// // HI-DIM TESTING //////////////////////////////////////////////////// //////// Text data - 300 documents // datafile = "/u/ml/software/weka-latest/data/20newsgroups/different-100_fromCCS.arff"; // System.out.println("\nClustering diff-100 newsgroup data with seeding, using constrained HAC...\n"); // // set up the data // reader = new FileReader (datafile); // data = new Instances (reader); // System.out.println("Initial data has size: " + data.numInstances()); // // Make the last attribute be the class // theClass = data.numAttributes(); // data.setClassIndex(theClass-1); // starts with 0 // numClusters = data.numClasses(); // WeightedDotP dotp = new WeightedDotP(data.numAttributes()); // hac = new HAC (dotp); // // cluster with seeding // Instances seeds = new Instances(data, 0, 5); // seeds.add(data.instance(100)); // seeds.add(data.instance(101)); // seeds.add(data.instance(102)); // seeds.add(data.instance(103)); // seeds.add(data.instance(104)); // seeds.add(data.instance(200)); // seeds.add(data.instance(201)); // seeds.add(data.instance(202)); // seeds.add(data.instance(203)); // seeds.add(data.instance(204)); // System.out.println("Labeled data has size: " + seeds.numInstances() + ", number of attributes: " + data.numAttributes()); // data.delete(204); // data.delete(203); // data.delete(202); // data.delete(201); // data.delete(200); // data.delete(104); // data.delete(103); // data.delete(102); // data.delete(101); // data.delete(100); // data.delete(4); // data.delete(3); // data.delete(2); // data.delete(1); // data.delete(0); // System.out.println("Unlabeled data has size: " + data.numInstances()); // // Remove the class labels before clustering // clusterData = new Instances(data); // // clusterData.deleteAttributeAt(theClass-1); // clusterData.deleteClassAttribute(); // hac.setVerbose(false); // hac.setSeedable(true); // hac.buildClusterer(seeds, clusterData, theClass, numClusters); } catch (Exception e) { e.printStackTrace(); } } }