package mil.nga.giat.geowave.analytic.kmeans.serial;
import java.util.ArrayList;
import java.util.List;
import mil.nga.giat.geowave.analytic.AnalyticItemWrapper;
import mil.nga.giat.geowave.analytic.AnalyticItemWrapperFactory;
import mil.nga.giat.geowave.analytic.clustering.CentroidPairing;
import mil.nga.giat.geowave.analytic.kmeans.AssociationNotification;
import mil.nga.giat.geowave.analytic.kmeans.CentroidAssociationFn;
import mil.nga.giat.geowave.analytic.sample.SampleNotification;
import mil.nga.giat.geowave.analytic.sample.Sampler;
import org.apache.commons.lang3.tuple.Pair;
public class KMeansParallelInitialize<T>
{
private CentroidAssociationFn<T> centroidAssociationFn = new CentroidAssociationFn<T>();
private double psi = 5.0;
private final Sampler<T> sampler = new Sampler<T>();
private AnalyticItemWrapperFactory<T> centroidFactory;
private final AnalyticStats stats = new StatsMap();
public CentroidAssociationFn<T> getCentroidAssociationFn() {
return centroidAssociationFn;
}
public void setCentroidAssociationFn(
final CentroidAssociationFn<T> centroidAssociationFn ) {
this.centroidAssociationFn = centroidAssociationFn;
}
public double getPsi() {
return psi;
}
public void setPsi(
final double psi ) {
this.psi = psi;
}
public Sampler<T> getSampler() {
return sampler;
}
public AnalyticItemWrapperFactory<T> getCentroidFactory() {
return centroidFactory;
}
public void setCentroidFactory(
final AnalyticItemWrapperFactory<T> centroidFactory ) {
this.centroidFactory = centroidFactory;
}
public AnalyticStats getStats() {
return stats;
}
public Pair<List<CentroidPairing<T>>, List<AnalyticItemWrapper<T>>> runLocal(
final Iterable<AnalyticItemWrapper<T>> pointSet ) {
stats.reset();
final List<AnalyticItemWrapper<T>> sampleSet = new ArrayList<AnalyticItemWrapper<T>>();
sampleSet.add(pointSet.iterator().next());
final List<CentroidPairing<T>> pairingSet = new ArrayList<CentroidPairing<T>>();
final AssociationNotification<T> assocFn = new AssociationNotification<T>() {
@Override
public void notify(
final CentroidPairing<T> pairing ) {
pairingSet.add(pairing);
pairing.getCentroid().incrementAssociationCount(
1);
}
};
// combine to get pairing?
double normalizingConstant = centroidAssociationFn.compute(
pointSet,
sampleSet,
assocFn);
stats.notify(
AnalyticStats.StatValue.COST,
normalizingConstant);
final int logPsi = Math.max(
1,
(int) (Math.log(psi) / Math.log(2)));
for (int i = 0; i < logPsi; i++) {
sampler.sample(
pairingSet,
new SampleNotification<T>() {
@Override
public void notify(
final T item,
final boolean partial ) {
sampleSet.add(centroidFactory.create(item));
}
},
normalizingConstant);
pairingSet.clear();
for (final AnalyticItemWrapper<T> centroid : sampleSet) {
centroid.resetAssociatonCount();
}
normalizingConstant = centroidAssociationFn.compute(
pointSet,
sampleSet,
assocFn);
stats.notify(
AnalyticStats.StatValue.COST,
normalizingConstant);
}
return Pair.of(
pairingSet,
sampleSet);
}
}