package com.compomics.util.math.clustering;
import com.compomics.util.gui.waiting.waitinghandlers.WaitingHandlerCLIImpl;
import com.compomics.util.waiting.WaitingHandler;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Random;
import no.uib.jsparklines.renderers.util.Util;
/**
* K-means clustering.
*
* @author Harald Barsnes
*/
public class KMeansClustering {
/**
* The number of clusters.
*/
private final int NUM_CLUSTERS;
/**
* The number of samples.
*/
private final int NUM_SAMPLES;
/**
* The number of values for each sample.
*/
private final int NUM_VALUES;
/**
* The sample data.
*/
private final double SAMPLES[][];
/**
* The sample identifiers.
*/
private final String SAMPLE_IDS[];
/**
* The centroids.
*/
private double centroids[][];
/**
* The current cluster each sample belongs to.
*/
private int clusters[];
/**
* The maximum number of iteration.
*/
private int maxIterations = 500; // @TODO: what should the default be..?
/**
* Constructor.
*
* @param samples the data
* @param sampleIds the sample identifiers
* @param numClusters the number of clusters
*/
public KMeansClustering(double samples[][], String[] sampleIds, int numClusters) {
SAMPLES = samples;
SAMPLE_IDS = sampleIds;
NUM_SAMPLES = samples.length;
NUM_VALUES = samples[0].length;
NUM_CLUSTERS = numClusters;
if (NUM_CLUSTERS > NUM_SAMPLES) {
throw new IllegalArgumentException("The number of clusters cannot be bigger than the number of samples! #clusters: " + NUM_CLUSTERS + ", #samples: " + NUM_SAMPLES);
}
initialize();
}
/**
* Constructor.
*
* @param dataFile the file with the data
* @param numClusters the number of clusters
*/
public KMeansClustering(File dataFile, int numClusters) {
SampleData sampleData = readDataFromFile(dataFile);
SAMPLES = sampleData.getSamples();
SAMPLE_IDS = sampleData.getSampleIds();
NUM_SAMPLES = SAMPLES.length;
NUM_VALUES = SAMPLES[0].length;
NUM_CLUSTERS = numClusters;
if (NUM_CLUSTERS > NUM_SAMPLES) {
throw new IllegalArgumentException("The number of clusters cannot be bigger than the number of samples! #clusters: " + NUM_CLUSTERS + ", #samples: " + NUM_SAMPLES);
}
initialize();
}
/**
* Set up the empty clusters and set the initial centroids.
*/
private void initialize() {
// set up the yet empty clusters
clusters = new int[NUM_SAMPLES];
// add the initial centroids
centroids = new double[NUM_CLUSTERS][NUM_VALUES];
// set the initial random centroids
Random rand = new Random();
for (int centroidCounter = 0; centroidCounter < NUM_CLUSTERS; centroidCounter++) {
int randomSample = rand.nextInt(NUM_SAMPLES);
System.arraycopy(SAMPLES[randomSample], 0, centroids[centroidCounter], 0, NUM_VALUES);
}
}
/**
* Run the k-means clustering.
*
* @param waitingHandler the waiting handler
*/
public void kMeanCluster(WaitingHandler waitingHandler) {
boolean clustersChanged = true;
// asign the samples to the clusters
assignToClusters();
int iterationCounter = 0;
// iterate until the clustering no longer changes
while (clustersChanged && iterationCounter < maxIterations && !waitingHandler.isRunCanceled()) {
// calculate the new centroids
calculateNewCentroids();
// assign the samples to the new centroids
clustersChanged = assignToClusters();
iterationCounter++;
}
}
/**
* Assign the samples to the clusters.
*
* @return true if the clustering changed
*/
private boolean assignToClusters() {
boolean clustersChanged = false;
for (int sampleNumber = 0; sampleNumber < NUM_SAMPLES; sampleNumber++) {
double minimumValue = Double.MAX_VALUE;
int selectedCentroidNumber = 0;
// find the closest cluster
for (int centroidNumber = 0; centroidNumber < NUM_CLUSTERS; centroidNumber++) {
double distance = distSampleToCentroid(sampleNumber, centroidNumber);
if (distance < minimumValue) {
minimumValue = distance;
selectedCentroidNumber = centroidNumber;
}
}
// check if the sample's cluster assignment changed
if (clusters[sampleNumber] != selectedCentroidNumber) {
clustersChanged = true;
}
// add to the closest cluster
clusters[sampleNumber] = selectedCentroidNumber;
}
return clustersChanged;
}
/**
* Calculate new centroids.
*/
private void calculateNewCentroids() {
// clear the centroids
clearCentroids();
// calculate new centroids
for (int centroidNumber = 0; centroidNumber < NUM_CLUSTERS; centroidNumber++) {
int totalInCluster = 0;
for (int sampleCounter = 0; sampleCounter < NUM_SAMPLES; sampleCounter++) {
if (clusters[sampleCounter] == centroidNumber) {
for (int valueNumber = 0; valueNumber < NUM_VALUES; valueNumber++) {
centroids[centroidNumber][valueNumber] += SAMPLES[sampleCounter][valueNumber];
}
totalInCluster++;
}
}
if (totalInCluster > 0) {
for (int valueNumber = 0; valueNumber < NUM_VALUES; valueNumber++) {
centroids[centroidNumber][valueNumber] /= totalInCluster;
}
}
}
}
/**
* Clear the centroids.
*/
private void clearCentroids() {
for (int centroidNumber = 0; centroidNumber < NUM_CLUSTERS; centroidNumber++) {
for (int valueNumber = 0; valueNumber < NUM_VALUES; valueNumber++) {
centroids[centroidNumber][valueNumber] = 0.0;
}
}
}
/**
* Calculate the Euclidean distance between a sample and a centroid.
*
* @param sampleNumber the sample number
* @param centroidNumber the centroid number
* @return the Euclidean distance
*/
private double distSampleToCentroid(int sampleNumber, int centroidNumber) {
double distance = 0;
for (int valueNumber = 0; valueNumber < NUM_VALUES; valueNumber++) {
distance += Math.pow(SAMPLES[sampleNumber][valueNumber] - centroids[centroidNumber][valueNumber], 2);
}
return Math.sqrt(distance);
}
/**
* Calculate the Euclidean distance between two samples.
*
* @param sampleNumber1 the samples number of the first sample
* @param sampleNumber2 the sample number of the second sample
* @return the Euclidean distance
*/
private double distSampleToSample(int sampleNumber1, int sampleNumber2) {
double distance = 0;
for (int valueNumber = 0; valueNumber < NUM_VALUES; valueNumber++) {
distance += Math.pow(SAMPLES[sampleNumber1][valueNumber] - SAMPLES[sampleNumber2][valueNumber], 2);
}
return Math.sqrt(distance);
}
/**
* Main method for testing purposes.
*
* @param args the command line arguments
*/
public static void main(String[] args) {
// example with direct input
// the sample data
double samples[][] = new double[][]{
{1.018387278, 0.983270041, 1.063472453, 0.713225975, 0.731043734, 0.973387687, 0.936300274, 1.039486067, 1.134279088, 0.986361721},
{0.981590377, 1.02987824, 1.089762055, 0.927909537, 0.745221317, 0.709817942, 0.655031878, 1.047253604, 0.952668566, 1.037703939},
{1.03694662, 1.080079418, 1.041962748, 1.258192406, 1.342060684, 0.996528485, 0.924128553, 0.936412377, 0.920298185, 0.918456169},
{0.990287761, 0.892692992, 0.914314664, 1.279408351, 1.31410923, 0.941641721, 0.910025757, 1.064973225, 1.041036986, 1.049735711},
{1.040106591, 0.938527051, 0.965804511, 0.695864906, 0.813267072, 1.064862452, 1.128367944, 0.9798703, 1.268314349, 0.890250862},
{1.283690338, 1.221861511, 1.237727692, 1.131154141, 0.991934148, 0.962126821, 0.943197586, 0.872215846, 0.912011518, 0.829430491},
{0.981473817, 0.805082739, 0.979007845, 0.685868656, 0.467881815, 1.30464142, 1.031580941, 1.120770021, 1.163524042, 0.948936962},
{0.935739165, 0.961540471, 0.948513884, 1.1214119, 1.139158941, 0.952546774, 1.061539826, 0.967187465, 0.969725485, 1.066965917},
{0.98084797, 0.99517748, 0.967601553, 1.408483587, 1.242533492, 0.809655819, 1.012664473, 0.972120169, 0.90671428, 1.064156888},
{1.114446123, 1.024968093, 1.034149441, 0.783212889, 0.801006499, 0.983516619, 1.026256729, 0.996830977, 0.975588315, 0.942473673},
{0.905988305, 0.908986417, 0.925003413, 1.19651456, 1.106383596, 0.997060333, 1.030914868, 1.07807453, 1.146596783, 1.079137402},
{1.040040646, 1.049901339, 0.989359079, 1.017323675, 1.008910963, 0.983004953, 0.984566787, 1.040902927, 1.02390089, 1.015875601},
{1.038052043, 0.999666309, 1.011292944, 0.862294159, 0.878858798, 0.98299443, 0.963822514, 0.982571918, 0.975889047, 1.009450539},
{0.821272331, 0.767589262, 0.817114369, 1.059135199, 0.884487875, 1.091284726, 1.022820961, 1.148307617, 1.032334252, 1.167097238},
{1.016334545, 1.090488723, 0.981954941, 1.223423201, 1.07287664, 0.967790703, 0.894805565, 1.103557481, 1.031495908, 1.028484672},
{0.991456092, 0.665417264, 0.862248473, 1.005142654, 0.919656901, 1.244190762, 1.056869139, 1.031395099, 0.898937035, 0.946095374}
};
// the sample identifiers
String sampleNames[] = new String[]{"O95071", "Q6ZT21", "Q99590", "Q14517", "Q9P219", "Q14692",
"Q8TF74", "Q13427", "Q9ULD9", "Q9UPN9", "P51805", "Q92621", "Q5SRE5", "Q8TB73", "Q96CP6", "Q13671"};
KMeansClustering kMeansClutering = new KMeansClustering(samples, sampleNames, 5);
//
// // example with input from file - tab separated input, no header, first column assumed to be the sample ids
// KMeansClustering kMeansClutering = new KMeansClustering(new File("C:\\Users\\hba041\\Desktop\\clustering_data.txt"), 30);
// print the initial centroids
System.out.println("Centroids initialized at:");
kMeansClutering.printCentroids();
System.out.print("\n");
// exectute the clustering
kMeansClutering.kMeanCluster(new WaitingHandlerCLIImpl());
// print the clustering results
kMeansClutering.printClusters();
// print the centroid results
System.out.println("Centroids finalized at:");
kMeansClutering.printCentroids();
System.out.print("\n");
}
/**
* Print the centroids.
*/
public void printCentroids() {
for (int centroidNumber = 0; centroidNumber < NUM_CLUSTERS; centroidNumber++) {
System.out.print(" " + (centroidNumber + 1) + "\t\t");
for (int valueNumber = 0; valueNumber < NUM_VALUES; valueNumber++) {
if (valueNumber > 0) {
System.out.print("\t");
}
System.out.print(Util.roundDouble(centroids[centroidNumber][valueNumber], 2));
}
System.out.println();
}
}
/**
* Print the current clusters.
*/
public void printClusters() {
for (int clusterIndex = 0; clusterIndex < NUM_CLUSTERS; clusterIndex++) {
System.out.println("Cluster " + (clusterIndex + 1) + " includes:");
for (int sampleIndex = 0; sampleIndex < NUM_SAMPLES; sampleIndex++) {
if (clusters[sampleIndex] == clusterIndex) {
System.out.print(" " + SAMPLE_IDS[sampleIndex] + "\t");
for (int valueNumber = 0; valueNumber < NUM_VALUES; valueNumber++) {
if (valueNumber > 0) {
System.out.print("\t");
}
System.out.print(Util.roundDouble(SAMPLES[sampleIndex][valueNumber], 2));
}
System.out.println();
}
}
System.out.println();
}
}
/**
* Get the sample names of all the members in the given cluster.
*
* @param clusterIndex the index of the cluster
* @return the sample names of all the members in the given cluster
*/
public ArrayList<String> getClusterMembers(int clusterIndex) {
ArrayList<String> clusterMembers = new ArrayList<String>();
for (int sampleIndex = 0; sampleIndex < NUM_SAMPLES; sampleIndex++) {
if (clusters[sampleIndex] == clusterIndex) {
clusterMembers.add(SAMPLE_IDS[sampleIndex]);
}
}
return clusterMembers;
}
/**
* Returns a hashmap with the values for the members in the given cluster.
* Key: sample id, value: the data points.
*
* @param clusterIndex the index of the cluster
* @return the values for the members in the given cluster
*/
public HashMap<String, ArrayList<Double>> getClusterMembersData(int clusterIndex) {
HashMap<String, ArrayList<Double>> clusterMembers = new HashMap<String, ArrayList<Double>>();
for (int sampleIndex = 0; sampleIndex < NUM_SAMPLES; sampleIndex++) {
if (clusters[sampleIndex] == clusterIndex) {
ArrayList<Double> values = new ArrayList<Double>();
for (int valueNumber = 0; valueNumber < NUM_VALUES; valueNumber++) {
values.add(SAMPLES[sampleIndex][valueNumber]);
}
clusterMembers.put(SAMPLE_IDS[sampleIndex], values);
}
}
return clusterMembers;
}
/**
* Read sample data from file.
*
* @param dataFile the file to read from
* @return the sample data
*/
private SampleData readDataFromFile(File dataFile) {
SampleData sampleData = null;
try {
FileReader f = new FileReader(dataFile);
BufferedReader br = new BufferedReader(f);
String line = br.readLine();
ArrayList<String> sampleIds = new ArrayList<String>();
ArrayList<ArrayList<Double>> sampleDataAsArray = new ArrayList<ArrayList<Double>>();
int numSamples = 0;
int numValues = 0;
while (line != null) {
String[] values = line.split("\\t");
sampleIds.add(values[0]);
ArrayList<Double> tempData = new ArrayList<Double>();
for (int i = 1; i < values.length; i++) {
tempData.add(Double.parseDouble(values[i]));
}
sampleDataAsArray.add(tempData);
if (numValues == 0) {
numValues = values.length - 1;
}
numSamples++;
line = br.readLine();
}
String sampleNames[] = new String[numSamples];
for (int i = 0; i < numSamples; i++) {
sampleNames[i] = sampleIds.get(i);
}
double samples[][] = new double[numSamples][numValues];
for (int i = 0; i < numSamples; i++) {
for (int j = 0; j < numValues; j++) {
samples[i][j] = sampleDataAsArray.get(i).get(j);
}
}
sampleData = new SampleData(samples, sampleNames);
br.close();
f.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
return sampleData;
}
/**
* Returns the number of clusters.
*
* @return the number of clusters
*/
public int getNumberOfClusters() {
return NUM_CLUSTERS;
}
/**
* Returns the maximum number of iterations.
*
* @return the maximum number of iterations
*/
public int getMaxIterations() {
return maxIterations;
}
/**
* Set the maximum number of iterations.
*
* @param maxIterations the maximum number of iterations
*/
public void setMaxIterations(int maxIterations) {
this.maxIterations = maxIterations;
}
/**
* Sample data.
*/
private class SampleData {
/**
* The sample data.
*/
private double samples[][];
/**
* The sample identifiers.
*/
private String sampleIds[];
/**
* The sample data.
*
* @param samples the data
* @param sampleIds the sample identifiers.
*/
public SampleData(double samples[][], String sampleIds[]) {
this.samples = samples;
this.sampleIds = sampleIds;
}
/**
* Returns the samples.
*
* @return the samples
*/
public double[][] getSamples() {
return samples;
}
/**
* Returns the sample identifiers.
*
* @return the sample identifiers
*/
public String[] getSampleIds() {
return sampleIds;
}
}
}