/* * To change this template, choose Tools | Templates * and open the template in the editor. */ package weka.clusterers; import weka.core.Attribute; import weka.core.Capabilities; import weka.core.DistanceFunction; import weka.core.EuclideanDistance; import weka.core.Instance; import weka.core.Instances; import weka.core.ManhattanDistance; import weka.core.KPrototypes_DistanceFunction; import weka.core.Option; import weka.core.RevisionUtils; import weka.core.Utils; import weka.core.WeightedInstancesHandler; import weka.core.Capabilities.Capability; import weka.filters.Filter; import weka.filters.unsupervised.attribute.ReplaceMissingValues; import java.util.Enumeration; import java.util.HashMap; import java.util.LinkedList; import java.util.Random; import java.util.Vector; /** * * @author MetalGearRex */ public class BisectingKMeans extends RandomizableClusterer implements NumberOfClustersRequestable, WeightedInstancesHandler { /** for serialization */ static final long serialVersionUID = -3235809600124455376L; /** * replace missing values in training instances */ private ReplaceMissingValues m_ReplaceMissingFilter; /** * number of clusters to generate */ private int m_NumClusters = 2; /** * Replace missing values globally? */ private boolean m_dontReplaceMissing = false; /** * The number of instances in each cluster */ private int [] m_ClusterSizes; /** * Maximum number of iterations to be executed by the K-means subalgorithm */ private int m_MaxIterations = 500; /** * Holds the errors for all clusters */ private double [] m_ClusterErrors; /** the distance function used. */ protected DistanceFunction m_DistanceFunction = new EuclideanDistance(); /** * Assignments obtained */ protected int[] m_Assignments = null; /** * Number of executions of the K-means subalgorithm at each splitting */ private int m_NumExecutions = 2; /** * The resulting clusters */ private Vector<Instances> m_Clusters; /** * The cluster centroids */ private Instance[] m_ClusterCentroids; /** * A hash map that holds to which cluster an instance belongs */ private HashMap<String, Integer> m_ClusterIndices; /** * The possible ways to choose the cluster to split */ private int m_wayToChooseClusterToSplit = 1; /** * the default constructor */ public BisectingKMeans() { super(); m_SeedDefault = 10; setSeed(m_SeedDefault); } /** * Returns a string describing this clusterer * @return a description of the evaluator suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Cluster data using the bisecting k means algorithm; Can use either " + "the Euclidean distance (default) or the Manhattan distance;" + " If the Manhattan distance is used, then centroids are computed " + "as the component-wise median rather than mean."; } /** * Returns default capabilities of the clusterer. * * @return the capabilities of this clusterer */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); result.enable(Capability.NO_CLASS); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); return result; } /** * Chooses a cluster to split into two * * @param clusters vector of the clusters to choose from * @param seed seed for the random number generator * * @return the index in the vector of the chosen cluster */ private int chooseClusterToSplit(int seed) throws Exception { int clusterIndex = 0; switch (m_wayToChooseClusterToSplit){ case 1: // With highest error double highest_error = -1; for (int i = 0; i < m_Clusters.size(); ++i){ if (highest_error < m_ClusterErrors[i]){ highest_error = m_ClusterErrors[i]; clusterIndex = i; } } break; case 2: // With highest count of instances int maxInstances = 0; for (int i = 0; i < m_Clusters.size(); ++i){ if (maxInstances < m_Clusters.get(i).numInstances()){ clusterIndex = i; maxInstances = m_Clusters.get(i).numInstances(); } } break; default: throw new Exception("BisectingKMeans currently only supports 2 ways to choose a cluster to split. Check the tooltip for description"); } return clusterIndex; } public void buildClusterer(Instances data) throws Exception { getCapabilities().testWithFail(data); m_ReplaceMissingFilter = new ReplaceMissingValues(); Instances instances = new Instances(data); instances.setClassIndex(-1); if (!m_dontReplaceMissing) { m_ReplaceMissingFilter.setInputFormat(instances); instances = Filter.useFilter(instances, m_ReplaceMissingFilter); } // all the instances are assigned to cluster 0 m_Assignments = new int [instances.numInstances()]; m_ClusterCentroids = new Instance[m_NumClusters]; m_ClusterErrors = new double[m_NumClusters]; Random RandomO = new Random(getSeed()); m_ClusterSizes = new int[m_NumClusters]; m_Clusters = new Vector<Instances>(); m_Clusters.add(instances); m_ClusterIndices = new HashMap<String, Integer>(); for (int i = 0; i < instances.numInstances(); ++i){ m_ClusterIndices.put(instances.instance(i).toString(), 0); } while (m_Clusters.size() < m_NumClusters){ int clusterIndex = chooseClusterToSplit(RandomO.nextInt()); Instances clusterToSplit = m_Clusters.get(clusterIndex); double minimumError = 1.79769313486231570e+308d; // largest Java number Instances first = new Instances(data, 0), second = new Instances(data, 0); Instance firstCentroid = null, secondCentroid = null; double firstError = 0, secondError = 0; for (int l = 0; l < m_NumExecutions; l++){ // create and configure the K-Means subalgorithm //weka.clusterers.kMeans kMeans = new weka.clusterers.kMeans(); weka.clusterers.SimpleKMeans kMeans = new weka.clusterers.SimpleKMeans(); kMeans.setDisplayStdDevs(false); kMeans.setDistanceFunction(m_DistanceFunction); kMeans.setDontReplaceMissingValues(m_dontReplaceMissing); kMeans.setMaxIterations(m_MaxIterations); kMeans.setNumClusters(2); // always split into two subclusters //kMeans.setPreserveInstancesOrder(false); // no need for that kMeans.setSeed(RandomO.nextInt()); kMeans.buildClusterer(clusterToSplit); // FIXME: think about supporting other types of error calculating double currentError = kMeans.getSquaredError(); if (currentError < minimumError){ // update the set of clusters with the new clusters first.delete(); second.delete(); for (int i = 0; i < clusterToSplit.numInstances(); ++i){ Instance nextInstance = clusterToSplit.instance(i); if (kMeans.clusterInstance(nextInstance) == 0){ first.add(nextInstance); } else { second.add(nextInstance); } } // FIXME: There should be a better way to get the two centroids. Instances centroids = kMeans.getClusterCentroids(); firstCentroid = centroids.instance(0); centroids.delete(0); secondCentroid = centroids.instance(0); firstError = kMeans.getClusterErrors()[0]; secondError = kMeans.getClusterErrors()[1]; minimumError = currentError; } } // update the set of clusters with the new clusters m_Clusters.set(clusterIndex, first); m_Clusters.add(second); // mark the instances of the split cluster as belonging to one of the two new clusters for (int l = 0; l < first.numInstances(); ++l){ m_ClusterIndices.put(first.instance(l).toString(), clusterIndex); } for (int l = 0; l < second.numInstances(); ++l){ m_ClusterIndices.put(second.instance(l).toString(), m_Clusters.size() - 1); } // update the centroids and errors of the new clusters m_ClusterCentroids[clusterIndex] = firstCentroid; m_ClusterCentroids[m_Clusters.size() - 1] = secondCentroid; m_ClusterErrors[clusterIndex] = firstError; m_ClusterErrors[m_Clusters.size() - 1] = secondError; } // set the indices of respective clusters for each instance for (int i = 0; i < instances.numInstances(); ++i){ m_Assignments[i] = m_ClusterIndices.get(instances.instance(i).toString()); } // set the sizes of each cluster for (int i = 0; i < m_NumClusters; ++i){ m_ClusterSizes[i] = m_Clusters.get(i).numInstances(); } } /** * clusters an instance that has been through the filters * * @param instance the instance to assign a cluster to * @param updateErrors if true, update the within clusters sum of errors * @return a cluster number */ private int clusterProcessedInstance(Instance instance, boolean updateErrors) { double minDist = Integer.MAX_VALUE; int bestCluster = 0; for (int i = 0; i < m_NumClusters; i++) { double dist = m_DistanceFunction.distance(instance, m_ClusterCentroids[i]); if (dist < minDist) { minDist = dist; bestCluster = i; } } if (updateErrors) { if(m_DistanceFunction instanceof EuclideanDistance){ //Euclidean distance to Squared Euclidean distance minDist *= minDist; } m_ClusterErrors[bestCluster] += minDist; } return bestCluster; } /** * Classifies a given instance. * * @param instance the instance to be assigned to a cluster * @return the number of the assigned cluster as an interger * if the class is enumerated, otherwise the predicted value * @throws Exception if instance could not be classified * successfully */ @Override public int clusterInstance(Instance instance) throws Exception { Instance inst = null; if (!m_dontReplaceMissing) { m_ReplaceMissingFilter.input(instance); m_ReplaceMissingFilter.batchFinished(); inst = m_ReplaceMissingFilter.output(); } else { inst = instance; } return clusterProcessedInstance(inst, false); } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numClustersTipText() { return "set number of clusters"; } /** * set the number of clusters to generate * * @param n the number of clusters to generate * @throws Exception if number of clusters is negative */ public void setNumClusters(int n) throws Exception { if (n <= 0) { throw new Exception("Number of clusters must be > 0"); } m_NumClusters = n; } /** * gets the number of clusters to generate * * @return the number of clusters to generate */ public int getNumClusters() { return m_NumClusters; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String maxIterationsTipText() { return "set maximum number of iterations"; } /** * set the maximum number of iterations to be executed * * @param n the maximum number of iterations * @throws Exception if maximum number of iteration is smaller than 1 */ public void setMaxIterations(int n) throws Exception { if (n <= 0) { throw new Exception("Maximum number of iterations must be > 0"); } m_MaxIterations = n; } /** * gets the number of maximum iterations to be executed * * @return the number of clusters to generate */ public int getMaxIterations() { return m_MaxIterations; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numExecutionsTipText() { return "set number of executions of the K-means subalgorithm"; } /** * set the maximum number of iterations to be executed * * @param n the maximum number of iterations * @throws Exception if maximum number of iteration is smaller than 1 */ public void setNumExecutions(int n) throws Exception { if (n <= 0) { throw new Exception("Number of executions must be > 0"); } m_MaxIterations = n; } /** * gets the number of maximum iterations to be executed * * @return the number of clusters to generate */ public int getNumExecutions() { return m_NumExecutions; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String dontReplaceMissingValuesTipText() { return "Replace missing values globally with mean/mode."; } /** * Sets whether missing values are to be replaced * * @param r true if missing values are to be * replaced */ public void setDontReplaceMissingValues(boolean r) { m_dontReplaceMissing = r; } /** * Gets whether missing values are to be replaced * * @return true if missing values are to be * replaced */ public boolean getDontReplaceMissingValues() { return m_dontReplaceMissing; } /** * Returns the tip text for this property. * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String distanceFunctionTipText() { return "The distance function to use for instances comparison " + "(default: weka.core.EuclideanDistance). "; } /** * returns the distance function currently in use. * * @return the distance function */ public DistanceFunction getDistanceFunction() { return m_DistanceFunction; } /** * sets the distance function to use for instance comparison. * * @param df the new distance function to use * @throws Exception if instances cannot be processed */ public void setDistanceFunction(DistanceFunction df) throws Exception { if(!(df instanceof EuclideanDistance) && !(df instanceof ManhattanDistance) && !(df instanceof KPrototypes_DistanceFunction)) { throw new Exception("BisectingKMeans currently only supports the Euclidean, Manhattan and KPrototypes distances."); } m_DistanceFunction = df; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String wayToChooseClusterToSplitTipText() { return "Way to choose the cluster to split:" + " 1 = With highest average squared error;" + " 2 = With highest count of instances;"; } /** * Sets the chosen way to choose the cluster to split * * @param w the way to choose the cluster to split * look at the tooltip for description */ public void setWayToChooseClusterToSplit(int w) throws Exception { if ((w <= 0) || (w > 2)) { throw new Exception("BisectingKMeans currently only supports 2 ways to choose a cluster to split. Check the tooltip for description"); } m_wayToChooseClusterToSplit = w; } /** * Gets the chosen way to choose the cluster to split * * @return w the way to choose the cluster to split * look at the tooltip for description */ public int getWayToChooseClusterToSplit() { return m_wayToChooseClusterToSplit; } /** * Returns the number of clusters. * * @return the number of clusters generated for a training dataset. * @throws Exception if number of clusters could not be returned * successfully */ public int numberOfClusters() throws Exception { return m_NumClusters; } public Enumeration listOptions () { Vector result = new Vector(); result.addElement(new Option( "\tnumber of clusters.\n" + "\t(default 2).", "N", 1, "-N <num>")); result.addElement(new Option( "\tReplace missing values with mean/mode.\n", "M", 0, "-M")); result.add(new Option( "\tDistance function to use.\n" + "\t(default: weka.core.EuclideanDistance)", "A", 1,"-A <classname and options>")); result.add(new Option( "\tMaximum number of iterations.\n", "I",1,"-I <num>")); result.addElement(new Option( "\tNumber of executions of the K-means subalgorithm.\n", "X", 1, "-X")); result.addElement(new Option( "\tWay to choose the cluster to split.\n", "W", 1, "-W")); Enumeration en = super.listOptions(); while (en.hasMoreElements()) result.addElement(en.nextElement()); return result.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -N <num> * number of clusters. * (default 2). * </pre> * * <pre> -V * Display std. deviations for centroids. * </pre> * * <pre> -M * Replace missing values with mean/mode. * </pre> * * <pre> -S <num> * Random number seed. * (default 10) * </pre> * * <pre> -A <classname and options> * Distance function to be used for instance comparison * (default weka.core.EuclidianDistance) * </pre> * * <pre> -I <num> * Maximum number of iterations of the K-means subalgorithm. * </pre> * * <pre> -O * Preserve order of instances. * </pre> * * <pre> -X * Number of executions of the K-means subalgorithm at each splitting. * </pre> * * <pre> -W * Way to choose the cluster to split. * </pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions (String[] options) throws Exception { m_dontReplaceMissing = Utils.getFlag("M", options); String optionString = Utils.getOption('N', options); if (optionString.length() != 0) { setNumClusters(Integer.parseInt(optionString)); } optionString = Utils.getOption("I", options); if (optionString.length() != 0) { setMaxIterations(Integer.parseInt(optionString)); } optionString = Utils.getOption("X", options); if (optionString.length() != 0) { setNumExecutions(Integer.parseInt(optionString)); } optionString = Utils.getOption("W", options); if (optionString.length() != 0) { setWayToChooseClusterToSplit(Integer.parseInt(optionString)); } String distFunctionClass = Utils.getOption('A', options); if(distFunctionClass.length() != 0) { String distFunctionClassSpec[] = Utils.splitOptions(distFunctionClass); if(distFunctionClassSpec.length == 0) { throw new Exception("Invalid DistanceFunction specification string."); } String className = distFunctionClassSpec[0]; distFunctionClassSpec[0] = ""; setDistanceFunction( (DistanceFunction) Utils.forName( DistanceFunction.class, className, distFunctionClassSpec) ); } else { setDistanceFunction(new EuclideanDistance()); } super.setOptions(options); } /** * Gets the current settings of BisectingKMeans * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions () { int i; Vector result; String[] options; result = new Vector(); if (m_dontReplaceMissing) { result.add("-M"); } result.add("-N"); result.add("" + getNumClusters()); result.add("-A"); result.add((m_DistanceFunction.getClass().getName() + " " + Utils.joinOptions(m_DistanceFunction.getOptions())).trim()); result.add("-I"); result.add(""+ getMaxIterations()); result.add("-X"); result.add(""+ getNumExecutions()); result.add("-В"); result.add(""+ getWayToChooseClusterToSplit()); options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); return (String[]) result.toArray(new String[result.size()]); } /** * return a string describing this clusterer * * @return a description of the clusterer as a string */ public String toString() { String resultString = new String(); resultString = resultString.concat("Number of clusters: "); resultString = resultString.concat(m_NumClusters + "\n"); resultString = resultString.concat("Number of executions of the subalgorithm: "); resultString = resultString.concat(m_NumExecutions + "\n"); resultString = resultString.concat("Max iterations of the subalgorithm: "); resultString = resultString.concat(m_MaxIterations + "\n"); resultString = resultString.concat("\n Cluster centroids:\n"); for (int i = 0; i < m_NumClusters; ++i){ resultString = resultString.concat("Cluster " + i + " centroid: "); resultString = resultString.concat(m_ClusterCentroids[i].toString() + "\n"); } resultString = resultString.concat("\nCluster average squared errors:\n"); for (int i = 0; i < m_NumClusters; ++i){ resultString = resultString.concat("Cluster " + i + " average squared error: "); resultString = resultString.concat(m_ClusterErrors[i] + "\n"); } resultString = resultString.concat("\n"); resultString = resultString.concat("\nSum of the clusters average squared errors: " + Utils.sum(m_ClusterErrors) + "\n"); // for (int i = 0; i < m_NumClusters; i++){ // resultString = resultString.concat("Cluster " + i + " contains the following instances: \n"); // for( int j = 0; j< m_Clusters.get(i).numInstances(); j++){ // resultString = resultString.concat(m_Clusters.get(i).instance(j).toString()); // resultString = resultString.concat("\n"); // } // resultString = resultString.concat("=======================================\n"); // } resultString = resultString.concat("\n"); return resultString; } /** * Gets the cluster centroids * * @return the cluster centroids */ public Instance[] getClusterCentroids() { return m_ClusterCentroids; } /** * Gets the cluster errors * * @return the cluster errors */ public double[] getClusterErrors() { return m_ClusterErrors; } /** * Gets the error for all clusters * * @return the error */ public double getErrors() { return Utils.sum(m_ClusterErrors); } /** * Gets the number of instances in each cluster * * @return The number of instances in each cluster */ public int [] getClusterSizes() { return m_ClusterSizes; } /** * Gets the assignments for each instance * @return Array of indexes of the centroid assigned to each instance * @throws Exception if order of instances wasn't preserved or no assignments were made */ public int [] getAssignments() throws Exception{ if(m_Assignments == null){ throw new Exception("No assignments made."); } return m_Assignments; } /** * Returns the revision string. * * @return the revision */ public String getRevision() { return RevisionUtils.extract("$Revision: 1000 $"); } public static void main (String[] argv) { runClusterer(new BisectingKMeans(), argv); } }