/*
* File: VectorThresholdHellingerDistanceLearnerTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright November 25, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.tree;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.mtj.Vector3;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.util.DefaultPair;
import java.util.LinkedList;
import junit.framework.TestCase;
/**
* Unit tests for class VectorThresholdHellingerDistanceLearner
*
* @author Justin Basilico
* @since 3.0
*/
public class VectorThresholdHellingerDistanceLearnerTest
extends TestCase
{
/**
* Creates a new test.
*
* @param testName The test name.
*/
public VectorThresholdHellingerDistanceLearnerTest(
String testName)
{
super(testName);
}
/**
* Test of constructors of class VectorThresholdHellingerDistanceLearner.
*/
public void testConstructors()
{
int minSplitSize = VectorThresholdHellingerDistanceLearner.DEFAULT_MIN_SPLIT_SIZE;
VectorThresholdHellingerDistanceLearner<Boolean> instance =
new VectorThresholdHellingerDistanceLearner<>();
assertEquals(minSplitSize, instance.getMinSplitSize());
minSplitSize = 6;
instance = new VectorThresholdHellingerDistanceLearner<>(minSplitSize);
assertEquals(minSplitSize, instance.getMinSplitSize());
}
/**
* Test of learn method, of class VectorThresholdHellingerDistanceLearner.
*/
public void testLearn()
{
VectorThresholdHellingerDistanceLearner<Boolean> instance =
new VectorThresholdHellingerDistanceLearner<Boolean>();
VectorElementThresholdCategorizer result = instance.learn(null);
assertNull(result);
LinkedList<InputOutputPair<Vector3, Boolean>> data =
new LinkedList<InputOutputPair<Vector3, Boolean>>();
result = instance.learn(data);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 4.0, 2.0), true));
result = instance.learn(data);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 4.0, 2.0), true));
result = instance.learn(data);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 1.0, 2.0), true));
result = instance.learn(data);
assertNotNull(result);
assertEquals(1, result.getIndex());
assertEquals(2.5, result.getThreshold());
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 2.0, 3.0), false));
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 4.0, 4.0), false));
result = instance.learn(data);
assertNotNull(result);
assertEquals(2, result.getIndex());
assertEquals(2.5, result.getThreshold());
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 3.0, 2.0), false));
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 0.0, 2.0), true));
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 5.0, 2.0), false));
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 7.0, 2.0), false));
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 8.0, 2.0), false));
result = instance.learn(data);
assertNotNull(result);
assertEquals(1, result.getIndex());
assertEquals(4.5, result.getThreshold());
}
/**
* Test of computeBestGainAndThreshold method, of class VectorThresholdHellingerDistanceLearner.
*/
public void testComputeBestThreshold()
{
VectorThresholdHellingerDistanceLearner<Boolean> instance =
new VectorThresholdHellingerDistanceLearner<Boolean>();
DefaultDataDistribution<Boolean> baseCounts = null;
DefaultPair<Double, Double> result = null;
LinkedList<InputOutputPair<Vector3, Boolean>> data =
new LinkedList<InputOutputPair<Vector3, Boolean>>();
baseCounts = CategorizationTreeLearner.getOutputCounts(data);
result = instance.computeBestGainAndThreshold(data, 0, baseCounts);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 4.0, 2.0), true));
baseCounts = CategorizationTreeLearner.getOutputCounts(data);
result = instance.computeBestGainAndThreshold(data, 0, baseCounts);
assertNull(result);
result = instance.computeBestGainAndThreshold(data, 1, baseCounts);
assertNull(result);
result = instance.computeBestGainAndThreshold(data, 2, baseCounts);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 1.0, 2.0), true));
baseCounts = CategorizationTreeLearner.getOutputCounts(data);
result = instance.computeBestGainAndThreshold(data, 0, baseCounts);
assertNull(result);
baseCounts = CategorizationTreeLearner.getOutputCounts(data);
result = instance.computeBestGainAndThreshold(data, 1, baseCounts);
assertEquals(0.0, result.getFirst());
assertEquals(2.5, result.getSecond());
baseCounts = CategorizationTreeLearner.getOutputCounts(data);
result = instance.computeBestGainAndThreshold(data, 2, baseCounts);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 2.0, 3.0), false));
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 4.0, 4.0), false));
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 3.0, 5.0), false));
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 0.0, 2.0), true));
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 5.0, 2.0), false));
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 7.0, 2.0), false));
data.add(new DefaultInputOutputPair<Vector3, Boolean>(new Vector3(1.0, 8.0, 2.0), false));
baseCounts = CategorizationTreeLearner.getOutputCounts(data);
result = instance.computeBestGainAndThreshold(data, 0, baseCounts);
assertNull(result);
baseCounts = CategorizationTreeLearner.getOutputCounts(data);
result = instance.computeBestGainAndThreshold(data, 1, baseCounts);
assertEquals(0.919, result.getFirst(), 0.001);
assertEquals(1.5, result.getSecond());
baseCounts = CategorizationTreeLearner.getOutputCounts(data);
result = instance.computeBestGainAndThreshold(data, 2, baseCounts);
assertEquals(0.765, result.getFirst(), 0.001);
assertEquals(2.5, result.getSecond());
}
/**
* Test of computeSplitGain method, of class VectorThresholdHellingerDistanceLearner.
*/
public void testComputeGain()
{
VectorThresholdHellingerDistanceLearner<Boolean> instance =
new VectorThresholdHellingerDistanceLearner<Boolean>();
DefaultDataDistribution<Boolean> positiveCounts = new DefaultDataDistribution<Boolean>();
DefaultDataDistribution<Boolean> negativeCounts = new DefaultDataDistribution<Boolean>();
DefaultDataDistribution<Boolean> baseCounts = new DefaultDataDistribution<Boolean>();
double result = 0.0;
// Test case: Zeros
// P N Total
// T 0 0 0
// F 0 0 0
// Mean Hellinger distance = 0.000
result = instance.computeSplitGain(baseCounts, positiveCounts, negativeCounts);
assertEquals(0.0, result, 0.0);
// Test case: 1/1 split
// P N Total
// T 1 0 1
// F 0 1 1
// Mean Hellinger distance = 1.141 = sqrt(2)
positiveCounts.increment(true, 1);
positiveCounts.increment(false, 0);
negativeCounts.increment(true, 0);
negativeCounts.increment(false, 1);
baseCounts.increment(true, 1);
baseCounts.increment(false, 1);
result = instance.computeSplitGain(baseCounts, positiveCounts, negativeCounts);
assertEquals(Math.sqrt(2.0), result, 0.001);
// Test case: Zero distance
// P N Total
// T 1 1 2
// F 1 1 2
// Mean Hellinger distance = 0.000
baseCounts = new DefaultDataDistribution<Boolean>();
positiveCounts = new DefaultDataDistribution<Boolean>();
negativeCounts = new DefaultDataDistribution<Boolean>();
positiveCounts.increment(true, 1);
positiveCounts.increment(false, 1);
negativeCounts.increment(true, 1);
negativeCounts.increment(false, 1);
baseCounts.increment(true, 2);
baseCounts.increment(false, 2);
result = instance.computeSplitGain(baseCounts, positiveCounts, negativeCounts);
assertEquals(0.0, result, 0.001);
// Test case: Example 1
// P N Total
// T 2 2 4
// F 6 0 6
// Mean Hellinger distance = 0.765
baseCounts = new DefaultDataDistribution<Boolean>();
positiveCounts = new DefaultDataDistribution<Boolean>();
negativeCounts = new DefaultDataDistribution<Boolean>();
positiveCounts.increment(true, 2);
positiveCounts.increment(false, 6);
negativeCounts.increment(true, 2);
negativeCounts.increment(false, 0);
baseCounts.increment(true, 4);
baseCounts.increment(false, 6);
result = instance.computeSplitGain(baseCounts, positiveCounts, negativeCounts);
assertEquals(0.765, result, 0.001);
// Test case: Example 2
// P N Total
// T 0 4 4
// F 3 3 6
// Mean Hellinger distance = 0.765
positiveCounts = new DefaultDataDistribution<Boolean>();
negativeCounts = new DefaultDataDistribution<Boolean>();
baseCounts = new DefaultDataDistribution<Boolean>();
positiveCounts.increment(true, 0);
positiveCounts.increment(false, 3);
negativeCounts.increment(true, 4);
negativeCounts.increment(false, 3);
baseCounts.increment(true, 4);
baseCounts.increment(false, 6);
result = instance.computeSplitGain(baseCounts, positiveCounts, negativeCounts);
assertEquals(0.765, result, 0.001);
// Test case: Example 3
// P N Total
// T 1 2 3
// F 6 0 6
// Mean Hellinger distance = 0.919
positiveCounts = new DefaultDataDistribution<Boolean>();
negativeCounts = new DefaultDataDistribution<Boolean>();
baseCounts = new DefaultDataDistribution<Boolean>();
positiveCounts.increment(true, 1);
positiveCounts.increment(false, 6);
negativeCounts.increment(true, 2);
negativeCounts.increment(false, 0);
baseCounts.increment(true, 3);
baseCounts.increment(false, 6);
result = instance.computeSplitGain(baseCounts, positiveCounts, negativeCounts);
assertEquals(0.919, result, 0.001);
// Test case: Example 4
// P N Total
// T 0 3 3
// F 3 3 6
// Mean Hellinger distance = 0.765
positiveCounts = new DefaultDataDistribution<Boolean>();
negativeCounts = new DefaultDataDistribution<Boolean>();
baseCounts = new DefaultDataDistribution<Boolean>();
positiveCounts.increment(true, 0);
positiveCounts.increment(false, 3);
negativeCounts.increment(true, 3);
negativeCounts.increment(false, 3);
baseCounts.increment(true, 3);
baseCounts.increment(false, 6);
result = instance.computeSplitGain(baseCounts, positiveCounts, negativeCounts);
assertEquals(0.765, result, 0.001);
// Test case: Example 5
// P N Total
// T 0 1 2
// F 1 10 11
// Mean Hellinger distance = 0.474
positiveCounts = new DefaultDataDistribution<Boolean>();
negativeCounts = new DefaultDataDistribution<Boolean>();
baseCounts = new DefaultDataDistribution<Boolean>();
positiveCounts.increment(true, 1);
positiveCounts.increment(false, 1);
negativeCounts.increment(true, 1);
negativeCounts.increment(false, 10);
baseCounts.increment(true, 2);
baseCounts.increment(false, 11);
result = instance.computeSplitGain(baseCounts, positiveCounts, negativeCounts);
assertEquals(0.474, result, 0.001);
// Test case: Example 6
// P N Total
// T 11 1 11
// F 1 1 2
// Mean Hellinger distance = 0.474
positiveCounts = new DefaultDataDistribution<Boolean>();
negativeCounts = new DefaultDataDistribution<Boolean>();
baseCounts = new DefaultDataDistribution<Boolean>();
positiveCounts.increment(true, 10);
positiveCounts.increment(false, 1);
negativeCounts.increment(true, 1);
negativeCounts.increment(false, 1);
baseCounts.increment(true, 11);
baseCounts.increment(false, 2);
result = instance.computeSplitGain(baseCounts, positiveCounts, negativeCounts);
assertEquals(0.474, result, 0.001);
}
}