/*
* #%L
* gitools-core
* %%
* Copyright (C) 2013 Universitat Pompeu Fabra - Biomedical Genomics group
* %%
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public
* License along with this program. If not, see
* <http://www.gnu.org/licenses/gpl-3.0.html>.
* #L%
*/
package org.gitools.analysis.clustering.kmeans;
import org.gitools.analysis.clustering.AbstractClusteringMethod;
import org.gitools.analysis.clustering.ClusteringData;
import org.gitools.analysis.clustering.ClusteringException;
import org.gitools.analysis.clustering.MatrixClusteringData;
import org.gitools.analysis.clustering.distance.DistanceMeasure;
import org.gitools.analysis.clustering.distance.EuclideanDistance;
import org.gitools.analysis.clustering.hierarchical.HierarchicalClusterer;
import org.gitools.api.analysis.Clusters;
import org.gitools.api.analysis.IAggregator;
import org.gitools.api.analysis.IProgressMonitor;
import org.gitools.heatmap.header.HierarchicalClusterNamer;
import java.util.*;
import static com.google.common.base.Predicates.notNull;
import static com.google.common.collect.Iterables.filter;
import static com.google.common.collect.Iterables.size;
public class KMeansPlusPlusMethod extends AbstractClusteringMethod {
public static String PROPERTY_ITERATIONS = "iterations";
public static String PROPERTY_NUMCLUSTERS = "numClusters";
public static String PROPERTY_DISTANCE = "distance";
private Long iterations = Long.valueOf(300);
private Long numClusters = Long.valueOf(6);
private DistanceMeasure distance = EuclideanDistance.get();
public KMeansPlusPlusMethod() {
super("K-means++");
}
@Override
public Clusters cluster(ClusteringData clusterData, IProgressMonitor monitor) throws ClusteringException {
if (!(clusterData instanceof MatrixClusteringData)) {
return null;
}
MatrixClusteringData data = (MatrixClusteringData) clusterData;
KMeansPlusPlusClusterer clusterer = new KMeansPlusPlusClusterer(getNumClusters().intValue(), getIterations().intValue(), getDistance());
Set<String> noData = new HashSet<>();
List<Slide> points = new ArrayList<>(data.getClusteringDimension().size());
int aggregationSize = data.getAggregationDimension().size();
for (String identifier : data.getClusteringDimension()) {
Iterable<Double> point = data.getMatrix().newPosition().set(data.getClusteringDimension(), identifier).iterate(data.getLayer(), data.getAggregationDimension());
// Skip all null rows/columns
if (size(filter(point, notNull())) == 0) {
noData.add(identifier);
continue;
}
points.add(new Slide(identifier, point, aggregationSize));
}
// Cluster data
List<CentroidCluster<Slide>> clusters = clusterer.cluster(points, monitor);
// Calculate cluster weight
IAggregator aggregator = data.getLayer().getAggregator();
for (Cluster<Slide> cluster : clusters) {
List<Double> values = new ArrayList<>();
for (Slide slide : cluster.getPoints()) {
values.add(aggregator.aggregate(slide.getPoint()));
}
cluster.setWeight(aggregator.aggregate(values) / cluster.getPoints().size());
}
// Sort clusters
Collections.sort(clusters);
// Name and return clusters
return new KMeansClusters(clusters, noData);
}
public Long getIterations() {
return iterations;
}
public void setIterations(Long iterations) {
this.iterations = iterations;
firePropertyChange(PROPERTY_ITERATIONS, null, iterations);
}
public DistanceMeasure getDistance() {
return distance;
}
public void setDistance(DistanceMeasure distance) {
this.distance = distance;
firePropertyChange(PROPERTY_DISTANCE, null, distance);
}
public Long getNumClusters() {
return numClusters;
}
public void setNumClusters(Long numClusters) {
this.numClusters = numClusters;
firePropertyChange(PROPERTY_NUMCLUSTERS, null, numClusters);
}
public static class Slide implements Clusterable {
private String identifier;
private Iterable<Double> point;
private int size;
public Slide(String identifier, Iterable<Double> point, int size) {
this.identifier = identifier;
this.point = point;
this.size = size;
}
public String getIdentifier() {
return identifier;
}
@Override
public Iterable<Double> getPoint() {
return point;
}
@Override
public int size() {
return size;
}
}
public static class KMeansClusters implements Clusters {
private Map<String, Set<String>> clusters;
public KMeansClusters(List<? extends Cluster<Slide>> clusters, Set<String> noData) {
this.clusters = new HashMap<>();
int digits = HierarchicalClusterNamer.calculateDigits(clusters.size());
for (int i=0; i < clusters.size(); i++) {
Cluster<Slide> cluster = clusters.get(i);
Set<String> identifiers = new HashSet<>();
List<Slide> points = cluster.getPoints();
for (Slide slide : points) {
identifiers.add(slide.getIdentifier());
}
this.clusters.put(HierarchicalClusterNamer.createLabel(i, digits), identifiers);
}
this.clusters.put("Empty", noData);
}
@Override
public Collection<String> getClusters() {
return clusters.keySet();
}
@Override
public Set<String> getItems(String cluster) {
return clusters.get(cluster);
}
@Override
public Map<String, Set<String>> getClustersMap() {
return clusters;
}
}
}