package quickml.supervised.tree.scorers;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import quickml.supervised.tree.decisionTree.scorers.GRPenalizedGiniImpurityScorer;
import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter;
import quickml.supervised.tree.reducers.AttributeStats;
import java.util.Arrays;
public class GiniImpurityScorerTest {
Scorer<ClassificationCounter> scorer;
@Before
public void setUp(){
ClassificationCounter a = new ClassificationCounter();
a.addClassification("a", 4);
ClassificationCounter b = new ClassificationCounter();
b.addClassification("a", 4);
scorer = new GRPenalizedGiniImpurityScorer(0, new AttributeStats<>(null, a.add(b), "a"));
}
@Test
public void sameClassificationTest() {
ClassificationCounter a = new ClassificationCounter();
a.addClassification("a", 4);
ClassificationCounter b = new ClassificationCounter();
b.addClassification("a", 4);
Assert.assertEquals(scorer.scoreSplit(a, b), 0.0, 1E-7);
}
@Test
public void diffClassificationTest() {
ClassificationCounter a = new ClassificationCounter();
a.addClassification("a", 4);
ClassificationCounter b = new ClassificationCounter();
b.addClassification("b", 4);
scorer = new GRPenalizedGiniImpurityScorer(0, new AttributeStats<>(Arrays.asList(a,b), a.add(b), "a"));
Assert.assertEquals(scorer.scoreSplit(a, b), 0.5, 1E-7);
}
@Test
public void parentClassificationSameAsIdenticalChildTest() {
ClassificationCounter a = new ClassificationCounter();
a.addClassification("a", 4);
a.addClassification("b", 3);
ClassificationCounter b = new ClassificationCounter();
GRPenalizedGiniImpurityScorer scorer = new GRPenalizedGiniImpurityScorer(0, new AttributeStats<>(Arrays.asList(a,b), a.add(b), "a"));
Assert.assertEquals(0.0, scorer.scoreSplit(a, b), 1E-7);
}
}