package mil.nga.giat.geowave.analytic.mapreduce.kmeans; import java.io.IOException; import java.util.List; import mil.nga.giat.geowave.analytic.AnalyticItemWrapper; import mil.nga.giat.geowave.analytic.AnalyticItemWrapperFactory; import mil.nga.giat.geowave.analytic.ScopedJobConfiguration; import mil.nga.giat.geowave.analytic.SimpleFeatureItemWrapperFactory; import mil.nga.giat.geowave.analytic.clustering.CentroidManagerGeoWave; import mil.nga.giat.geowave.analytic.clustering.CentroidPairing; import mil.nga.giat.geowave.analytic.clustering.DistortionGroupManagement; import mil.nga.giat.geowave.analytic.clustering.DistortionGroupManagement.DistortionDataAdapter; import mil.nga.giat.geowave.analytic.clustering.DistortionGroupManagement.DistortionEntry; import mil.nga.giat.geowave.analytic.clustering.NestedGroupCentroidAssignment; import mil.nga.giat.geowave.analytic.extract.CentroidExtractor; import mil.nga.giat.geowave.analytic.extract.SimpleFeatureCentroidExtractor; import mil.nga.giat.geowave.analytic.kmeans.AssociationNotification; import mil.nga.giat.geowave.analytic.mapreduce.CountofDoubleWritable; import mil.nga.giat.geowave.analytic.param.CentroidParameters; import mil.nga.giat.geowave.analytic.param.GlobalParameters; import mil.nga.giat.geowave.analytic.param.JumpParameters; import mil.nga.giat.geowave.mapreduce.GeoWaveWritableInputMapper; import mil.nga.giat.geowave.mapreduce.input.GeoWaveInputKey; import mil.nga.giat.geowave.mapreduce.output.GeoWaveOutputKey; import org.apache.hadoop.io.ObjectWritable; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.Reducer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.vividsolutions.jts.geom.Point; /** * Calculate the distortation. * <p/> * See Catherine A. Sugar and Gareth M. James (2003). * "Finding the number of clusters in a data set: An information theoretic approach" * Journal of the American Statistical Association 98 (January): 750–763 * * @formatter:off Context configuration parameters include: * <p/> * "KMeansDistortionMapReduce.Common.DistanceFunctionClass" -> * {@link mil.nga.giat.geowave.analytic.distance.DistanceFn} used * to determine distance to centroid * <p/> * "KMeansDistortionMapReduce.Centroid.WrapperFactoryClass" -> * {@link AnalyticItemWrapperFactory} to extract wrap spatial * objects with Centroid management functions * <p/> * "KMeansDistortionMapReduce.Centroid.ExtractorClass" -> * {@link mil.nga.giat.geowave.analytic.extract.CentroidExtractor} * <p/> * "KMeansDistortionMapReduce.Jump.CountOfCentroids" -> May be * different from actual. * @formatter:on * @see CentroidManagerGeoWave */ public class KMeansDistortionMapReduce { protected static final Logger LOGGER = LoggerFactory.getLogger(KMeansDistortionMapReduce.class); public static class KMeansDistortionMapper extends GeoWaveWritableInputMapper<Text, CountofDoubleWritable> { private NestedGroupCentroidAssignment<Object> nestedGroupCentroidAssigner; private final Text outputKeyWritable = new Text( "1"); private final CountofDoubleWritable outputValWritable = new CountofDoubleWritable(); private CentroidExtractor<Object> centroidExtractor; private AnalyticItemWrapperFactory<Object> itemWrapperFactory; AssociationNotification<Object> centroidAssociationFn = new AssociationNotification<Object>() { @Override public void notify( final CentroidPairing<Object> pairing ) { outputKeyWritable.set(pairing.getCentroid().getGroupID()); final double extraFromItem[] = pairing.getPairedItem().getDimensionValues(); final double extraCentroid[] = pairing.getCentroid().getDimensionValues(); final Point p = centroidExtractor.getCentroid(pairing.getPairedItem().getWrappedItem()); final Point centroid = centroidExtractor.getCentroid(pairing.getCentroid().getWrappedItem()); // calculate error for dp // using identity matrix for the common covariance, therefore // E[(p - c)^-1 * cov * (p - c)] => (px - cx)^2 + (py - cy)^2 double expectation = 0.0; for (int i = 0; i < extraCentroid.length; i++) { expectation += Math.pow( extraFromItem[i] - extraCentroid[i], 2); } expectation += (Math.pow( p.getCoordinate().x - centroid.getCoordinate().x, 2) + Math.pow( p.getCoordinate().y - centroid.getCoordinate().y, 2)); // + Math.pow( // p.getCoordinate().z - centroid.getCoordinate().z, // 2)); outputValWritable.set( expectation, 1); } }; @Override protected void mapNativeValue( final GeoWaveInputKey key, final Object value, final org.apache.hadoop.mapreduce.Mapper<GeoWaveInputKey, ObjectWritable, Text, CountofDoubleWritable>.Context context ) throws IOException, InterruptedException { nestedGroupCentroidAssigner.findCentroidForLevel( itemWrapperFactory.create(value), centroidAssociationFn); context.write( outputKeyWritable, outputValWritable); } @SuppressWarnings("unchecked") @Override protected void setup( final Mapper<GeoWaveInputKey, ObjectWritable, Text, CountofDoubleWritable>.Context context ) throws IOException, InterruptedException { super.setup(context); final ScopedJobConfiguration config = new ScopedJobConfiguration( context.getConfiguration(), KMeansDistortionMapReduce.class, KMeansDistortionMapReduce.LOGGER); try { nestedGroupCentroidAssigner = new NestedGroupCentroidAssignment<Object>( context, KMeansDistortionMapReduce.class, KMeansDistortionMapReduce.LOGGER); } catch (final Exception e1) { throw new IOException( e1); } try { centroidExtractor = config.getInstance( CentroidParameters.Centroid.EXTRACTOR_CLASS, CentroidExtractor.class, SimpleFeatureCentroidExtractor.class); } catch (final Exception e1) { throw new IOException( e1); } try { itemWrapperFactory = config.getInstance( CentroidParameters.Centroid.WRAPPER_FACTORY_CLASS, AnalyticItemWrapperFactory.class, SimpleFeatureItemWrapperFactory.class); } catch (final Exception e1) { throw new IOException( e1); } } } public static class KMeansDistorationCombiner extends Reducer<Text, CountofDoubleWritable, Text, CountofDoubleWritable> { final CountofDoubleWritable outputValue = new CountofDoubleWritable(); @Override public void reduce( final Text key, final Iterable<CountofDoubleWritable> values, final Reducer<Text, CountofDoubleWritable, Text, CountofDoubleWritable>.Context context ) throws IOException, InterruptedException { double expectation = 0; double ptCount = 0; for (final CountofDoubleWritable value : values) { expectation += value.getValue(); ptCount += value.getCount(); } outputValue.set( expectation, ptCount); context.write( key, outputValue); } } public static class KMeansDistortionReduce extends Reducer<Text, CountofDoubleWritable, GeoWaveOutputKey, DistortionEntry> { private Integer expectedK = null; final protected Text output = new Text( ""); private CentroidManagerGeoWave<Object> centroidManager; private String batchId; @Override public void reduce( final Text key, final Iterable<CountofDoubleWritable> values, final Reducer<Text, CountofDoubleWritable, GeoWaveOutputKey, DistortionEntry>.Context context ) throws IOException, InterruptedException { double expectation = 0.0; final List<AnalyticItemWrapper<Object>> centroids = centroidManager.getCentroidsForGroup(key.toString()); // it is possible that the number of items in a group are smaller // than the cluster final Integer kCount; if (expectedK == null) { kCount = centroids.size(); } else { kCount = expectedK; } if (centroids.size() == 0) { return; } final double numDimesions = 2 + centroids.get( 0).getExtraDimensions().length; double ptCount = 0; for (final CountofDoubleWritable value : values) { expectation += value.getValue(); ptCount += value.getCount(); } if (ptCount > 0) { expectation /= ptCount; final Double distortion = Math.pow( expectation / numDimesions, -(numDimesions / 2)); final DistortionEntry entry = new DistortionEntry( key.toString(), batchId, kCount, distortion); context.write( new GeoWaveOutputKey( DistortionDataAdapter.ADAPTER_ID, DistortionGroupManagement.DISTORTIONS_INDEX_LIST), entry); } } @Override protected void setup( final Reducer<Text, CountofDoubleWritable, GeoWaveOutputKey, DistortionEntry>.Context context ) throws IOException, InterruptedException { super.setup(context); final ScopedJobConfiguration config = new ScopedJobConfiguration( context.getConfiguration(), KMeansDistortionMapReduce.class, KMeansDistortionMapReduce.LOGGER); final int k = config.getInt( JumpParameters.Jump.COUNT_OF_CENTROIDS, -1); if (k > 0) { expectedK = k; } try { centroidManager = new CentroidManagerGeoWave<Object>( context, KMeansDistortionMapReduce.class, KMeansDistortionMapReduce.LOGGER); } catch (final Exception e) { KMeansDistortionMapReduce.LOGGER.warn( "Unable to initialize centroid manager", e); throw new IOException( "Unable to initialize centroid manager", e); } batchId = config.getString( GlobalParameters.Global.PARENT_BATCH_ID, centroidManager.getBatchId()); } } }