/*
* File: VectorThresholdGiniImpurityLearnerTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright January 14, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.tree;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.mtj.Vector3;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import java.util.LinkedList;
import junit.framework.TestCase;
/**
* Unit tests for class VectorThresholdGiniImpurityLearner.
*
* @author Justin Basilico
* @since 3.0
*/
public class VectorThresholdGiniImpurityLearnerTest
extends TestCase
{
/**
* Creates a new test.
*
* @param testName The test name.
*/
public VectorThresholdGiniImpurityLearnerTest(
String testName)
{
super(testName);
}
/**
* Test of constructors of class VectorThresholdGiniImpurityLearner.
*/
public void testConstructors()
{
int minSplitSize = VectorThresholdGiniImpurityLearner.DEFAULT_MIN_SPLIT_SIZE;
VectorThresholdGiniImpurityLearner<Boolean> instance =
new VectorThresholdGiniImpurityLearner<>();
assertEquals(minSplitSize, instance.getMinSplitSize());
minSplitSize = 6;
instance = new VectorThresholdGiniImpurityLearner<>(minSplitSize);
assertEquals(minSplitSize, instance.getMinSplitSize());
}
/**
* Test of computeSplitGain method, of class VectorThresholdGiniImpurityLearner.
*/
public void testComputeSplitGain()
{
final double epsilon = 0.0001;
VectorThresholdGiniImpurityLearner<String> instance =
new VectorThresholdGiniImpurityLearner<String>();
DefaultDataDistribution<String> empty = new DefaultDataDistribution<String>();
DefaultDataDistribution<String> baseCounts = new DefaultDataDistribution<String>();
baseCounts.increment("a", 50);
baseCounts.increment("b", 50);
assertEquals(0.0, instance.computeSplitGain(baseCounts, baseCounts, empty), epsilon);
assertEquals(0.0, instance.computeSplitGain(baseCounts, empty, baseCounts), epsilon);
DefaultDataDistribution<String> as = new DefaultDataDistribution<String>();
as.increment("a", 50);
DefaultDataDistribution<String> bs = new DefaultDataDistribution<String>();
bs.increment("b", 50);
assertEquals(0.5, instance.computeSplitGain(baseCounts, as, bs), epsilon);
assertEquals(0.5, instance.computeSplitGain(baseCounts, bs, as), epsilon);
DefaultDataDistribution<String> mixed = new DefaultDataDistribution<String>();
mixed.increment("a", 25);
mixed.increment("b", 25);
assertEquals(0.0, instance.computeSplitGain(baseCounts, mixed, mixed), epsilon);
DefaultDataDistribution<String> tenAs = new DefaultDataDistribution<String>();
tenAs.increment("a", 10);
DefaultDataDistribution<String> moreBs = new DefaultDataDistribution<String>();
moreBs.increment("a", 15);
moreBs.increment("b", 25);
assertEquals(0.3125, instance.computeSplitGain(baseCounts, tenAs, moreBs), epsilon);
assertEquals(0.3125, instance.computeSplitGain(baseCounts, moreBs, tenAs), epsilon);
}
/**
* Test of giniImpurity method, of class VectorThresholdGiniImpurityLearner.
*/
public void testGiniImpurity()
{
final double epsilon = 0.00001;
DefaultDataDistribution<String> empty = new DefaultDataDistribution<String>();
assertEquals(0.0, VectorThresholdGiniImpurityLearner.giniImpurity(empty), epsilon);
DefaultDataDistribution<String> pure = new DefaultDataDistribution<String>();
pure.increment("a", 100);
assertEquals(0.0, VectorThresholdGiniImpurityLearner.giniImpurity(pure), epsilon);
DefaultDataDistribution<String> impure = new DefaultDataDistribution<String>();
impure.increment("a", 50);
impure.increment("b", 50);
assertEquals(0.5, VectorThresholdGiniImpurityLearner.giniImpurity(impure), epsilon);
DefaultDataDistribution<String> almostPure = new DefaultDataDistribution<String>();
almostPure.increment("a", 1);
almostPure.increment("b", 99);
assertEquals(0.0198, VectorThresholdGiniImpurityLearner.giniImpurity(almostPure), epsilon);
DefaultDataDistribution<Integer> lots = new DefaultDataDistribution<Integer>();
for (int i = 0; i < 100; i++)
{
lots.increment(i, 1);
}
assertEquals(0.99, VectorThresholdGiniImpurityLearner.giniImpurity(lots), epsilon);
}
public void testLearnWithMinSplitSize()
{
VectorElementThresholdCategorizer result;
LinkedList<InputOutputPair<Vector3, String>> data = new LinkedList<>();
data.add(new DefaultInputOutputPair<>(new Vector3(1.0, 4.0, 2.0), "c"));
data.add(new DefaultInputOutputPair<>(new Vector3(1.0, 3.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<>(new Vector3(1.0, 1.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<>(new Vector3(1.0, 1.5, 3.0), "a"));
data.add(new DefaultInputOutputPair<>(new Vector3(1.0, 2.5, 4.0), "a"));
data.add(new DefaultInputOutputPair<>(new Vector3(1.0, 2.5, 5.0), "b"));
VectorThresholdGiniImpurityLearner<String> instance =
new VectorThresholdGiniImpurityLearner<>();
result = instance.learn(data);
assertNotNull(result);
assertEquals(1, result.getIndex());
assertEquals(3.5, result.getThreshold());
instance = new VectorThresholdGiniImpurityLearner<>(2);
result = instance.learn(data);
assertNotNull(result);
assertEquals(2, result.getIndex());
assertEquals(2.5, result.getThreshold());
instance = new VectorThresholdGiniImpurityLearner<>(4);
result = instance.learn(data);
assertNull(result);
// Now create a dataset that cannot possibly be split into two children of at least two members
data = new LinkedList<>();
data.add(new DefaultInputOutputPair<>(new Vector3(1.0, 1.0, 1.0), "a"));
data.add(new DefaultInputOutputPair<>(new Vector3(1.0, 2.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<>(new Vector3(1.0, 2.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<>(new Vector3(1.0, 2.0, 2.0), "a"));
data.add(new DefaultInputOutputPair<>(new Vector3(1.0, 3.0, 2.0), "a"));
instance = new VectorThresholdGiniImpurityLearner<>(1);
result = instance.learn(data);
assertNotNull(result);
assertEquals(1, result.getIndex());
assertEquals(1.5, result.getThreshold());
instance = new VectorThresholdGiniImpurityLearner<>(2);
result = instance.learn(data);
assertNull(result);
// Setting childLeafCount to zero should be overridden inside so it becomes 1.
boolean exceptionThrown = false;
try
{
instance = new VectorThresholdGiniImpurityLearner<>(0);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
}
}