package quickml.supervised.inspection;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.AttributesMap;
import quickml.data.instances.Instance;
import quickml.data.instances.InstanceImpl;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
/**
* Created by alexanderhawk on 11/17/14.
*/
public class CategoricalDistributionSamplerTest {
private static final Logger logger = LoggerFactory.getLogger(CategoricalDistributionSamplerTest.class);
@Test
public void testStoresCountsCorrectly() {
List<Instance<AttributesMap, Serializable>> instances = getInstances();
CategoricalDistributionSampler categoricalDistributionSampler = new CategoricalDistributionSampler(instances, 1.0, "v1");
Map<Serializable, Long> distribution = categoricalDistributionSampler.getHistogramOfCountsForValues();
Assert.assertTrue(distribution.get("cat1").equals(1L));
Assert.assertTrue(distribution.get("cat2").equals(3L));
}
@Test
public void samplesDistCorrectly() {
List<Instance<AttributesMap, Serializable>> instances = getInstances();
//add same instances multiple times since we'll be randomly (nearly randomly) removing instances
for (int i = 0; i<10; i++) {
instances.addAll(getInstances());
}
CategoricalDistributionSampler categoricalDistributionSampler = new CategoricalDistributionSampler(instances, .5, "v1");
Map<Serializable, Long> actualHistogram = categoricalDistributionSampler.getHistogramOfCountsForValues();
Map<Serializable, Double> sampledDistribution = Maps.newHashMap();
double tolerance = 0.1;
double samples = 800;
for (int i = 0; i<samples; i++) {
String val = (String) categoricalDistributionSampler.sampleHistogram();
if (sampledDistribution.containsKey(val)) {
sampledDistribution.put(val, (Double) (sampledDistribution.get(val)).doubleValue() + 1.0);
} else {
sampledDistribution.put(val, 1.0);
}
}
double total = 0;
for (Serializable val : actualHistogram.keySet()) {
total += actualHistogram.get(val).doubleValue();
}
Map<Serializable, Double> actualDistribution = Maps.newHashMap();
for (Serializable val : actualHistogram.keySet()) {
actualDistribution.put(val, actualHistogram.get(val).doubleValue()/total);
sampledDistribution.put(val, sampledDistribution.get(val).doubleValue()/samples);
}
for(Serializable val : actualDistribution.keySet()) {
logger.info("for val: " + val + ", the actual prob is: " + actualDistribution.get(val).doubleValue() + ". Sampled prob is " + sampledDistribution.get(val).doubleValue());
Assert.assertTrue(actualDistribution.get(val).doubleValue() < sampledDistribution.get(val).doubleValue() + tolerance &&
actualDistribution.get(val).doubleValue() > sampledDistribution.get(val).doubleValue() - tolerance);
}
}
@Test
public void updatesCorrectly() {
List<Instance<AttributesMap, Serializable>> instances = getInstances();
CategoricalDistributionSampler categoricalDistributionSampler = new CategoricalDistributionSampler(instances, 1.0, "v1");
categoricalDistributionSampler.updateDistributionSampler(getNewInstances(), 1.0, "v1");
Map<Serializable, Long> actualHistogram = categoricalDistributionSampler.getHistogramOfCountsForValues();
Map<Serializable, Double> sampledDistribution = Maps.newHashMap();
double tolerance = 0.1;
double samples = 400;
for (int i = 0; i<samples; i++) {
String val = (String) categoricalDistributionSampler.sampleHistogram();
if (sampledDistribution.containsKey(val)) {
sampledDistribution.put(val, (Double) (sampledDistribution.get(val)).doubleValue() + 1.0);
} else {
sampledDistribution.put(val, 1.0);
}
}
double total = 0;
for (Serializable val : actualHistogram.keySet()) {
total += actualHistogram.get(val).doubleValue();
}
Map<Serializable, Double> actualDistribution = Maps.newHashMap();
for (Serializable val : actualHistogram.keySet()) {
actualDistribution.put(val, actualHistogram.get(val).doubleValue()/total);
sampledDistribution.put(val, sampledDistribution.get(val).doubleValue()/samples);
}
for(Serializable val : actualDistribution.keySet()) {
logger.info("for val: " + val + ", the actual prob should be .5 an d is: " + actualDistribution.get(val).doubleValue() + ". Sampled prob is " + sampledDistribution.get(val).doubleValue());
Assert.assertTrue(actualDistribution.get(val).doubleValue() < sampledDistribution.get(val).doubleValue() + tolerance &&
actualDistribution.get(val).doubleValue() > sampledDistribution.get(val).doubleValue() - tolerance);
}
}
private List<Instance<AttributesMap, Serializable>> getInstances(){
List<Instance<AttributesMap, Serializable>> instances = Lists.newArrayList();
//instance 1
AttributesMap attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat1");
Instance<AttributesMap, Serializable> instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
//instance 2
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat2");
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
//instance 3
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat2");
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
//instance 4
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat2");
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
return instances;
}
private List<Instance<AttributesMap, Serializable>> getNewInstances(){
List<Instance<AttributesMap, Serializable>> instances = Lists.newArrayList();
//instance 1
AttributesMap attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat2");
Instance<AttributesMap, Serializable> instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
//instance 2
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat1");
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
//instance 3
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat1");
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
//instance 4
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat1");
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
return instances;
}
}