package ca.pfv.spmf.algorithms.clustering.kmeans;
/* This file is copyright (c) 2008-2013 Philippe Fournier-Viger
*
* This file is part of the SPMF DATA MINING SOFTWARE
* (http://www.philippe-fournier-viger.com/spmf).
*
* SPMF is free software: you can redistribute it and/or modify it under the
* terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* SPMF is distributed in the hope that it will be useful, but WITHOUT ANY
* WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
* A PARTICULAR PURPOSE. See the GNU General Public License for more details.
* You should have received a copy of the GNU General Public License along with
* SPMF. If not, see <http://www.gnu.org/licenses/>.
*/
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import ca.pfv.spmf.algorithms.clustering.distanceFunctions.DistanceFunction;
import ca.pfv.spmf.patterns.cluster.ClusterWithMean;
import ca.pfv.spmf.patterns.cluster.DoubleArray;
import ca.pfv.spmf.tools.MemoryLogger;
/**
* An implementation of the K-means algorithm (J. MacQueen, 1967).
* <br/><br/>
*
* The K-means algorithm steps are (text from Wikipedia) : 1) Choose the number of clusters, k.
* * 2) Randomly generate k clusters and determine the cluster centers, or directly
* generate k random points as cluster centers. 3) Assign each point to the
* nearest cluster center. 4) Recompute the new cluster centers. 5) Repeat the two
* previous steps until some convergence criterion is met (usually that the
* assignment hasn't changed).
*
* @author Philippe Fournier-Viger
*/
public class AlgoKMeans {
// The list of clusters generated
protected List<ClusterWithMean> clusters = null;
// A random number generator because K-Means is a randomized algorithm
protected final static Random random = new Random(System.currentTimeMillis());
// For statistics
protected long startTimestamp; // the start time of the latest execution
protected long endTimestamp; // the end time of the latest execution
long iterationCount; // the number of iterations that was performed
/* The distance function to be used for clustering */
protected DistanceFunction distanceFunction = null;
/**
* Default constructor
*/
public AlgoKMeans() {
}
/**
* Run the K-Means algorithm
* @param inputFile an input file path containing a list of vectors of double values
* @param k the parameter k
* @param distanceFunction
* @return a list of clusters (some of them may be empty)
* @throws IOException exception if an error while writing the file occurs
*/
public List<ClusterWithMean> runAlgorithm(String inputFile, int k, DistanceFunction distanceFunction) throws NumberFormatException, IOException {
// record the start time
startTimestamp = System.currentTimeMillis();
// reset the number of iterations
iterationCount =0;
this.distanceFunction = distanceFunction;
// Structure to store the vectors from the file
List<DoubleArray> vectors = new ArrayList<DoubleArray>();
// variables to store the minimum and maximum values in vectors
double minValue = Integer.MAX_VALUE;
double maxValue = 0;
// read the vectors from the input file
BufferedReader reader = new BufferedReader(new FileReader(inputFile));
String line;
// for each line until the end of the file
while (((line = reader.readLine()) != null)) {
// if the line is a comment, is empty or is a
// kind of metadata
if (line.isEmpty() == true ||
line.charAt(0) == '#' || line.charAt(0) == '%'
|| line.charAt(0) == '@') {
continue;
}
// split the line by spaces
String[] lineSplited = line.split(" ");
// create a vector of double
double [] vector = new double[lineSplited.length];
// for each value of the current line
for (int i=0; i< lineSplited.length; i++) {
// convert to double
double value = Double.parseDouble(lineSplited[i]);
// check if it is the min or max
if(value < minValue){
minValue = value;
}
if(value > maxValue){
maxValue = value;
}
// add the value to the current vector
vector[i] = value;
}
// add the vector to the list of vectors
vectors.add(new DoubleArray(vector));
}
// close the file
reader.close();
// Get the size of vectors
int vectorsSize = vectors.get(0).data.length;
// if the user ask for only one cluster!
if(k == 1) {
// Create a single cluster and return it
clusters = new ArrayList<ClusterWithMean>();
ClusterWithMean cluster = new ClusterWithMean(vectorsSize);
for(DoubleArray vector : vectors) {
cluster.addVector(vector);
}
cluster.setMean(new DoubleArray(new double[vectorsSize]));
cluster.recomputeClusterMean();
clusters.add(cluster);
// check memory usage
MemoryLogger.getInstance().checkMemory();
// record end time
endTimestamp = System.currentTimeMillis();
return clusters;
}
// SPECIAL CASE: If only one vector
if (vectors.size() == 1) {
// Create a single cluster and return it
clusters = new ArrayList<ClusterWithMean>();
DoubleArray vector = vectors.get(0);
ClusterWithMean cluster = new ClusterWithMean(vectorsSize);
cluster.addVector(vector);
cluster.recomputeClusterMean();
cluster.setMean(new DoubleArray(new double[vectorsSize]));
clusters.add(cluster);
// check memory usage
MemoryLogger.getInstance().checkMemory();
// record end time
endTimestamp = System.currentTimeMillis();
return clusters;
}
// if the user asks for more cluster then there is data,
// we set k to the number of data points.
if(k > vectors.size()) {
k = vectors.size();
}
applyAlgorithm(k, distanceFunction, vectors, minValue, maxValue,
vectorsSize);
// check memory usage
MemoryLogger.getInstance().checkMemory();
// record end time
endTimestamp = System.currentTimeMillis();
// return the clusters
return clusters;
}
/**
* Apply the K-means algorithm
* @param k the parameter k
* @param distanceFunction a distance function
* @param vectors the list of initial vectors
* @param minValue the min value
* @param maxValue the max value
* @param vectorsSize the vector size
*/
void applyAlgorithm(int k, DistanceFunction distanceFunction,
List<DoubleArray> vectors, double minValue, double maxValue,
int vectorsSize) {
// apply kmeans
clusters = applyKMeans(k, distanceFunction, vectors, minValue, maxValue, vectorsSize);
}
/**
* Apply the K-means algorithm
* @param k the parameter k
* @param distanceFunction a distance function
* @param vectors the list of initial vectors
* @param minValue the min value
* @param maxValue the max value
* @param vectorsSize the vector size
*/
List<ClusterWithMean> applyKMeans(int k, DistanceFunction distanceFunction,
List<DoubleArray> vectors, double minValue, double maxValue,
int vectorsSize) {
List<ClusterWithMean> newClusters = new ArrayList<ClusterWithMean>();
// SPECIAL CASE: If only one vector
if (vectors.size() == 1) {
// Create a single cluster and return it
DoubleArray vector = vectors.get(0);
ClusterWithMean cluster = new ClusterWithMean(vectorsSize);
cluster.addVector(vector);
newClusters.add(cluster);
return newClusters;
}
// (1) Randomly generate k empty clusters with a random mean (cluster
// center)
for(int i=0; i< k; i++){
DoubleArray meanVector = generateRandomVector(minValue, maxValue, vectorsSize);
ClusterWithMean cluster = new ClusterWithMean(vectorsSize);
cluster.setMean(meanVector);
newClusters.add(cluster);
}
// (2) Repeat the two next steps until the assignment hasn't changed
boolean changed;
while(true) {
iterationCount++;
changed = false;
// (2.1) Assign each point to the nearest cluster center.
// / for each vector
for (DoubleArray vector : vectors) {
// find the nearest cluster and the cluster containing the item
ClusterWithMean nearestCluster = null;
ClusterWithMean containingCluster = null;
double distanceToNearestCluster = Double.MAX_VALUE;
// for each cluster
for (ClusterWithMean cluster : newClusters) {
// calculate the distance of the cluster mean to the vector
double distance = distanceFunction.calculateDistance(cluster.getmean(), vector);
// if it is the smallest distance until now, record this cluster
// and the distance
if (distance < distanceToNearestCluster) {
nearestCluster = cluster;
distanceToNearestCluster = distance;
}
// if the cluster contain the vector already,
// remember that too!
if (cluster.contains(vector)) {
containingCluster = cluster;
}
}
// if the nearest cluster is not the cluster containing
// the vector
if (containingCluster != nearestCluster) {
// remove the vector from the containing cluster
if (containingCluster != null) {
containingCluster.remove(vector);
}
// add the vector to the nearest cluster
nearestCluster.addVector(vector);
changed = true;
}
}
// check the memory usage
MemoryLogger.getInstance().checkMemory();
if(!changed){ // exit condition for main loop
break;
}
// (2.2) Recompute the new cluster means
for (ClusterWithMean cluster : newClusters) {
cluster.recomputeClusterMean();
}
}
return newClusters;
}
/**
* Generate a random vector.
* @param minValue the minimum value allowed
* @param maxValue the maximum value allowed
* @param vectorsSize the desired vector size
* @return the random vector
*/
DoubleArray generateRandomVector(double minValue, double maxValue,
int vectorsSize) {
// create a new vector
double[] vector = new double[vectorsSize];
// for each position generate a random number
for(int i=0; i < vectorsSize; i++){
vector[i] = (random.nextDouble() * (maxValue - minValue)) + minValue;
}
// return the vector
return new DoubleArray(vector);
}
double getSSE(List<ClusterWithMean> clusters) {
double sse = 0;
for(ClusterWithMean cluster : clusters) {
for(DoubleArray vector : cluster.getVectors()) {
sse += Math.pow(distanceFunction.calculateDistance(vector, cluster.getmean()), 2);
}
}
return sse;
}
/**
* Save the clusters to an output file
* @param output the output file path
* @throws IOException exception if there is some writing error.
*/
public void saveToFile(String output) throws IOException {
BufferedWriter writer = new BufferedWriter(new FileWriter(output));
// for each cluster
for(int i=0; i< clusters.size(); i++){
// if the cluster is not empty
if(clusters.get(i).getVectors().size() >= 1){
// write the cluster
writer.write(clusters.get(i).toString());
// if not the last cluster, add a line return
if(i < clusters.size()-1){
writer.newLine();
}
}
}
// close the file
writer.close();
}
/**
* Print statistics of the latest execution to System.out.
*/
public void printStatistics() {
System.out.println("========== KMEANS - STATS ============");
System.out.println(" Distance function: " + distanceFunction.getName());
System.out.println(" Total time ~: " + (endTimestamp - startTimestamp)
+ " ms");
System.out.println(" SSE (Sum of Squared Errors) (lower is better) : " + getSSE(clusters));
System.out.println(" Max memory:" + MemoryLogger.getInstance().getMaxMemory() + " mb ");
System.out.println(" Iteration count: " + iterationCount);
System.out.println("=====================================");
}
}