package mil.nga.giat.geowave.analytic.clustering;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.hadoop.io.Writable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import mil.nga.giat.geowave.analytic.AnalyticItemWrapperFactory;
import mil.nga.giat.geowave.core.index.ByteArrayId;
import mil.nga.giat.geowave.core.index.NumericIndexStrategy;
import mil.nga.giat.geowave.core.index.StringUtils;
import mil.nga.giat.geowave.core.index.sfc.data.MultiDimensionalNumericData;
import mil.nga.giat.geowave.core.store.CloseableIterator;
import mil.nga.giat.geowave.core.store.DataStore;
import mil.nga.giat.geowave.core.store.adapter.AdapterPersistenceEncoding;
import mil.nga.giat.geowave.core.store.adapter.AdapterStore;
import mil.nga.giat.geowave.core.store.adapter.IndexedAdapterPersistenceEncoding;
import mil.nga.giat.geowave.core.store.adapter.WritableDataAdapter;
import mil.nga.giat.geowave.core.store.data.IndexedPersistenceEncoding;
import mil.nga.giat.geowave.core.store.data.PersistentDataset;
import mil.nga.giat.geowave.core.store.data.field.FieldReader;
import mil.nga.giat.geowave.core.store.data.field.FieldUtils;
import mil.nga.giat.geowave.core.store.data.field.FieldVisibilityHandler;
import mil.nga.giat.geowave.core.store.data.field.FieldWriter;
import mil.nga.giat.geowave.core.store.dimension.NumericDimensionField;
import mil.nga.giat.geowave.core.store.filter.DistributableQueryFilter;
import mil.nga.giat.geowave.core.store.filter.QueryFilter;
import mil.nga.giat.geowave.core.store.index.CommonIndexModel;
import mil.nga.giat.geowave.core.store.index.CommonIndexValue;
import mil.nga.giat.geowave.core.store.index.Index;
import mil.nga.giat.geowave.core.store.index.IndexStore;
import mil.nga.giat.geowave.core.store.index.NullIndex;
import mil.nga.giat.geowave.core.store.index.PrimaryIndex;
import mil.nga.giat.geowave.core.store.query.Query;
import mil.nga.giat.geowave.core.store.query.QueryOptions;
/**
* Find the max change in distortion between some k and k-1, picking the value k
* associated with that change.
*
* In a multi-group setting, each group may have a different optimal k. Thus,
* the optimal batch may be different for each group. Each batch is associated
* with a different value k.
*
* Choose the appropriate batch for each group. Then change the batch identifier
* for group centroids to a final provided single batch identifier ( parent
* batch ).
*
*/
public class DistortionGroupManagement
{
final static Logger LOGGER = LoggerFactory.getLogger(DistortionGroupManagement.class);
public final static PrimaryIndex DISTORTIONS_INDEX = new NullIndex(
"DISTORTIONS");
public final static List<ByteArrayId> DISTORTIONS_INDEX_LIST = Collections.unmodifiableList(Arrays
.asList(DISTORTIONS_INDEX.getId()));
final DataStore dataStore;
final IndexStore indexStore;
final AdapterStore adapterStore;
public DistortionGroupManagement(
final DataStore dataStore,
final IndexStore indexStore,
final AdapterStore adapterStore ) {
this.dataStore = dataStore;
this.indexStore = indexStore;
this.adapterStore = adapterStore;
indexStore.addIndex(DISTORTIONS_INDEX);
adapterStore.addAdapter(new DistortionDataAdapter());
}
public static class BatchIdFilter implements
DistributableQueryFilter
{
String batchId;
public BatchIdFilter() {
}
public BatchIdFilter(
final String batchId ) {
super();
this.batchId = batchId;
}
@Override
public boolean accept(
final CommonIndexModel indexModel,
final IndexedPersistenceEncoding<?> persistenceEncoding ) {
return new DistortionEntry(
persistenceEncoding.getDataId(),
0.0).batchId.equals(batchId);
}
@Override
public byte[] toBinary() {
return StringUtils.stringToBinary(batchId);
}
@Override
public void fromBinary(
final byte[] bytes ) {
batchId = StringUtils.stringFromBinary(bytes);
}
}
public static class BatchIdQuery implements
Query
{
String batchId;
public BatchIdQuery() {}
public BatchIdQuery(
final String batchId ) {
super();
this.batchId = batchId;
}
@Override
public List<QueryFilter> createFilters(
final CommonIndexModel indexModel ) {
return Collections.<QueryFilter> singletonList(new BatchIdFilter(
batchId));
}
@Override
public boolean isSupported(
final Index<?, ?> index ) {
return index instanceof NullIndex;
}
@Override
public List<MultiDimensionalNumericData> getIndexConstraints(
final NumericIndexStrategy indexStrategy ) {
return Collections.emptyList();
}
}
/**
*
* @param ops
* @param distortationTableName
* the name of the table holding the distortions
* @param parentBatchId
* the batch id to associate with the centroids for each group
* @return
*/
public <T> int retainBestGroups(
final AnalyticItemWrapperFactory<T> itemWrapperFactory,
final String dataTypeId,
final String indexId,
final String batchId,
final int level ) {
try {
final Map<String, DistortionGroup> groupDistortions = new HashMap<String, DistortionGroup>();
// row id is group id
// colQual is cluster count
try (CloseableIterator<DistortionEntry> it = dataStore.query(
new QueryOptions(
new DistortionDataAdapter(),
DISTORTIONS_INDEX),
new BatchIdQuery(
batchId))) {
while (it.hasNext()) {
final DistortionEntry entry = it.next();
final String groupID = entry.getGroupId();
final Integer clusterCount = entry.getClusterCount();
final Double distortion = entry.getDistortionValue();
DistortionGroup grp = groupDistortions.get(groupID);
if (grp == null) {
grp = new DistortionGroup(
groupID);
groupDistortions.put(
groupID,
grp);
}
grp.addPair(
clusterCount,
distortion);
}
}
final CentroidManagerGeoWave<T> centroidManager = new CentroidManagerGeoWave<T>(
dataStore,
indexStore,
adapterStore,
itemWrapperFactory,
dataTypeId,
indexId,
batchId,
level);
for (final DistortionGroup grp : groupDistortions.values()) {
final int optimalK = grp.bestCount();
final String kbatchId = batchId + "_" + optimalK;
centroidManager.transferBatch(
kbatchId,
grp.getGroupID());
}
}
catch (final RuntimeException ex) {
throw ex;
}
catch (final Exception ex) {
LOGGER.error(
"Cannot determine groups for batch",
ex);
return 1;
}
return 0;
}
public static class DistortionEntry implements
Writable
{
private String groupId;
private String batchId;
private Integer clusterCount;
private Double distortionValue;
public DistortionEntry() {}
public DistortionEntry(
final String groupId,
final String batchId,
final Integer clusterCount,
final Double distortionValue ) {
this.groupId = groupId;
this.batchId = batchId;
this.clusterCount = clusterCount;
this.distortionValue = distortionValue;
}
private DistortionEntry(
final ByteArrayId dataId,
final Double distortionValue ) {
final String dataIdStr = StringUtils.stringFromBinary(dataId.getBytes());
final String[] split = dataIdStr.split("/");
batchId = split[0];
groupId = split[1];
clusterCount = Integer.parseInt(split[2]);
this.distortionValue = distortionValue;
}
public String getGroupId() {
return groupId;
}
public Integer getClusterCount() {
return clusterCount;
}
public Double getDistortionValue() {
return distortionValue;
}
private ByteArrayId getDataId() {
return new ByteArrayId(
batchId + "/" + groupId + "/" + clusterCount);
}
@Override
public void write(
final DataOutput out )
throws IOException {
out.writeUTF(groupId);
out.writeUTF(batchId);
out.writeInt(clusterCount);
out.writeDouble(distortionValue);
}
@Override
public void readFields(
final DataInput in )
throws IOException {
groupId = in.readUTF();
batchId = in.readUTF();
clusterCount = in.readInt();
distortionValue = in.readDouble();
}
}
private static class DistortionGroup
{
final String groupID;
final List<Pair<Integer, Double>> clusterCountToDistortion = new ArrayList<Pair<Integer, Double>>();
public DistortionGroup(
final String groupID ) {
this.groupID = groupID;
}
public void addPair(
final Integer count,
final Double distortion ) {
clusterCountToDistortion.add(Pair.of(
count,
distortion));
}
public String getGroupID() {
return groupID;
}
public int bestCount() {
Collections.sort(
clusterCountToDistortion,
new Comparator<Pair<Integer, Double>>() {
@Override
public int compare(
final Pair<Integer, Double> arg0,
final Pair<Integer, Double> arg1 ) {
return arg0.getKey().compareTo(
arg1.getKey());
}
});
double maxJump = -1.0;
Integer jumpIdx = -1;
Double oldD = 0.0; // base case !?
for (final Pair<Integer, Double> pair : clusterCountToDistortion) {
final Double jump = pair.getValue() - oldD;
if (jump > maxJump) {
maxJump = jump;
jumpIdx = pair.getKey();
}
oldD = pair.getValue();
}
return jumpIdx;
}
}
public static class DistortionDataAdapter implements
WritableDataAdapter<DistortionEntry>
{
public final static ByteArrayId ADAPTER_ID = new ByteArrayId(
"distortion");
private final static ByteArrayId DISTORTION_FIELD_ID = new ByteArrayId(
"distortion");
private final FieldVisibilityHandler<DistortionEntry, Object> distortionVisibilityHandler;
public DistortionDataAdapter() {
this(
null);
}
public DistortionDataAdapter(
final FieldVisibilityHandler<DistortionEntry, Object> distortionVisibilityHandler ) {
this.distortionVisibilityHandler = distortionVisibilityHandler;
}
@Override
public ByteArrayId getAdapterId() {
return ADAPTER_ID;
}
@Override
public boolean isSupported(
final DistortionEntry entry ) {
return true;
}
@Override
public ByteArrayId getDataId(
final DistortionEntry entry ) {
return entry.getDataId();
}
@Override
public DistortionEntry decode(
final IndexedAdapterPersistenceEncoding data,
final PrimaryIndex index ) {
return new DistortionEntry(
data.getDataId(),
(Double) data.getAdapterExtendedData().getValue(
DISTORTION_FIELD_ID));
}
@Override
public AdapterPersistenceEncoding encode(
final DistortionEntry entry,
final CommonIndexModel indexModel ) {
final Map<ByteArrayId, Object> fieldIdToValueMap = new HashMap<ByteArrayId, Object>();
fieldIdToValueMap.put(
DISTORTION_FIELD_ID,
entry.getDistortionValue());
return new AdapterPersistenceEncoding(
getAdapterId(),
entry.getDataId(),
new PersistentDataset<CommonIndexValue>(),
new PersistentDataset<Object>(
fieldIdToValueMap));
}
@Override
public FieldReader<Object> getReader(
final ByteArrayId fieldId ) {
if (DISTORTION_FIELD_ID.equals(fieldId)) {
return (FieldReader) FieldUtils.getDefaultReaderForClass(Double.class);
}
return null;
}
@Override
public byte[] toBinary() {
return new byte[] {};
}
@Override
public void fromBinary(
final byte[] bytes ) {}
@Override
public FieldWriter<DistortionEntry, Object> getWriter(
final ByteArrayId fieldId ) {
if (DISTORTION_FIELD_ID.equals(fieldId)) {
if (distortionVisibilityHandler != null) {
return (FieldWriter) FieldUtils.getDefaultWriterForClass(
Double.class,
distortionVisibilityHandler);
}
else {
return (FieldWriter) FieldUtils.getDefaultWriterForClass(Double.class);
}
}
return null;
}
@Override
public int getPositionOfOrderedField(
final CommonIndexModel model,
final ByteArrayId fieldId ) {
int i = 0;
for (final NumericDimensionField<? extends CommonIndexValue> dimensionField : model.getDimensions()) {
if (fieldId.equals(dimensionField.getFieldId())) {
return i;
}
i++;
}
if (fieldId.equals(DISTORTION_FIELD_ID)) {
return i;
}
return -1;
}
@Override
public ByteArrayId getFieldIdForPosition(
final CommonIndexModel model,
final int position ) {
if (position < model.getDimensions().length) {
int i = 0;
for (final NumericDimensionField<? extends CommonIndexValue> dimensionField : model.getDimensions()) {
if (i == position) {
return dimensionField.getFieldId();
}
i++;
}
}
else {
final int numDimensions = model.getDimensions().length;
if (position == numDimensions) {
return DISTORTION_FIELD_ID;
}
}
return null;
}
}
}