/* * To change this template, choose Tools | Templates * and open the template in the editor. */ package weka.clusterers; import weka.classifiers.rules.DecisionTableHashKey; import weka.core.Attribute; import weka.core.Capabilities; import weka.core.CapabilitiesHandler; import weka.core.Capabilities.Capability; import weka.core.DistanceFunction; import weka.core.EuclideanDistance; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.OptionHandler; import weka.core.RevisionUtils; import weka.core.Utils; import weka.core.WeightedInstancesHandler; import weka.filters.Filter; import weka.filters.unsupervised.attribute.ReplaceMissingValues; import java.util.Vector; import java.util.Enumeration; import java.util.Random; import java.util.HashMap; public class kModes extends RandomizableClusterer implements NumberOfClustersRequestable, WeightedInstancesHandler, OptionHandler, CapabilitiesHandler { //Private Members private int m_numClusters; private int m_currentIteration; private int m_maxIterations; private DistanceFunction m_distanceFunction; private Instances m_clusterCenters; private Instances[] m_clusterDistribution; private int[] m_previousAssignment; private boolean m_dontReplaceMissing = false; private ReplaceMissingValues m_ReplaceMissingFilter; private double[] m_clusterErrors; //end of Private Members //Constructors public kModes() { m_numClusters = 2; m_maxIterations = 100; m_distanceFunction = new EuclideanDistance(); } //end of Constructors //Randomizable Clusterer public void buildClusterer(Instances data) throws Exception { //is data consistent with cluster capabilities getCapabilities().testWithFail(data); Random rand = new Random(getSeed()); int spaceCount = getMaxUniqueInstances(data); if(spaceCount < m_numClusters || data.numInstances() < m_numClusters) { m_numClusters = data.numInstances(); } m_clusterCenters = new Instances(data, m_numClusters); m_ReplaceMissingFilter = new ReplaceMissingValues(); Instances inst = new Instances(data); inst.setClassIndex(-1); if (!m_dontReplaceMissing) { m_ReplaceMissingFilter.setInputFormat(inst); inst = Filter.useFilter(inst, m_ReplaceMissingFilter); } m_distanceFunction.setInstances(inst); int count = inst.numInstances()-1; int clusterIndex; HashMap initC = new HashMap(); DecisionTableHashKey hk = null; for (int j = count; j >= 0; j--) { clusterIndex = rand.nextInt(j+1); hk = new DecisionTableHashKey(inst.instance(clusterIndex), inst.numAttributes(), true); if (!initC.containsKey(hk)) { m_clusterCenters.add(inst.instance(clusterIndex)); initC.put(hk, null); } inst.swap(j, clusterIndex); if (m_clusterCenters.numInstances() == m_numClusters) { break; } } boolean finished = false; m_currentIteration = 0; m_clusterDistribution = new Instances[m_numClusters]; m_previousAssignment = new int[count+1]; int emptyClustCount = 0; m_clusterErrors = new double[m_numClusters]; while(!finished) { finished = true; m_currentIteration++; for(int i = 0; i < inst.numInstances(); i++) { Instance next = inst.instance(i); int newClust = clusterFilteredInstance(next, true); if(newClust != m_previousAssignment[i]) { m_previousAssignment[i] = newClust; finished = false; } } if(!finished) { m_clusterCenters = new Instances(inst, m_numClusters); for (int i = 0; i < m_numClusters; i++) { m_clusterDistribution[i] = new Instances(inst, 0); } //update Clusters logic for(int i = 0; i < inst.numInstances(); i++) { m_clusterDistribution[m_previousAssignment[i]].add(inst.instance(i)); } for (int i = 0; i < m_numClusters; i++) { if (m_clusterDistribution[i].numInstances() == 0) { emptyClustCount++; } else { recalcCenters(m_clusterDistribution[i]); } } } if (emptyClustCount > 0) { m_numClusters -= emptyClustCount; if (finished) { Instances[] newClusterDistribution = new Instances[m_numClusters]; int index = 0; for (int k = 0; k < m_clusterDistribution.length; k++) { if (m_clusterDistribution[k].numInstances() > 0) { newClusterDistribution[index++] = m_clusterDistribution[k]; } } m_clusterDistribution = newClusterDistribution; } else { m_clusterDistribution = new Instances[m_numClusters]; } } if(!finished) { m_clusterErrors = new double[m_numClusters]; } if(m_currentIteration == m_maxIterations) finished = true; } // Substitue the sum of the similarity errors with the average similarity error. for(int f =0; f < m_clusterErrors.length; f++) { m_clusterErrors[f] /= m_clusterDistribution[f].numInstances(); } } public int clusterInstance(Instance instance) throws Exception { Instance toCluster = null; if (!m_dontReplaceMissing) { m_ReplaceMissingFilter.input(instance); m_ReplaceMissingFilter.batchFinished(); toCluster = m_ReplaceMissingFilter.output(); } else { toCluster = instance; } return clusterFilteredInstance(toCluster, false); } private int clusterFilteredInstance(Instance instance, boolean updateErrors) throws Exception { double minDist = m_distanceFunction.distance(instance, m_clusterCenters.instance(0)); int retValue = 0; for (int i = 1; i < m_numClusters; i++) { double dist = m_distanceFunction.distance(instance, m_clusterCenters.instance(i)); if (dist < minDist) { minDist = dist; retValue = i; } } if (updateErrors) { if(m_distanceFunction instanceof EuclideanDistance){ //Euclidean distance to Squared Euclidean distance minDist *= minDist; } m_clusterErrors[retValue] += minDist; } return retValue; } public int numberOfClusters() { return m_numClusters; } //end of Randomizable Clusterer //NumberOfClustersRequestable public void setNumClusters(int numClusters) throws Exception { if(numClusters <= 0) { throw new Exception("Number of clusters must be > 0"); } m_numClusters = numClusters; } //end of NumberOfClustersRequestable //OptionHandler public Enumeration listOptions() { Vector result = new Vector(); result.addElement(new Option( "\tReplace missing values with mean/mode.\n", "M", 0, "-M")); result.addElement(new Option( "\tNumber of clusters.\n" + "\t(default 2).", "N", 1, "-N <num>")); result.add(new Option( "\tMaximum number of iterations.\n", "I",1,"-I <num>")); Enumeration en = super.listOptions(); while (en.hasMoreElements()) result.addElement(en.nextElement()); return result.elements(); } public void setOptions(String[] options) throws Exception { m_dontReplaceMissing = Utils.getFlag("M", options); String optionString = Utils.getOption('N', options); if (optionString.length() != 0) { setNumClusters(Integer.parseInt(optionString)); } else { setNumClusters(2); } optionString = Utils.getOption("I", options); if (optionString.length() != 0) { setMaxIterations(Integer.parseInt(optionString)); } else { setMaxIterations(50); } super.setOptions(options); } public String[] getOptions() { int i; Vector result = new Vector(); String[] options; if (m_dontReplaceMissing) { result.add("-M"); } result.add("-N"); result.add(""+ numberOfClusters()); result.add("-I"); result.add(""+ getMaxIterations()); options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); return (String[]) result.toArray(new String[result.size()]); } //end of OptionHandler //CapabilitiesHandler public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); //class result.enable(Capability.NO_CLASS); result.enable(Capability.NOMINAL_CLASS); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); return result; } //end of CapabilitiesHandler //HelperFunctions public void setMaxIterations(int n) throws Exception { if (n <= 0) { throw new Exception("Maximum number of iterations must be > 0"); } m_maxIterations = n; } public int getMaxIterations() { return m_maxIterations; } public int getNumClusters() { return m_numClusters; } private double[] recalcCenters(Instances members) throws Exception { double [] vals = new double[members.numAttributes()]; for (int j = 0; j < members.numAttributes(); j++) { vals[j] = members.meanOrMode(j); if (members.attributeStats(j).missingCount > members.attributeStats(j).nominalCounts[Utils.maxIndex(members.attributeStats(j).nominalCounts)]) vals[j] = Instance.missingValue(); } m_clusterCenters.add(new Instance(1.0,vals)); return vals; } public void setDontReplaceMissingValues(boolean value) { m_dontReplaceMissing = value; } public boolean getDontReplaceMissingValues() { return m_dontReplaceMissing; } private int getMaxUniqueInstances(Instances data) { int retValue = 1; for(int i = 0; i < data.numAttributes(); i++) { weka.core.AttributeStats stats = data.attributeStats(i); retValue *= stats.distinctCount; } return retValue; } //end HelperFunctions //GUI public String toString() { String resultString = new String(); resultString = resultString.concat("Number of clusters: "); resultString = resultString.concat(m_numClusters + "\n"); resultString = resultString.concat("\n Cluster centroids:\n"); for (int i = 0; i < m_numClusters; ++i){ resultString = resultString.concat("Cluster " + i + " centroid: "); resultString = resultString.concat(m_clusterCenters.instance(i).toString() + "\n"); } resultString = resultString.concat("\nCluster average similarity errors:\n"); for (int r = 0; r < m_numClusters; ++r){ resultString = resultString.concat("Cluster " + r + " average similarity error: "); resultString = resultString.concat(m_clusterErrors[r] + "\n"); } resultString = resultString.concat("\nSum of the clusters average similarity errors: " + Utils.sum(m_clusterErrors) + "\n"); // resultString = resultString.concat("\n"); // for (int i = 0; i < m_numClusters; i++){ // resultString = resultString.concat("Cluster " + i + " contains the following instances: \n"); // for( int j = 0; j< m_clusterDistribution[i].numInstances(); j++){ // resultString = resultString.concat(m_clusterDistribution[i].instance(j).toString()); // resultString = resultString.concat("\n"); // } // resultString = resultString.concat("=======================================\n"); // } resultString = resultString.concat("\n"); return resultString; } //GUI Info public String globalInfo() { return "Basic kModes algorithm for clustering data, containing nominal only values."; } public String numClustersTipText() { return "Set the number of clusters required."; } public String maxIterationsTipText() { return "Set maximum number of algorithm cycle iterations."; } public String dontReplaceMissingValuesTipText() { return "Replace missing values globally with mean/mode."; } //end GUI Info //end GUI public static void main (String[] argv) { runClusterer(new kMeans(), argv); } }