package mil.nga.giat.geowave.analytic.mapreduce.kmeans.runner;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import mil.nga.giat.geowave.analytic.AnalyticItemWrapper;
import mil.nga.giat.geowave.analytic.IndependentJobRunner;
import mil.nga.giat.geowave.analytic.PropertyManagement;
import mil.nga.giat.geowave.analytic.SimpleFeatureItemWrapperFactory;
import mil.nga.giat.geowave.analytic.clustering.CentroidManager;
import mil.nga.giat.geowave.analytic.clustering.CentroidManager.CentroidProcessingFn;
import mil.nga.giat.geowave.analytic.clustering.CentroidManagerGeoWave;
import mil.nga.giat.geowave.analytic.clustering.NestedGroupCentroidAssignment;
import mil.nga.giat.geowave.analytic.distance.DistanceFn;
import mil.nga.giat.geowave.analytic.distance.FeatureCentroidDistanceFn;
import mil.nga.giat.geowave.analytic.extract.SimpleFeatureCentroidExtractor;
import mil.nga.giat.geowave.analytic.mapreduce.MapReduceJobController;
import mil.nga.giat.geowave.analytic.mapreduce.MapReduceJobRunner;
import mil.nga.giat.geowave.analytic.param.CentroidParameters;
import mil.nga.giat.geowave.analytic.param.ClusteringParameters;
import mil.nga.giat.geowave.analytic.param.CommonParameters;
import mil.nga.giat.geowave.analytic.param.FormatConfiguration;
import mil.nga.giat.geowave.analytic.param.ParameterEnum;
import org.apache.hadoop.conf.Configuration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
*
* Run 'K' means until convergence across ALL groups.
*
*
*/
public class KMeansIterationsJobRunner<T> implements
MapReduceJobRunner,
IndependentJobRunner
{
protected static final Logger LOGGER = LoggerFactory.getLogger(KMeansIterationsJobRunner.class);
private final KMeansJobRunner jobRunner = new KMeansJobRunner();
private double convergenceTol = 0.0001;
public KMeansIterationsJobRunner() {}
protected CentroidManager<T> constructCentroidManager(
final Configuration config,
final PropertyManagement runTimeProperties )
throws IOException {
return new CentroidManagerGeoWave<T>(
runTimeProperties);
}
public void setInputFormatConfiguration(
final FormatConfiguration inputFormatConfiguration ) {
jobRunner.setInputFormatConfiguration(inputFormatConfiguration);
}
public void setReducerCount(
final int reducerCount ) {
jobRunner.setReducerCount(reducerCount);
}
@SuppressWarnings("unchecked")
@Override
public int run(
final Configuration config,
final PropertyManagement runTimeProperties )
throws Exception {
convergenceTol = runTimeProperties.getPropertyAsDouble(
ClusteringParameters.Clustering.CONVERGANCE_TOLERANCE,
convergenceTol);
final DistanceFn<T> distanceFunction = runTimeProperties.getClassInstance(
CommonParameters.Common.DISTANCE_FUNCTION_CLASS,
DistanceFn.class,
FeatureCentroidDistanceFn.class);
int maxIterationCount = runTimeProperties.getPropertyAsInt(
ClusteringParameters.Clustering.MAX_ITERATIONS,
15);
boolean converged = false;
while (!converged && (maxIterationCount > 0)) {
final int status = runJob(
config,
runTimeProperties);
if (status != 0) {
return status;
}
// new one each time to force a refresh of the centroids
final CentroidManager<T> centroidManager = constructCentroidManager(
config,
runTimeProperties);
// check for convergence
converged = checkForConvergence(
centroidManager,
distanceFunction);
maxIterationCount--;
}
return 0;
}
protected int runJob(
final Configuration config,
final PropertyManagement runTimeProperties )
throws Exception {
runTimeProperties.storeIfEmpty(
CentroidParameters.Centroid.EXTRACTOR_CLASS,
SimpleFeatureCentroidExtractor.class);
runTimeProperties.storeIfEmpty(
CentroidParameters.Centroid.WRAPPER_FACTORY_CLASS,
SimpleFeatureItemWrapperFactory.class);
runTimeProperties.storeIfEmpty(
CommonParameters.Common.DISTANCE_FUNCTION_CLASS,
FeatureCentroidDistanceFn.class);
return jobRunner.run(
config,
runTimeProperties);
}
private boolean checkForConvergence(
final CentroidManager<T> centroidManager,
final DistanceFn<T> distanceFunction )
throws IOException {
final AtomicInteger grpCount = new AtomicInteger(
0);
final AtomicInteger failuresCount = new AtomicInteger(
0);
final AtomicInteger centroidCount = new AtomicInteger(
0);
final boolean status = centroidManager.processForAllGroups(new CentroidProcessingFn<T>() {
@Override
public int processGroup(
final String groupID,
final List<AnalyticItemWrapper<T>> centroids ) {
grpCount.incrementAndGet();
centroidCount.addAndGet(centroids.size() / 2);
if (LOGGER.isTraceEnabled()) {
LOGGER.trace(
"Parent Group: {} ",
groupID);
for (final AnalyticItemWrapper<T> troid : centroids) {
LOGGER.warn(
"Child Group: {} ",
troid.getID());
}
}
failuresCount.addAndGet(computeCostAndCleanUp(
groupID,
centroids,
centroidManager,
distanceFunction));
return 0;
}
}) == 0 ? true : false;
// update default based on data size
setReducerCount(grpCount.get() * centroidCount.get());
return status && (failuresCount.get() == 0);
}
protected int computeCostAndCleanUp(
final String groupID,
final List<AnalyticItemWrapper<T>> centroids,
final CentroidManager<T> centroidManager,
final DistanceFn<T> distanceFunction ) {
double distance = 0;
final List<String> deletionKeys = new ArrayList<String>();
// sort by id and then by iteration
Collections.sort(
centroids,
new Comparator<AnalyticItemWrapper<T>>() {
@Override
public int compare(
final AnalyticItemWrapper<T> arg0,
final AnalyticItemWrapper<T> arg1 ) {
final int c = arg0.getName().compareTo(
arg1.getName());
if (c == 0) {
return arg0.getIterationID() - arg1.getIterationID();
}
else {
return c;
}
}
});
AnalyticItemWrapper<T> prior = null;
for (final AnalyticItemWrapper<T> centroid : centroids) {
if (prior == null) {
prior = centroid;
continue;
}
else if (!prior.getName().equals(
centroid.getName())) {
// should we delete this...it is a centroid without assigned
// points? This occurs when the number of centroids exceeds the
// number of points in a cluster.
// it is an edge case.
// deletionKeys.add( prior.getID() );
LOGGER.warn("Centroid is no longer viable " + prior.getID() + " from group " + prior.getGroupID());
prior = centroid;
continue;
}
// the prior run centroids are still present from the geowave data
// store;
// their priors do not exist in the map
distance += distanceFunction.measure(
prior.getWrappedItem(),
centroid.getWrappedItem());
deletionKeys.add(prior.getID());
if (LOGGER.isTraceEnabled()) {
LOGGER.trace(
"Within group {} replace {} with {}",
new String[] {
prior.getGroupID(),
prior.getID(),
centroid.getID()
});
}
prior = null;
}
distance /= centroids.size();
try {
centroidManager.delete(deletionKeys.toArray(new String[deletionKeys.size()]));
}
catch (final IOException e) {
throw new RuntimeException(
e);
}
return (distance < convergenceTol) ? 0 : 1;
}
@Override
public Collection<ParameterEnum<?>> getParameters() {
final Set<ParameterEnum<?>> params = new HashSet<ParameterEnum<?>>();
params.addAll(Arrays.asList(new ParameterEnum<?>[] {
CentroidParameters.Centroid.INDEX_ID,
CentroidParameters.Centroid.DATA_TYPE_ID,
CentroidParameters.Centroid.DATA_NAMESPACE_URI,
CentroidParameters.Centroid.EXTRACTOR_CLASS,
CentroidParameters.Centroid.WRAPPER_FACTORY_CLASS,
ClusteringParameters.Clustering.MAX_REDUCER_COUNT,
ClusteringParameters.Clustering.MAX_ITERATIONS,
ClusteringParameters.Clustering.CONVERGANCE_TOLERANCE,
CommonParameters.Common.DISTANCE_FUNCTION_CLASS
}));
params.addAll(CentroidManagerGeoWave.getParameters());
params.addAll(NestedGroupCentroidAssignment.getParameters());
params.addAll(jobRunner.getParameters());
return params;
}
@Override
public int run(
final PropertyManagement runTimeProperties )
throws Exception {
return this.run(
MapReduceJobController.getConfiguration(runTimeProperties),
runTimeProperties);
}
}