/*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
/*
* ConstrainedKMeans.java
* Copyright (C) 2009-2010 Aristotle University of Thessaloniki, Thessaloniki, Greece
*/
package mulan.classifier.meta;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.rules.DecisionTableHashKey;
import weka.clusterers.NumberOfClustersRequestable;
import weka.clusterers.RandomizableClusterer;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.Capabilities.Capability;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
/**
<!-- globalinfo-start -->
* Cluster data using the k means algorithm
* <p/>
<!-- globalinfo-end -->
*
<!-- options-start -->
* Valid options are: <p/>
*
* <pre> -N <num>
* number of clusters.
* (default 2).</pre>
*
* <pre> -S <num>
* Random number seed.
* (default 10)</pre>
*
<!-- options-end -->
*
* @author Mark Hall (mhall@cs.waikato.ac.nz)
* @author Eibe Frank (eibe@cs.waikato.ac.nz)
* @version $Revision: 1.29 $
* @see RandomizableClusterer
*/
public class ConstrainedKMeans extends RandomizableClusterer implements NumberOfClustersRequestable, WeightedInstancesHandler {
/** for serialization **/
static final long serialVersionUID = -3235809600124455376L;
private ArrayList[] bucket;
private int bucketSize;
private int maxIterations;
@Override
public String getRevision() {
throw new UnsupportedOperationException("Not supported yet.");
}
static public class bucketInstance implements Comparable {
double[] distances;
double distance;
public bucketInstance() {
}
public void setDistances(double[] x) {
distances = new double[x.length];
System.arraycopy(x, 0, distances, 0, x.length);
}
public void setDistance(double x) {
distance = x;
}
public double[] getDistances() {
return distances;
}
public double getDistance() {
return distance;
}
public int compareTo(Object ci) {
double d = ((bucketInstance) ci).getDistance();
if ((this.distance - d) < 0) {
return -1;
} else if (this.distance == d) {
return 0;
} else {
return 1;
}
}
}
/**
* replace missing values in training instances
*/
private ReplaceMissingValues m_ReplaceMissingFilter;
/**
* number of clusters to generate
*/
private int m_NumClusters = 2;
/**
* holds the cluster centroids
*/
private Instances m_ClusterCentroids;
/**
* Holds the standard deviations of the numeric attributes in each cluster
*/
private Instances m_ClusterStdDevs;
/**
* For each cluster, holds the frequency counts for the values of each
* nominal attribute
*/
private int[][][] m_ClusterNominalCounts;
/**
* The number of instances in each cluster
*/
private int[] m_ClusterSizes;
/**
* attribute min values
*/
private double[] m_Min;
/**
* attribute max values
*/
private double[] m_Max;
/**
* Keep track of the number of iterations completed before convergence
*/
private int m_Iterations = 0;
/**
* Holds the squared errors for all clusters
*/
private double[] m_squaredErrors;
/**
* the default constructor
*/
public ConstrainedKMeans() {
super();
m_SeedDefault = 10;
setSeed(m_SeedDefault);
}
/**
* Returns a string describing this clusterer
* @return a description of the evaluator suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Cluster data using the k means algorithm";
}
/**
* Returns default capabilities of the clusterer.
*
* @return the capabilities of this clusterer
*/
@Override
public Capabilities getCapabilities() {
Capabilities result = super.getCapabilities();
result.disableAll();
result.enable(Capability.NO_CLASS);
// attributes
result.enable(Capability.NOMINAL_ATTRIBUTES);
result.enable(Capability.NUMERIC_ATTRIBUTES);
result.enable(Capability.MISSING_VALUES);
return result;
}
public void setMaxIterations(int x) {
maxIterations = x;
}
/**
* Generates a clusterer. Has to initialize all fields of the clusterer
* that are not being set via options.
*
* @param data set of instances serving as training data
* @throws Exception if the clusterer has not been
* generated successfully
*/
public void buildClusterer(Instances data) throws Exception {
for (int i = 0; i < m_NumClusters; i++) {
bucket[i] = new ArrayList<bucketInstance>();
}
// calculate bucket size
bucketSize = (int) Math.ceil(data.numInstances() / (double) m_NumClusters); //System.out.print("bucketSize = " + bucketSize + "\n"); // can clusterer handle the data?
getCapabilities().testWithFail(data);
m_Iterations = 0;
m_ReplaceMissingFilter = new ReplaceMissingValues();
Instances instances = new Instances(data);
instances.setClassIndex(-1);
m_ReplaceMissingFilter.setInputFormat(instances);
instances = Filter.useFilter(instances, m_ReplaceMissingFilter);
m_Min = new double[instances.numAttributes()];
m_Max = new double[instances.numAttributes()];
for (int i = 0; i < instances.numAttributes(); i++) {
m_Min[i] = m_Max[i] = Double.NaN;
}
m_ClusterCentroids = new Instances(instances, m_NumClusters);
int[] clusterAssignments = new int[instances.numInstances()];
for (int i = 0; i < instances.numInstances(); i++) {
updateMinMax(instances.instance(i));
}
Random RandomO = new Random(getSeed());
int instIndex;
HashMap initC = new HashMap();
DecisionTableHashKey hk = null;
for (int j = instances.numInstances() - 1; j >= 0; j--) {
instIndex = RandomO.nextInt(j + 1);
hk = new DecisionTableHashKey(instances.instance(instIndex),
instances.numAttributes(), true);
if (!initC.containsKey(hk)) {
m_ClusterCentroids.add(instances.instance(instIndex));
initC.put(hk, null);
}
instances.swap(j, instIndex);
if (m_ClusterCentroids.numInstances() == m_NumClusters) {
break;
}
}
m_NumClusters = m_ClusterCentroids.numInstances();
int i;
boolean converged = false;
int emptyClusterCount;
Instances[] tempI = new Instances[m_NumClusters];
m_squaredErrors = new double[m_NumClusters];
m_ClusterNominalCounts = new int[m_NumClusters][instances.numAttributes()][0];
while (!converged) {
// reset buckets
for (int j = 0; j < m_NumClusters; j++) {
bucket[j] = new ArrayList<bucketInstance>();
}
emptyClusterCount = 0;
m_Iterations++;
//System.out.println(">>Iterations: "+m_Iterations);
converged = true;
for (i = 0; i < instances.numInstances(); i++) {
//System.out.println("processing instance: " + i);
Instance toCluster = instances.instance(i);
int newC = clusterProcessedInstance(toCluster, true);
if (newC != clusterAssignments[i]) {
converged = false;
}
clusterAssignments[i] = newC;
}
if (m_Iterations > maxIterations) {
converged = true;
}
// update centroids
m_ClusterCentroids = new Instances(instances, m_NumClusters);
for (i = 0; i < m_NumClusters; i++) {
tempI[i] = new Instances(instances, 0);
}
for (i = 0; i < instances.numInstances(); i++) {
tempI[clusterAssignments[i]].add(instances.instance(i));
}
for (i = 0; i < m_NumClusters; i++) {
double[] vals = new double[instances.numAttributes()];
if (tempI[i].numInstances() == 0) {
// empty cluster
emptyClusterCount++;
} else {
for (int j = 0; j < instances.numAttributes(); j++) {
vals[j] = tempI[i].meanOrMode(j);
m_ClusterNominalCounts[i][j] =
tempI[i].attributeStats(j).nominalCounts;
}
m_ClusterCentroids.add(new DenseInstance(1.0, vals));
}
//System.out.println("centroid: " + i + " " + m_ClusterCentroids.instance(i).toString());
}
if (emptyClusterCount > 0) {
m_NumClusters -= emptyClusterCount;
tempI = new Instances[m_NumClusters];
}
if (!converged) {
m_squaredErrors = new double[m_NumClusters];
m_ClusterNominalCounts = new int[m_NumClusters][instances.numAttributes()][0];
}
}
// reset buckets
for (int j = 0; j < m_NumClusters; j++) {
bucket[j] = new ArrayList<bucketInstance>();
}
m_ClusterStdDevs = new Instances(instances, m_NumClusters);
m_ClusterSizes = new int[m_NumClusters];
for (i = 0; i < m_NumClusters; i++) {
double[] vals2 = new double[instances.numAttributes()];
for (int j = 0; j < instances.numAttributes(); j++) {
if (instances.attribute(j).isNumeric()) {
vals2[j] = Math.sqrt(tempI[i].variance(j));
} else {
vals2[j] = Utils.missingValue();
}
}
m_ClusterStdDevs.add(new DenseInstance(1.0, vals2));
m_ClusterSizes[i] = tempI[i].numInstances();
}
}
/**
* clusters an instance that has been through the filters
*
* @param instance the instance to assign a cluster to
* @param updateErrors if true, update the within clusters sum of errors
* @return a cluster number
*/
private int clusterProcessedInstance(Instance instance, boolean updateErrors) {
// calculate distance from bucket centers
double[] distance = new double[m_NumClusters];
for (int i = 0; i < m_NumClusters; i++) {
distance[i] = distance(instance, m_ClusterCentroids.instance(i)); // create a bucket item from the instance
}
bucketInstance ci = new bucketInstance();
ci.setDistances(distance);
// assing item to closest bucket
int bestCluster;
boolean finished;
do {
finished = true;
// add to closestBucket
bestCluster = Utils.minIndex(distance);
//System.out.print("closest bucket: " + closestBucket + "\n");
ci.setDistance(distance[bestCluster]);
//* insert sort
int j;
for (j = 0; j < bucket[bestCluster].size() && ((bucketInstance) bucket[bestCluster].get(j)).compareTo(ci) < 0; j++) {
}
bucket[bestCluster].add(j, ci);
//*/
/* simple insert
bucket[closestBucket].add(ci);
//*/
if (bucket[bestCluster].size() > bucketSize) {
//System.out.println("removing an instance");
ci = (bucketInstance) bucket[bestCluster].remove(bucket[bestCluster].size() - 1);
distance = ci.getDistances();
//System.out.print("distances: " + Arrays.toString(distance) + "\n");
distance[bestCluster] = Double.MAX_VALUE;
ci.setDistances(distance);
finished = false;
}
} while (!finished);
if (updateErrors) {
m_squaredErrors[bestCluster] += distance[bestCluster];
}
return bestCluster;
}
/**
* Classifies a given instance.
*
* @param instance the instance to be assigned to a cluster
* @return the number of the assigned cluster as an interger
* if the class is enumerated, otherwise the predicted value
* @throws Exception if instance could not be classified
* successfully
*/
@Override
public int clusterInstance(Instance instance) throws Exception {
m_ReplaceMissingFilter.input(instance);
m_ReplaceMissingFilter.batchFinished();
Instance inst = m_ReplaceMissingFilter.output();
return clusterProcessedInstance(inst, false);
}
/**
* Calculates the distance between two instances
*
* @param first the first instance
* @param second the second instance
* @return the distance between the two given instances, between 0 and 1
*/
private double distance(Instance first, Instance second) {
double distance = 0;
int firstI, secondI;
for (int p1 = 0, p2 = 0;
p1 < first.numValues() || p2 < second.numValues();) {
if (p1 >= first.numValues()) {
firstI = m_ClusterCentroids.numAttributes();
} else {
firstI = first.index(p1);
}
if (p2 >= second.numValues()) {
secondI = m_ClusterCentroids.numAttributes();
} else {
secondI = second.index(p2);
}
/* if (firstI == m_ClusterCentroids.classIndex()) {
p1++; continue;
}
if (secondI == m_ClusterCentroids.classIndex()) {
p2++; continue;
} */
double diff;
if (firstI == secondI) {
diff = difference(firstI,
first.valueSparse(p1),
second.valueSparse(p2));
p1++;
p2++;
} else if (firstI > secondI) {
diff = difference(secondI,
0, second.valueSparse(p2));
p2++;
} else {
diff = difference(firstI,
first.valueSparse(p1), 0);
p1++;
}
distance += diff * diff;
}
//return Math.sqrt(distance / m_ClusterCentroids.numAttributes());
return distance;
}
/**
* Computes the difference between two given attribute
* values.
*
* @param index the attribute index
* @param val1 the first value
* @param val2 the second value
* @return the difference
*/
private double difference(int index, double val1, double val2) {
switch (m_ClusterCentroids.attribute(index).type()) {
case Attribute.NOMINAL:
// If attribute is nominal
if (Utils.isMissingValue(val1) ||
Utils.isMissingValue(val2) ||
((int) val1 != (int) val2)) {
return 1;
} else {
return 0;
}
case Attribute.NUMERIC:
// If attribute is numeric
if (Utils.isMissingValue(val1) ||
Utils.isMissingValue(val2)) {
if (Utils.isMissingValue(val1) &&
Utils.isMissingValue(val2)) {
return 1;
} else {
double diff;
if (Utils.isMissingValue(val2)) {
diff = norm(val1, index);
} else {
diff = norm(val2, index);
}
if (diff < 0.5) {
diff = 1.0 - diff;
}
return diff;
}
} else {
return norm(val1, index) - norm(val2, index);
}
default:
return 0;
}
}
/**
* Normalizes a given value of a numeric attribute.
*
* @param x the value to be normalized
* @param i the attribute's index
* @return the normalized value
*/
private double norm(double x, int i) {
if (Double.isNaN(m_Min[i]) || Utils.eq(m_Max[i], m_Min[i])) {
return 0;
} else {
return (x - m_Min[i]) / (m_Max[i] - m_Min[i]);
}
}
/**
* Updates the minimum and maximum values for all the attributes
* based on a new instance.
*
* @param instance the new instance
*/
private void updateMinMax(Instance instance) {
for (int j = 0; j < m_ClusterCentroids.numAttributes(); j++) {
if (!instance.isMissing(j)) {
if (Double.isNaN(m_Min[j])) {
m_Min[j] = instance.value(j);
m_Max[j] = instance.value(j);
} else {
if (instance.value(j) < m_Min[j]) {
m_Min[j] = instance.value(j);
} else {
if (instance.value(j) > m_Max[j]) {
m_Max[j] = instance.value(j);
}
}
}
}
}
}
/**
* Returns the number of clusters.
*
* @return the number of clusters generated for a training dataset.
* @throws Exception if number of clusters could not be returned
* successfully
*/
public int numberOfClusters() throws Exception {
return m_NumClusters;
}
/**
* Returns an enumeration describing the available options.
*
* @return an enumeration of all the available options.
*/
@Override
public Enumeration listOptions() {
Vector result = new Vector();
result.addElement(new Option(
"\tnumber of clusters.\n" + "\t(default 2).",
"N", 1, "-N <num>"));
Enumeration en = super.listOptions();
while (en.hasMoreElements()) {
result.addElement(en.nextElement());
}
return result.elements();
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numClustersTipText() {
return "set number of clusters";
}
/**
* set the number of clusters to generate
*
* @param n the number of clusters to generate
* @throws Exception if number of clusters is negative
*/
public void setNumClusters(int n) throws Exception {
if (n <= 0) {
throw new Exception("Number of clusters must be > 0");
}
m_NumClusters = n;
bucket = new ArrayList[n];
}
/**
* gets the number of clusters to generate
*
* @return the number of clusters to generate
*/
public int getNumClusters() {
return m_NumClusters;
}
/**
* Parses a given list of options. <p/>
*
<!-- options-start -->
* Valid options are: <p/>
*
* <pre> -N <num>
* number of clusters.
* (default 2).</pre>
*
* <pre> -S <num>
* Random number seed.
* (default 10)</pre>
*
<!-- options-end -->
*
* @param options the list of options as an array of strings
* @throws Exception if an option is not supported
*/
@Override
public void setOptions(String[] options)
throws Exception {
String optionString = Utils.getOption('N', options);
if (optionString.length() != 0) {
setNumClusters(Integer.parseInt(optionString));
}
super.setOptions(options);
}
/**
* Gets the current settings of SimpleKMeans
*
* @return an array of strings suitable for passing to setOptions()
*/
@Override
public String[] getOptions() {
int i;
Vector result;
String[] options;
result = new Vector();
result.add("-N");
result.add("" + getNumClusters());
options = super.getOptions();
for (i = 0; i < options.length; i++) {
result.add(options[i]);
}
return (String[]) result.toArray(new String[result.size()]);
}
/**
* return a string describing this clusterer
*
* @return a description of the clusterer as a string
*/
@Override
public String toString() {
int maxWidth = 0;
for (int i = 0; i < m_NumClusters; i++) {
for (int j = 0; j < m_ClusterCentroids.numAttributes(); j++) {
if (m_ClusterCentroids.attribute(j).isNumeric()) {
double width = Math.log(Math.abs(m_ClusterCentroids.instance(i).value(j))) /
Math.log(10.0);
width += 1.0;
if ((int) width > maxWidth) {
maxWidth = (int) width;
}
}
}
}
StringBuffer temp = new StringBuffer();
String naString = "N/A";
for (int i = 0; i < maxWidth + 2; i++) {
naString += " ";
}
temp.append("\nkMeans\n======\n");
temp.append("\nNumber of iterations: " + m_Iterations + "\n");
temp.append("Within cluster sum of squared errors: " + Utils.sum(m_squaredErrors));
temp.append("\n\nCluster centroids:\n");
for (int i = 0; i < m_NumClusters; i++) {
temp.append("\nCluster " + i + "\n\t");
temp.append("Mean/Mode: ");
for (int j = 0; j < m_ClusterCentroids.numAttributes(); j++) {
if (m_ClusterCentroids.attribute(j).isNominal()) {
temp.append(" " + m_ClusterCentroids.attribute(j).
value((int) m_ClusterCentroids.instance(i).value(j)));
} else {
temp.append(" " + Utils.doubleToString(m_ClusterCentroids.instance(i).value(j),
maxWidth + 5, 4));
}
}
temp.append("\n\tStd Devs: ");
for (int j = 0; j < m_ClusterStdDevs.numAttributes(); j++) {
if (m_ClusterStdDevs.attribute(j).isNumeric()) {
temp.append(" " + Utils.doubleToString(m_ClusterStdDevs.instance(i).value(j),
maxWidth + 5, 4));
} else {
temp.append(" " + naString);
}
}
}
temp.append("\n\n");
return temp.toString();
}
/**
* Gets the the cluster centroids
*
* @return the cluster centroids
*/
public Instances getClusterCentroids() {
return m_ClusterCentroids;
}
/**
* Gets the standard deviations of the numeric attributes in each cluster
*
* @return the standard deviations of the numeric attributes
* in each cluster
*/
public Instances getClusterStandardDevs() {
return m_ClusterStdDevs;
}
/**
* Returns for each cluster the frequency counts for the values of each
* nominal attribute
*
* @return the counts
*/
public int[][][] getClusterNominalCounts() {
return m_ClusterNominalCounts;
}
/**
* Gets the squared error for all clusters
*
* @return the squared error
*/
public double getSquaredError() {
return Utils.sum(m_squaredErrors);
}
/**
* Gets the number of instances in each cluster
*
* @return The number of instances in each cluster
*/
public int[] getClusterSizes() {
return m_ClusterSizes;
}
/**
* Main method for testing this class.
*
* @param argv should contain the following arguments: <p>
* -t training file [-N number of clusters]
*/
public static void main(String[] argv) {
runClusterer(new SimpleKMeans(), argv);
}
}