/*
* Copyright 2004-2010 Information & Software Engineering Group (188/1)
* Institute of Software Technology and Interactive Systems
* Vienna University of Technology, Austria
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.ifs.tuwien.ac.at/dm/somtoolbox/license.html
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package at.tuwien.ifs.somtoolbox.visualization.clustering;
import java.util.Vector;
import org.apache.commons.lang.ArrayUtils;
import at.tuwien.ifs.somtoolbox.layers.metrics.DistanceMetric;
import at.tuwien.ifs.somtoolbox.layers.metrics.L2Metric;
import at.tuwien.ifs.somtoolbox.layers.metrics.MetricException;
/**
* A Cluster used in KMeans clustering. Has a centroid and a number of indices of a data set assigned to it.
*
* @see KMeans
* @author Robert Neumayer
* @version $Id: Cluster.java 3583 2010-05-21 10:07:41Z mayer $
*/
public class Cluster {
private static final int MAX_DIM_DEBUG = 500;
private static final int MAX_INDICES_DEBUG = 150;
private Vector<Integer> indices;
private double[] centroid;
private DistanceMetric distanceFunction;
public Cluster() {
indices = new Vector<Integer>();
// defaulting to Euclidean distance.
distanceFunction = new L2Metric();
}
public Cluster(double[] centroid) {
this();
this.centroid = centroid;
}
public Cluster(double[] centroid, DistanceMetric distanceFunction) {
this(centroid);
this.distanceFunction = distanceFunction;
}
public Cluster(DistanceMetric distanceFunction) {
this();
this.distanceFunction = distanceFunction;
}
/**
* Calculate the centroid of this cluster. This is done by summing up all individual values divided by the number of
* instances assigned to it.
*
* @param data the data set.
*/
public void calculateCentroid(double[][] data) {
for (int instanceIndex = 0; instanceIndex < indices.size(); instanceIndex++) {
for (int attributeIndex = 0; attributeIndex < data[indices.elementAt(instanceIndex)].length; attributeIndex++) {
if (instanceIndex == 0) {
centroid[attributeIndex] = 0;
}
centroid[attributeIndex] += data[indices.elementAt(instanceIndex)][attributeIndex] / indices.size();
}
}
}
/** Removes the instance according to the given index. */
public void removeInstanceIndex(int instanceIndex) {
indices.remove(new Integer(instanceIndex));
}
/**
* Add the index of a data point to this cluster.
*
* @param index to add.
*/
public void addIndex(int index) {
indices.add(new Integer(index));
}
/**
* Set the centroid of this cluster.
*
* @param centroid to set.
*/
public void setCentroid(double[] centroid) {
this.centroid = centroid;
}
// FIXME medidate over this fuck or something
public double[] getCentroid() {
return centroid.clone();
}
public Vector<Integer> getIndices() {
return this.indices;
}
public int getNumberOfInstances() {
return indices.size();
}
/**
* Tough one to guess.
*/
public void printClusterIndices(double[][] data) {
if (centroid.length > 500) {
System.out.println("< Surpressing centroid debug output due to high dimensionality (" + centroid.length
+ ") >");
} else {
System.out.println("\tCentroid: " + ArrayUtils.toString(centroid));
}
System.out.println("\tSSE: " + SSE(data));
for (int i = 0; i < indices.size() && i < MAX_INDICES_DEBUG; i++) {
System.out.println("\tindex " + indices.elementAt(i) + " / "
+ getDistanceToCentroid(data[indices.elementAt(i)]));
}
if (indices.size() > MAX_INDICES_DEBUG) {
System.out.println("Surpressing output of " + (indices.size() - MAX_INDICES_DEBUG)
+ " additional elements.");
}
}
/**
* Tough one to guess.
*/
public void printClusterIndices() {
if (centroid.length > MAX_DIM_DEBUG) {
System.out.println("< Surpressing centroid debug output due to high dimensionality (" + centroid.length
+ ") >");
} else {
System.out.println("\tCentroid: " + ArrayUtils.toString(centroid));
}
for (int i = 0; i < indices.size() && i < MAX_INDICES_DEBUG; i++) {
System.out.println("\tindex " + indices.elementAt(i));
}
if (indices.size() > MAX_INDICES_DEBUG) {
System.out.println("Surpressing output of " + (indices.size() - MAX_INDICES_DEBUG)
+ " additional elements.");
}
}
/**
* Returns all the instances belonging to this cluster according to the given data set.
*
* @param data instances.
* @return plain matrix of all assigned instances.
*/
public double[][] getInstances(double[][] data) {
double[][] instances = new double[indices.size()][data[0].length];
for (int i = 0; i < indices.size(); i++) {
instances[i] = data[indices.elementAt(i)];
}
return instances;
}
/**
* Calculate the sum of the squared error (SSE) for this cluster. This is the distances of the cluster's centroid to
* all units assigned.
*
* @param data matrix to compute the SSE for.
* @return the SSE value for this cluster.
*/
public double SSE(double[][] data) {
double sse = 0d;
for (int i = 0; i < indices.size(); i++) {
try {
sse += distanceFunction.distance(data[indices.elementAt(i)], centroid);
} catch (MetricException e) {
e.printStackTrace();
}
}
return sse;
}
/** SSE again, this time the average one (i.e. divided by the number of instances within this cluster) */
public double averageSSE(double[][] data) {
return SSE(data) / this.getNumberOfInstances();
}
/**
* Get the distance of a given instance to this cluster's centroid.
*
* @param instance some instance.
* @return the distance according to the used distance function.
*/
public double getDistanceToCentroid(double[] instance) {
try {
return distanceFunction.distance(centroid, instance);
} catch (MetricException e) {
e.printStackTrace();
}
return 0d;
}
/**
* Get the numbers of occurrences of each attribute in this cluster.
*
* @return array for each attribute and the number of how many instances it occurs in
*/
public int[] getNumberOfAttributeOccurrences(double[][] data) {
int[] counts = new int[data[0].length];
for (int i = 0; i < indices.size(); i++) {
double[] row = data[indices.elementAt(i)];
for (int j = 0; j < row.length; j++) {
if (i == 0) {
counts[j] = 0;
}
counts[j] += row[j] > 0d ? 1 : 0;
}
}
for (int i = 0; i < counts.length; i++) {
counts[i] = counts[i] == 0 ? -1 : counts[i];
}
return counts;
}
/** Get the instance with the maximum SSE of all instances assigned to this cluster. */
public int getInstanceIndexWithMaxSSE(double[][] data) {
int index = -1;
double maxSSE = Double.NEGATIVE_INFINITY;
double currentSSE = 0;
for (int i = 0; i < indices.size(); i++) {
try {
currentSSE = distanceFunction.distance(data[indices.elementAt(i)], centroid);
if (currentSSE > maxSSE) {
maxSSE = currentSSE;
index = indices.elementAt(i);
}
} catch (MetricException e) {
e.printStackTrace();
}
}
return index;
}
}