package quickml.supervised.inspection;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.AttributesMap;
import quickml.data.instances.Instance;
import quickml.data.instances.InstanceImpl;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
/**
* Created by alexanderhawk on 11/17/14.
*/
public class NumericDistributionSamplerTest {
private static final Logger logger = LoggerFactory.getLogger(NumericDistributionSamplerTest.class);
@Test
public void testStoresCountsCorrectly() {
List<Instance<AttributesMap, Serializable>> instances = getInstances();
for (int i = 0; i<2; i++) {
instances.addAll(getInstances());
}
NumericDistributionSampler numericDistributionSampler = new NumericDistributionSampler(instances, 1.0, "v1", 4);
Map<Integer, Long> distribution = numericDistributionSampler.getHistogramOfCountsForValues();
Assert.assertTrue(distribution.get(0).equals(3L));
Assert.assertTrue(distribution.get(1).equals(3L));
Assert.assertTrue(distribution.get(2).equals(3L));
Assert.assertTrue(distribution.get(3).equals(3L));
}
@Test
public void samplesDistCorrectly() {
List<Instance<AttributesMap, Serializable>> instances = getInstances();
//add same instances multiple times since we'll be randomly (nearly randomly) removing instances
for (int i = 0; i<10; i++) {
instances.addAll(getInstances());
}
NumericDistributionSampler numericDistributionSampler = new NumericDistributionSampler(instances, 1.0, "v1", 4);
Map<Integer, Long> actualHistogram = numericDistributionSampler.getHistogramOfCountsForValues();
Map<Integer, Double> sampledDistribution = Maps.newHashMap();//contains bin + counts
double tolerance = 0.1;
double samples = 800;
for (int i = 0; i<samples; i++) {
double val = (Double) numericDistributionSampler.sampleHistogram();
int bin = numericDistributionSampler.getBinIndex(val);
if (sampledDistribution.containsKey(bin)) {
sampledDistribution.put(bin, (Double) (sampledDistribution.get(bin)).doubleValue() + 1.0);
} else {
sampledDistribution.put(bin, 1.0);
}
}
double total = 0;
for (Integer bin : actualHistogram.keySet()) {
total += actualHistogram.get(bin).doubleValue();
}
Map<Integer, Double> actualDistribution = Maps.newHashMap();
for (Integer bin : actualHistogram.keySet()) {
actualDistribution.put(bin, actualHistogram.get(bin).doubleValue()/total);
sampledDistribution.put(bin, sampledDistribution.get(bin).doubleValue()/samples);
}
for(Integer bin : actualDistribution.keySet()) {
logger.info("for bin: " + bin + ", the actual prob is: " + actualDistribution.get(bin).doubleValue() + ". Sampled prob is " + sampledDistribution.get(bin).doubleValue());
Assert.assertTrue(actualDistribution.get(bin).doubleValue() < sampledDistribution.get(bin).doubleValue() + tolerance &&
actualDistribution.get(bin).doubleValue() > sampledDistribution.get(bin).doubleValue() - tolerance);
}
}
@Ignore //this test is needed if the class it tests turns out to be needed
public void updatesCorrectly() {
List<Instance<AttributesMap, Serializable>> instances = getInstances();
//add same instances multiple times since we'll be randomly (nearly randomly) removing instances
for (int i = 0; i<10; i++) {
instances.addAll(getInstances());
}
NumericDistributionSampler numericDistributionSampler = new NumericDistributionSampler(instances, 1.0, "v1", 4);
numericDistributionSampler.updateDistributionSampler(getNewInstances(), 1.0, "v1", 4);
Map<Integer, Long> actualHistogram = numericDistributionSampler.getHistogramOfCountsForValues();
Map<Integer, Double> sampledDistribution = Maps.newHashMap();//contains bin + counts
double tolerance = 0.1;
double samples = 400;
for (int i = 0; i<samples; i++) {
double val = (Double) numericDistributionSampler.sampleHistogram();
int bin = numericDistributionSampler.getBinIndex(val);
if (sampledDistribution.containsKey(bin)) {
sampledDistribution.put(bin, (Double) (sampledDistribution.get(bin)).doubleValue() + 1.0);
} else {
sampledDistribution.put(bin, 1.0);
}
}
double total = 0;
for (Integer bin : actualHistogram.keySet()) {
total += actualHistogram.get(bin).doubleValue();
}
Map<Integer, Double> actualDistribution = Maps.newHashMap();
for (Integer bin : actualHistogram.keySet()) {
actualDistribution.put(bin, actualHistogram.get(bin).doubleValue()/total);
sampledDistribution.put(bin, sampledDistribution.get(bin).doubleValue()/samples);
}
for(Integer bin : actualDistribution.keySet()) {
logger.info("for bin: " + bin + ", the actual prob should be .5 an d is: " + actualDistribution.get(bin).doubleValue() + ". Sampled prob is " + sampledDistribution.get(bin).doubleValue());
Assert.assertTrue(actualDistribution.get(bin).doubleValue() < sampledDistribution.get(bin).doubleValue() + tolerance &&
actualDistribution.get(bin).doubleValue() > sampledDistribution.get(bin).doubleValue() - tolerance);
}
}
private List<Instance<AttributesMap, Serializable>> getInstances(){
List<Instance<AttributesMap, Serializable>> instances = Lists.newArrayList();
//instance 1
AttributesMap attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", 0.25);
Instance<AttributesMap, Serializable> instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
//instance 2
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", 0.5);
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
//instance 3
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", 0.75);
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
//instance 4
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", 1.0);
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
return instances;
}
private List<Instance<AttributesMap, Serializable>> getNewInstances(){
List<Instance<AttributesMap, Serializable>> instances = Lists.newArrayList();
//instance 1
AttributesMap attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat2");
Instance<AttributesMap, Serializable> instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
//instance 2
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat1");
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,1.0);
instances.add(instance);
//instance 3
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat1");
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,0.75);
instances.add(instance);
//instance 4
attributesMap = AttributesMap.newHashMap();
attributesMap.put("v1", "cat1");
instance = new InstanceImpl<AttributesMap, Serializable>(attributesMap,0.75);
instances.add(instance);
return instances;
}
}