package quickml.supervised.tree.decisionTree.reducers;
import com.google.common.base.Optional;
import com.google.common.collect.Iterables;
import com.twitter.common.stats.ReservoirSampler;
import org.junit.Assert;
import org.junit.Test;
import quickml.data.instances.ClassifierInstance;
import quickml.supervised.tree.branchFinders.SplittingUtilsTest;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.reducers.AttributeStats;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
/**
* Created by alexanderhawk on 6/25/15.
*/
public class DTNumOldBranchReducerTest {
@Test
public void getDeterministicSplitTest2() {
List<ClassifierInstance> instances = SplittingUtilsTest.getExtendedInstances();
String attribute = "t";
int numNumericBins = 6;
// List<I> instances, String, attribute, int numNumericBins) {
Assert.assertEquals(0, 0);
}
@Test
public void allValuesSameTest(){
double[] x = {0, 0, 0, 0};
Assert.assertTrue(DTNumBranchReducer.allValuesSame(x));
double[] y = {0, 0, 0, 1};
Assert.assertTrue(!DTNumBranchReducer.allValuesSame(y));
}
@Test
public void getBinDividerPointsTest(){
List<Double> valuesList = Arrays.<Double>asList(1.0, 2.0, 3.0, 4.0, 5.0);
Optional<double[]> splits = DTNumBranchReducer.getBinDividerPoints(4, valuesList);
Assert.assertEquals(splits.get()[0], 2.5, 1E-5);
valuesList = Arrays.<Double>asList(1.0, 2.0, 3.0, 4.0);
splits = DTNumBranchReducer.getBinDividerPoints(4, valuesList);
Assert.assertEquals(splits.get()[0], 1.5, 1E-5);
valuesList = Arrays.<Double>asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
splits = DTNumBranchReducer.getBinDividerPoints(4, valuesList);
Assert.assertEquals(splits.get()[0], 2.5, 1E-5);
valuesList = Arrays.<Double>asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0);
splits = DTNumBranchReducer.getBinDividerPoints(4, valuesList);
Assert.assertEquals(splits.get()[2], 6.5, 1E-5);
}
@Test
public void getDeterministicSplitTest(){
List<ClassifierInstance> instances = SplittingUtilsTest.getInstances();
Optional<double[]> splits =DTNumBranchReducer.<ClassifierInstance>getDeterministicSplit(instances, "t", 4);
Assert.assertEquals(splits.get()[0], 1.5, 1E-5);
}
@Test
public void fillReservoirSampler(){
List<ClassifierInstance> instances = SplittingUtilsTest.getExtendedInstances();
ReservoirSampler<Double> rs = DTNumBranchReducer.<ClassifierInstance>fillReservoirSampler(instances, "t", 4);
Assert.assertEquals(Iterables.size(rs.getSamples()), 4);
for (int i = 1; i< 20; i++ ) {
instances.addAll(SplittingUtilsTest.getExtendedInstances());
}
Random rand = new Random();
for (int i = 1; i< 1000; i++ ) {
int one = rand.nextInt(instances.size());
int two = rand.nextInt(instances.size());
ClassifierInstance ci1 = instances.get(one);
ClassifierInstance ci2 = instances.get(two);
instances.set(one, ci2);
instances.set(two, ci1);
}
rs = DTNumBranchReducer.<ClassifierInstance>fillReservoirSampler(instances, "t", 4);
Assert.assertEquals(Iterables.size(rs.getSamples()), 4);
}
@Test
public void getAttributeStatsOptionalTest(){
//fix
List<ClassifierInstance> instances = SplittingUtilsTest.getExtendedInstances();
double []splits ={2.5, 4.5, 6.5};
Optional<AttributeStats<ClassificationCounter>> attributeStatsOptional = DTNumBranchReducer.<ClassifierInstance>getAttributeStatsOptional("t", splits, instances);
Assert.assertTrue(attributeStatsOptional.isPresent());
AttributeStats<ClassificationCounter> attributeStats = attributeStatsOptional.get();
Assert.assertEquals(attributeStats.getStatsOnEachValue().size(), 4);
ClassificationCounter cc = attributeStats.getStatsOnEachValue().get(0);
Assert.assertEquals(cc.getTotal(), 2.0, 1E-5);
Assert.assertTrue(cc.allClassifications().contains(0.0) && !cc.allClassifications().contains(1.0));
splits =new double[1];
splits[0] = 1.5;
attributeStatsOptional = DTNumBranchReducer.<ClassifierInstance>getAttributeStatsOptional("t", splits, instances);
cc = attributeStatsOptional.get().getStatsOnEachValue().get(1);
Assert.assertTrue(cc.allClassifications().contains(0.0) && cc.allClassifications().contains(1.0));
Assert.assertEquals(cc.getCount(1.0), 6.0, 1E-5);
Assert.assertEquals(cc.getCount(0.0), 1.0, 1E-5);
}
}