package quickml.supervised.downsampling;
import com.google.common.collect.Lists;
import org.junit.Test;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import quickml.collections.MapUtils;
import quickml.data.*;
import quickml.data.instances.ClassifierInstance;
import quickml.data.instances.Instance;
import quickml.data.instances.InstanceWithAttributesMap;
import quickml.supervised.PredictiveModelBuilder;
import quickml.supervised.classifier.AbstractClassifier;
import quickml.supervised.classifier.Classifier;
import quickml.supervised.classifier.TreeBuilderTestUtils;
import quickml.supervised.classifier.downsampling.DownsamplingClassifier;
import quickml.supervised.classifier.downsampling.DownsamplingClassifierBuilder;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder;
import quickml.supervised.tree.decisionTree.DecisionTree;
import quickml.supervised.tree.decisionTree.DecisionTreeBuilder;
import quickml.supervised.tree.decisionTree.scorers.GRPenalizedGiniImpurityScorerFactory;
import java.io.IOException;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static junit.framework.Assert.assertEquals;
import static junit.framework.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
* Created by ian on 4/24/14.
*/
public class DownsamplingClassifierBuilderTest {
@Test
public void simpleTest() {
PredictiveModelBuilder mockPredictiveModelBuilder = mock(PredictiveModelBuilder.class);
when(mockPredictiveModelBuilder.buildPredictiveModel(Mockito.any(Iterable.class))).thenAnswer(new Answer<Classifier>() {
@Override
public Classifier answer(final InvocationOnMock invocationOnMock) throws Throwable {
Iterable<Instance<AttributesMap, Serializable>> instances = (Iterable<Instance<AttributesMap, Serializable>>) invocationOnMock.getArguments()[0];
int total = 0, sum = 0;
for (Instance<AttributesMap, Serializable> instance : instances) {
total++;
if (instance.getLabel().equals(true)) {
sum++;
}
}
Classifier dumbPM = new SamePredictionClassifier((double) sum / (double) total);
return dumbPM;
}
});
DownsamplingClassifierBuilder downsamplingClassifierBuilder = new DownsamplingClassifierBuilder(mockPredictiveModelBuilder, 0.2);
List<InstanceWithAttributesMap> data = Lists.newArrayList();
for (int x = 0; x < 10000; x++) {
data.add(new ClassifierInstance(AttributesMap.newHashMap(), (MapUtils.random.nextDouble() < 0.05)));
}
DownsamplingClassifier predictiveModel = downsamplingClassifierBuilder.buildPredictiveModel(data);
AttributesMap map = AttributesMap.newHashMap();
map.put("true", Boolean.TRUE);
final double correctedMinorityInstanceOccurance = predictiveModel.getProbability(map, Boolean.TRUE);
double error = Math.abs(0.05 - correctedMinorityInstanceOccurance);
assertTrue(String.format("Error should be < 0.1 but was %s (prob=%s, desired=0.05)", error, correctedMinorityInstanceOccurance), error < 0.01);
}
@Test
public void simpleBmiTest() throws IOException, ClassNotFoundException {
final DecisionTreeBuilder<ClassifierInstance> tb = new DecisionTreeBuilder<>().scorerFactory(new GRPenalizedGiniImpurityScorerFactory());
final RandomDecisionForestBuilder urfb = new RandomDecisionForestBuilder(tb);
final DownsamplingClassifierBuilder dpmb = new DownsamplingClassifierBuilder(urfb, 0.1);
final List<ClassifierInstance> instances = TreeBuilderTestUtils.getIntegerInstances(1000);
final long startTime = System.currentTimeMillis();
final DownsamplingClassifier downsamplingClassifier = dpmb.buildPredictiveModel(instances);
TreeBuilderTestUtils.serializeDeserialize(downsamplingClassifier);
RandomDecisionForest randomDecisionForest = (RandomDecisionForest) downsamplingClassifier.wrappedClassifier;
final List<DecisionTree> decisionTrees = randomDecisionForest.decisionTrees;
final int treeSize = decisionTrees.size();
final int firstTreeNodeSize = decisionTrees.get(0).root.getSize();
org.testng.Assert.assertTrue(treeSize < 400, "Forest getSize should be less than 400");
org.testng.Assert.assertTrue((System.currentTimeMillis() - startTime) < 20000, "Building this root should take far less than 20 seconds");
}
private static class SamePredictionClassifier extends AbstractClassifier {
private static final long serialVersionUID = 8241616760952568181L;
private final double prediction;
public SamePredictionClassifier(double prediction) {
this.prediction = prediction;
}
@Override
public PredictionMap predict(AttributesMap attributes) {
Map<Serializable, Double> map = new HashMap<>();
for (Serializable value : attributes.values()) {
map.put(value, prediction);
}
return new PredictionMap(map);
}
@Override
public PredictionMap predictWithoutAttributes(AttributesMap attributes, Set<String> attributesToIgnore) {
return predict(attributes);
}
}
}