package mil.nga.giat.geowave.analytic.sample.function; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import mil.nga.giat.geowave.analytic.AnalyticItemWrapper; import mil.nga.giat.geowave.analytic.AnalyticItemWrapperFactory; import mil.nga.giat.geowave.analytic.PropertyManagement; import mil.nga.giat.geowave.analytic.ScopedJobConfiguration; import mil.nga.giat.geowave.analytic.SimpleFeatureItemWrapperFactory; import mil.nga.giat.geowave.analytic.clustering.CentroidPairing; import mil.nga.giat.geowave.analytic.clustering.NestedGroupCentroidAssignment; import mil.nga.giat.geowave.analytic.distance.DistanceFn; import mil.nga.giat.geowave.analytic.kmeans.AssociationNotification; import mil.nga.giat.geowave.analytic.param.CentroidParameters; import mil.nga.giat.geowave.analytic.param.ParameterEnum; import mil.nga.giat.geowave.analytic.param.SampleParameters; import mil.nga.giat.geowave.analytic.sample.RandomProbabilitySampleFn; import mil.nga.giat.geowave.analytic.sample.SampleProbabilityFn; import mil.nga.giat.geowave.mapreduce.GeoWaveConfiguratorBase; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapreduce.JobContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Rank objects using their distance to the closest centroid of a set of * centroids. The specific rank is determined by the probability of the point * meeting being a centroid, modeled in the implementation of * {@link SampleProbabilityFn}. * * The farther the distance, the higher the rank. * * @formatter:off Properties: * * "CentroidDistanceBasedSamplingRankFunction.KMeansConfig.data_store_configuration" * - The class used to determine the prefix class name for te * GeoWave Data Store parameters for a connection to collect the * starting set of centroids. Defaults to * {@link CentroidDistanceBasedSamplingRankFunction}. * * * "CentroidDistanceBasedSamplingRankFunction.KMeansConfig.probability_function" * - implementation of {@link SampleProbabilityFn} * * "CentroidDistanceBasedSamplingRankFunction.KMeansConfig.distance_function" * - {@link DistanceFn} * * "CentroidDistanceBasedSamplingRankFunction.KMeansConfig.centroid_factory" * - {@link AnalyticItemWrapperFactory} to wrap the centroid data * with the appropriate centroid wrapper * {@link AnalyticItemWrapper} * * @ee CentroidManagerGeoWave * * * @formatter:on * * See {@link GeoWaveConfiguratorBase} for information for * configuration GeoWave Data Store for consumption of starting * set of centroids. * * @param <T> * The data type for the object being sampled */ public class CentroidDistanceBasedSamplingRankFunction<T> implements SamplingRankFunction<T> { protected static final Logger LOGGER = LoggerFactory.getLogger(CentroidDistanceBasedSamplingRankFunction.class); private SampleProbabilityFn sampleProbabilityFn; private NestedGroupCentroidAssignment<T> nestedGroupCentroidAssigner; private final Map<String, Double> groupToConstant = new HashMap<String, Double>(); protected AnalyticItemWrapperFactory<T> itemWrapperFactory;; public static void setParameters( final Configuration config, final Class<?> scope, final PropertyManagement runTimeProperties ) { NestedGroupCentroidAssignment.setParameters( config, scope, runTimeProperties); runTimeProperties.setConfig( new ParameterEnum[] { SampleParameters.Sample.PROBABILITY_FUNCTION, CentroidParameters.Centroid.WRAPPER_FACTORY_CLASS, }, config, scope); } @SuppressWarnings("unchecked") @Override public void initialize( final JobContext context, final Class<?> scope, final Logger logger ) throws IOException { final ScopedJobConfiguration config = new ScopedJobConfiguration( context.getConfiguration(), scope); try { sampleProbabilityFn = config.getInstance( SampleParameters.Sample.PROBABILITY_FUNCTION, SampleProbabilityFn.class, RandomProbabilitySampleFn.class); } catch (final Exception e) { throw new IOException( e); } try { itemWrapperFactory = config.getInstance( CentroidParameters.Centroid.WRAPPER_FACTORY_CLASS, AnalyticItemWrapperFactory.class, SimpleFeatureItemWrapperFactory.class); itemWrapperFactory.initialize( context, scope, logger); } catch (final Exception e1) { throw new IOException( e1); } try { nestedGroupCentroidAssigner = new NestedGroupCentroidAssignment<T>( context, scope, logger); } catch (final Exception e1) { throw new IOException( e1); } } /** * */ @Override public double rank( final int sampleSize, final T value ) { final AnalyticItemWrapper<T> item = itemWrapperFactory.create(value); final List<AnalyticItemWrapper<T>> centroids = new ArrayList<AnalyticItemWrapper<T>>(); double weight; try { weight = nestedGroupCentroidAssigner.findCentroidForLevel( item, new AssociationNotification<T>() { @Override public void notify( final CentroidPairing<T> pairing ) { try { centroids.addAll(nestedGroupCentroidAssigner.getCentroidsForGroup(pairing .getCentroid() .getGroupID())); } catch (final IOException e) { throw new RuntimeException( e); } } }); } catch (final IOException e) { throw new RuntimeException( e); } return sampleProbabilityFn.getProbability( weight, getNormalizingConstant( centroids.get( 0).getGroupID(), centroids), sampleSize); } private double getNormalizingConstant( final String groupID, final List<AnalyticItemWrapper<T>> centroids ) { if (!groupToConstant.containsKey(groupID)) { double constant = 0.0; for (final AnalyticItemWrapper<T> centroid : centroids) { constant += centroid.getCost(); } groupToConstant.put( groupID, constant); } return groupToConstant.get( groupID).doubleValue(); } }