/*
* Copyright 2015 MiLaboratory.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.milaboratory.core.clustering;
import com.milaboratory.core.sequence.Alphabet;
import com.milaboratory.core.sequence.Sequence;
import com.milaboratory.core.tree.NeighborhoodIterator;
import com.milaboratory.core.tree.SequenceTreeMap;
import com.milaboratory.core.tree.TreeSearchParameters;
import com.milaboratory.util.CanReportProgress;
import com.milaboratory.util.Factory;
import java.util.*;
import static com.milaboratory.core.tree.SequenceTreeMap.Node;
public final class Clustering<T, S extends Sequence<S>> implements CanReportProgress {
final Collection<T> inputObjects;
final SequenceExtractor<T, S> sequenceExtractor;
final ClusteringStrategy<T, S> strategy;
final List<Cluster<T>> clusters = new ArrayList<>();
volatile int progress;
public Clustering(Collection<T> inputObjects, SequenceExtractor<T, S> sequenceExtractor, ClusteringStrategy<T, S> strategy) {
this.inputObjects = inputObjects;
this.sequenceExtractor = sequenceExtractor;
this.strategy = strategy;
}
@Override
public double getProgress() {
return (1.0 * progress) / inputObjects.size();
}
@Override
public boolean isFinished() {
return progress == inputObjects.size();
}
public List<Cluster<T>> performClustering() {
try {
if (inputObjects.isEmpty())
return clusters;
final Comparator<Cluster<T>> clusterComparator = getComparatorOfClusters(strategy, sequenceExtractor);
// For performance
final TreeSearchParameters params = strategy.getSearchParameters();
final int maxDepth = strategy.getMaxClusterDepth();
final List<T> objects = new ArrayList<>(inputObjects);
Collections.sort(objects, getComparatorOfObjectsRegardingSequences(strategy, sequenceExtractor));
@SuppressWarnings("unchecked")
Alphabet<S> alphabet = sequenceExtractor.getSequence(objects.get(0)).getAlphabet();
final Factory<T[]> arrayFactory = new Factory<T[]>() {
@Override
public T[] create() {
return (T[]) new Object[1];
}
};
final SequenceTreeMap<S, T[]> tree = new SequenceTreeMap<>(alphabet);
for (T object : objects) {
T[] array = tree.createIfAbsent(sequenceExtractor.getSequence(object), arrayFactory);
if (array[0] == null)
array[0] = object;
else {
array = Arrays.copyOf(array, array.length + 1);
array[array.length - 1] = object;
tree.put(sequenceExtractor.getSequence(object), array);
}
}
Node<T[]> current;
final HashSet<Node<T[]>> processedNodes = new HashSet<>();
ArrayList<Cluster<T>> previousLayer = new ArrayList<>(), nextLayer = new ArrayList<>(), tmp;
T[] temp;
boolean inTree;
for (int i = 0; i < objects.size(); ++i) {
this.progress = i;
T object = objects.get(i);
//checking whether object is already clusterized
if ((temp = tree.get(sequenceExtractor.getSequence(object))) == null)
continue;
inTree = false;
for (T t : temp)
if (t == object) {
inTree = true;
break;
}
if (!inTree)
continue;
//<-object in not yet clusterized
Cluster<T> tempCluster = new Cluster<>(object);
clusters.add(tempCluster);
previousLayer.clear();
previousLayer.add(tempCluster);
for (int depth = 0; depth < maxDepth; ++depth) {
nextLayer.clear();
for (Cluster<T> previousCluster : previousLayer) {
NeighborhoodIterator<S, T[]> iterator = tree
.getNeighborhoodIterator(sequenceExtractor
.getSequence(previousCluster.head), params, null);
processedNodes.clear();
while ((current = iterator.nextNode()) != null) {
if (!processedNodes.add(current))
continue;
T[] currentObjects = current.getObject();
T matchedObject = null;
boolean allNulls = true;
for (int j = 0; j < currentObjects.length; j++) {
if (currentObjects[j] == null)
continue;
matchedObject = currentObjects[j];
if (strategy.compare(previousCluster.head, matchedObject) <= 0
|| !strategy.canAddToCluster(previousCluster, matchedObject, iterator)) {
allNulls = false;
continue;
}
nextLayer.add(tempCluster = new Cluster<>(matchedObject, previousCluster));
previousCluster.add(tempCluster);
currentObjects[j] = null;
}
assert matchedObject != null;
if (allNulls)
tree.remove(sequenceExtractor.getSequence(matchedObject));
}
if (previousCluster.children != null)
Collections.sort(previousCluster.children, clusterComparator);
}
Collections.sort(nextLayer, clusterComparator);
tmp = nextLayer;
nextLayer = previousLayer;
previousLayer = tmp;
}
}
return clusters;
} finally {
progress = inputObjects.size();
}
}
public List<Cluster<T>> getClusters() {
if (progress != inputObjects.size())
throw new IllegalStateException("Not yet clustered.");
return clusters;
}
static <T, S extends Sequence> Comparator<Cluster<T>>
getComparatorOfClusters(final Comparator<T> objectComparator, final SequenceExtractor<T, S> extractor) {
return new Comparator<Cluster<T>>() {
@Override
public int compare(Cluster<T> o1, Cluster<T> o2) {
int i = objectComparator.compare(o2.head, o1.head);
return i == 0 ?
extractor.getSequence(o2.head).compareTo(extractor.getSequence(o1.head))
: i;
}
};
}
static <T, S extends Sequence> Comparator<T>
getComparatorOfObjectsRegardingSequences(final Comparator<T> objectComparator, final SequenceExtractor<T, S> extractor) {
return new Comparator<T>() {
@Override
public int compare(T o1, T o2) {
int i = objectComparator.compare(o2, o1);
return i == 0 ?
extractor.getSequence(o2).compareTo(extractor.getSequence(o1))
: i;
}
};
}
public static <T, S extends Sequence<S>> List<Cluster<T>> performClustering(Collection<T> inputObjects,
SequenceExtractor<T, S> sequenceExtractor,
ClusteringStrategy<T, S> strategy) {
return new Clustering<T, S>(inputObjects, sequenceExtractor, strategy).performClustering();
}
}