/**
* Copyright (c) 2013 Oculus Info Inc.
* http://www.oculusinfo.com/
*
* Released under the MIT License.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of
* this software and associated documentation files (the "Software"), to deal in
* the Software without restriction, including without limitation the rights to
* use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
* of the Software, and to permit persons to whom the Software is furnished to do
* so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package spimedb.cluster.unsupervised.cluster;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import spimedb.cluster.DataSet;
import spimedb.cluster.Instance;
import spimedb.cluster.feature.Feature;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.*;
/***
*
* Abstract base class that provides many useful features a typical clusterer needs.
*
* Most clusterers will extend this class directly.
* The only method that is required to be implemented by sub-classes is isCandidate() for
* determining whether an instance can be added to a cluster during clustering.
*
* @author slangevin
*
*/
public abstract class AbstractClusterer extends BaseClusterer {
protected final static int DEFAULT_THREAD_POOL = Runtime.getRuntime().availableProcessors();
protected final boolean penalizeMissingFeatures;
protected final boolean firstCandidate;
//protected double maxDistance = 1.0;
protected static final Logger log = LoggerFactory.getLogger("com.oculusinfo");
protected ExecutorService exec; // = Executors.newFixedThreadPool(DEFAULT_THREAD_POOL); //.newSingleThreadExecutor();;
@Override
public void init() {
exec = Executors.newFixedThreadPool(DEFAULT_THREAD_POOL, new MyThreadFactory()); //.newSingleThreadExecutor();
Runtime.getRuntime().addShutdownHook(new Thread(this::terminate));
}
@Override
public void terminate() {
if (exec == null) return;
try {
log.debug("Clusterer thread pool starting shutdown");
exec.shutdown();
try {
if (!exec.awaitTermination(10, TimeUnit.SECONDS)) {
log.error("Clusterer thread pool did not shut down gracefully");
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
exec.shutdownNow();
// log.info("Clusterer thread pool shutdown");
}
catch (Exception e) {
log.error("Clusterer thread pool did not shut down gracefully", e);
}
}
protected static class DistanceResult {
public final Instance i;
public final Cluster c;
public final double distance;
public DistanceResult(Instance i, Cluster c, double distance) {
this.i = i;
this.c = c;
this.distance = distance;
}
public boolean isNull() {
return (c == null);
}
}
/***
* Method for determining whether a cluster is a better candidate than the previous best candidate cluster.
*
* @param inst the instance being clustered
* @param candidate is the cluster
* @param score is the double precision value of the distance score of inst to candidate
* @param best is the current best cluster the inst can be added to
* @param bestScore is the current best score
* @return true if candidate is a better cluster for inst than best
*/
protected abstract boolean isCandidate(Instance inst, Cluster candidate, double score, Cluster best, double bestScore);
public AbstractClusterer() {
this(false, false, true);
}
public AbstractClusterer(boolean firstCandidate) {
this(firstCandidate, false, true);
}
public AbstractClusterer(boolean firstCandidate, boolean onlineUpdate, boolean penalizeMissingFeatures) {
super(onlineUpdate);
this.firstCandidate = firstCandidate;
this.penalizeMissingFeatures = penalizeMissingFeatures;
}
/***
* Method to override the logger for the clusterer to use for output.
*
* @param logger
*/
/* public void setLogger(Logger logger) {
log = logger;
}*/
/***
* Return the executor service the clusterer is using for parallelization.
*
* @return the executor service
*/
public ExecutorService getExecutor() {
return exec;
}
@Override
public ClusterResult doIncrementalCluster(DataSet ds, List<Cluster> clusters) {
return doCluster(ds, clusters);
}
@Override
public ClusterResult doCluster(DataSet ds) {
return doCluster(ds, new LinkedList<>());
}
private static List<List<? extends Instance>> createBlocks(List<? extends Instance> clusters, int blocksize) {
List<List<? extends Instance>> blocks = new LinkedList<>();
int sIdx = 0;
int eIdx = 0;
while (eIdx < clusters.size()) {
eIdx = Math.min(sIdx+blocksize, clusters.size());
blocks.add(new LinkedList<Instance>(clusters.subList(sIdx, eIdx)));
sIdx = eIdx;
}
return blocks;
}
/***
* Public method to find the best cluster for inst to be a member.
*
* The method parallelizes the distance calculations between inst and the clusters
* using the executor service. The cluster that is the best candidate for the inst is returned.
*
* @param inst is the instance being considered
* @param clusters is a collection of clusters to search
* @return the best cluster
*/
public DistanceResult bestCluster(final Instance inst, final List<List<? extends Instance>> clusterBlocks) {
double bestScore = Double.MAX_VALUE;
Cluster bestCluster = null;
CompletionService<DistanceResult> batch = new ExecutorCompletionService<>(getExecutor());
for (final List<? extends Instance> clusters : clusterBlocks) {
batch.submit(() -> {
double bestDist = Double.MAX_VALUE;
Instance bestMatch = null;
for (Instance c : clusters) {
double d = distance(inst, c);
if (d < bestDist) {
bestDist = d;
bestMatch = c;
}
}
return new DistanceResult(inst, (Cluster)bestMatch, bestDist);
});
}
for (int i=0; i < clusterBlocks.size(); i++) {
try {
DistanceResult result = batch.take().get();
if (isCandidate(result.i, result.c, result.distance, bestCluster, bestScore)) {
bestScore = result.distance;
bestCluster = result.c;
if (firstCandidate) break; // best is the first candidate found
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} catch (Exception e) {
log.error("Error executing cluster distance task: {}", e.getLocalizedMessage());
}
}
return new DistanceResult(inst, bestCluster, bestScore);
}
/***
* Protected method to initiate clustering dataset. The public methods
* doCluster() and doIncrementalCluster() invoke this method.
*
* @param ds the data set to cluster
* @param clusters is a collection of clusters to modify
* @return a collection of modified clusters
*/
protected ClusterResult doCluster(DataSet ds, List<Cluster> clusters) {
double start = System.currentTimeMillis();
// if the clusterer hasn't been initially manually then init it now
if (exec == null) init();
LinkedHashSet<Cluster> modified = new LinkedHashSet<>();
for (Instance inst : ds) {
// Process in batches of blocks of 100 clusters
List<List<? extends Instance>> blocks = createBlocks(clusters, 100);
// double bestStart = System.currentTimeMillis();
Cluster bestCluster = bestCluster(inst, blocks).c;
// double bestTime = System.currentTimeMillis() - bestStart;
// log.debug("Find Best Cluster Time: {} ", bestTime);
if (bestCluster == null) { // no candidate cluster was found - create new one
bestCluster = createCluster();
bestCluster.add(inst);
if (!onlineUpdate) bestCluster.updateCentroid();
clusters.add(bestCluster);
}
else {
bestCluster.add(inst);
}
modified.add(bestCluster);
}
// centroids were not updated online so update them now
if (!onlineUpdate) {
for (Cluster c : modified) {
c.updateCentroid();
}
}
double clusterTime = System.currentTimeMillis() - start;
log.debug("Clustering time (s): {}", clusterTime / 1000);
return new InMemoryClusterResult(new LinkedList<>(modified));
}
@SuppressWarnings("unchecked")
@Override
public double distance(Instance inst1, Instance inst2) {
double totalDist = 0;
try {
for (FeatureTypeDefinition typedef : this.getTypeDefs()) {
if (typedef.distFunc.getWeight() < 0.00001) continue; // skip if weight is near zero
Feature f1 = inst1.getFeature(typedef.featureName);
Feature f2 = inst2.getFeature(typedef.featureName);
double d = 0;
if (f1 == null || f2 == null) {
d = penalizeMissingFeatures ? typedef.distFunc.getWeight() : 0;
}
else {
d = typedef.distFunc.distance(f1, f2) * typedef.distFunc.getWeight();
}
totalDist += d;
}
}
catch (Exception e) {
log.error("Error calculating distance between:\n---\n{}---\n{}---\nException:", inst1.toString(), inst2.toString(), e);
}
return totalDist;
}
private static class MyThreadFactory implements ThreadFactory {
@Override
public Thread newThread(Runnable r) {
return new Thread(r, "Clusterer Pool");
}
}
}