package quickml.supervised.inspection;
import com.google.common.collect.*;
import quickml.data.AttributesMap;
import quickml.data.instances.Instance;
import quickml.supervised.tree.constants.MissingValue;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Random;
/**
* Created by alexanderhawk on 11/14/14.
*/
public class CategoricalDistributionSampler {
public Map<Serializable, Long> getHistogramOfCountsForValues() {
return histogramOfCountsForValues;
}
Map<Serializable, Long> histogramOfCountsForValues = Maps.newHashMap();
ImmutableRangeMap<Double, Serializable> attributeValueRangeMap;
public static Random rand = new Random();
double actualSamples = 0;
public CategoricalDistributionSampler(List<Instance<AttributesMap, Serializable>> instances, int samplesToDraw, String attribute) {
updateDistributionSampler(instances, samplesToDraw, attribute);
}
public CategoricalDistributionSampler(List<Instance<AttributesMap, Serializable>> instances, double percentageOfAllSamplesToUse, String attribute) {
updateDistributionSampler(instances, percentageOfAllSamplesToUse, attribute);
}
public void updateDistributionSampler(List<Instance<AttributesMap, Serializable>> newInstances, double percentageOfAllSamplesToUse, String attribute) {
int samplesToDraw = (int)(percentageOfAllSamplesToUse * newInstances.size());
updateHistogramOfCountsForValues(newInstances, samplesToDraw, attribute);
createAttributeValueRangeMap();
}
public void updateDistributionSampler(List<Instance<AttributesMap, Serializable>> newInstances, int samplesToDraw, String attribute) {
updateHistogramOfCountsForValues(newInstances, samplesToDraw, attribute);
createAttributeValueRangeMap();
}
private void createAttributeValueRangeMap() {
double currentCount = 0, prevCount = 0;
ImmutableRangeMap.Builder<Double, Serializable> valuesWithProbabilityRangeBuilder = ImmutableRangeMap.builder();
// if (attributeValueRangeMap!=null) {
// valuesWithProbabilityRangeBuilder.putAll(attributeValueRangeMap);
// }
for (Serializable attributeVal : histogramOfCountsForValues.keySet()) {
prevCount = currentCount;
currentCount += histogramOfCountsForValues.get(attributeVal).doubleValue();
Range<Double> range = Range.closedOpen(prevCount/actualSamples, currentCount/actualSamples);
valuesWithProbabilityRangeBuilder.put(range, attributeVal);
}
attributeValueRangeMap = valuesWithProbabilityRangeBuilder.build();
}
private void updateHistogramOfCountsForValues(List<Instance<AttributesMap, Serializable>> instances, int samplesToDraw, String attribute) {
Serializable val;
//when the samples to draw are less than half the length of the list
if (instances.size() < samplesToDraw / 2) {
int folds = instances.size() / samplesToDraw;
for (int i = 0; i < instances.size(); i += folds) {
val = instances.get(i).getAttributes().get(attribute);
if (val == null)
val = MissingValue.MISSING_VALUE.name();
updateHistogram(val, histogramOfCountsForValues);
actualSamples++;
}
} else {
for (int i = instances.size()-1; i >= Math.max(0, instances.size() - samplesToDraw); i--) {
val = instances.get(i).getAttributes().get(attribute);
updateHistogram(val, histogramOfCountsForValues);
actualSamples++;
}
}
}
private void updateHistogram(Serializable val, Map<Serializable, Long> localHstogramOfCountsForValues) {
if (localHstogramOfCountsForValues.keySet().contains(val)) {
localHstogramOfCountsForValues.put(val, localHstogramOfCountsForValues.get(val).longValue() + 1L);
} else {
localHstogramOfCountsForValues.put(val, Long.valueOf(1));
}
}
public Serializable sampleHistogram() {
return attributeValueRangeMap.get(rand.nextDouble());
}
}