/*
* File: VectorThresholdVarianceLearnerTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright November 30, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*
*/
package gov.sandia.cognition.learning.algorithm.tree;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import junit.framework.*;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.math.matrix.mtj.Vector3;
import gov.sandia.cognition.util.DefaultPair;
import java.util.LinkedList;
/**
* This class implements JUnit tests for the following classes:
*
* VectorThresholdVarianceLearner
*
* @author Justin Basilico
* @since 2.0
*/
public class VectorThresholdVarianceLearnerTest
extends TestCase
{
public VectorThresholdVarianceLearnerTest(
String testName)
{
super(testName);
}
/**
* Test of constructors of class VectorThresholdVarianceLearner.
*/
public void testConstructors()
{
int minSplitSize = VectorThresholdVarianceLearner.DEFAULT_MIN_SPLIT_SIZE;
int[] dimensionsToConsider = null;
VectorThresholdVarianceLearner instance = new VectorThresholdVarianceLearner();
assertEquals(minSplitSize, instance.minSplitSize);
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
minSplitSize = 5;
instance = new VectorThresholdVarianceLearner(minSplitSize);
assertEquals(minSplitSize, instance.minSplitSize);
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
minSplitSize = 11;
dimensionsToConsider = new int[] { 3, 4 };
instance = new VectorThresholdVarianceLearner(minSplitSize, dimensionsToConsider);
assertEquals(minSplitSize, instance.minSplitSize);
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
}
/**
* Test of learn method, of class gov.sandia.cognition.learning.algorithm.tree.VectorThresholdVarianceLearner.
*/
public void testLearn()
{
VectorThresholdVarianceLearner instance =
new VectorThresholdVarianceLearner();
VectorElementThresholdCategorizer result = instance.learn(null);
assertNull(result);
LinkedList<InputOutputPair<Vector3, Double>> data =
new LinkedList<InputOutputPair<Vector3, Double>>();
result = instance.learn(data);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 4.0, 2.0), 4.0));
result = instance.learn(data);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 4.0, 2.0), 4.0));
result = instance.learn(data);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 1.0, 2.0), 4.0));
result = instance.learn(data);
assertNotNull(result);
assertEquals(1, result.getIndex());
assertEquals(2.5, result.getThreshold());
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 2.0, 3.0), 1.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 4.0, 4.0), 1.5));
result = instance.learn(data);
assertNotNull(result);
assertEquals(2, result.getIndex());
assertEquals(2.5, result.getThreshold());
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 3.0, 5.0), 0.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 0.0, 2.0), 4.5));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 5.0, 2.0), 0.5));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 7.0, 2.0), 2.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 8.0, 2.0), 1.5));
result = instance.learn(data);
assertNotNull(result);
assertEquals(1, result.getIndex());
assertEquals(1.5, result.getThreshold());
// Try to only use some dimensions.
instance.setDimensionsToConsider(0, 2);
result = instance.learn(data);
assertNotNull(result);
assertEquals(2, result.getIndex());
assertEquals(2.5, result.getThreshold());
}
/**
* Test of computeBestGainThreshold method, of class gov.sandia.cognition.learning.algorithm.tree.VectorThresholdVarianceLearner.
*/
public void testComputeBestGainThreshold()
{
VectorThresholdVarianceLearner instance =
new VectorThresholdVarianceLearner();
double baseVariance = 0.0;
DefaultPair<Double, Double> result = null;
LinkedList<InputOutputPair<Vector3, Double>> data =
new LinkedList<InputOutputPair<Vector3, Double>>();
baseVariance = DatasetUtil.computeOutputVariance(data);
result = instance.computeBestGainThreshold(data, 0, baseVariance);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 4.0, 2.0), 4.0));
baseVariance = DatasetUtil.computeOutputVariance(data);
result = instance.computeBestGainThreshold(data, 0, baseVariance);
assertNull(result);
result = instance.computeBestGainThreshold(data, 1, baseVariance);
assertNull(result);
result = instance.computeBestGainThreshold(data, 2, baseVariance);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 1.0, 2.0), 4.0));
baseVariance = DatasetUtil.computeOutputVariance(data);
result = instance.computeBestGainThreshold(data, 0, baseVariance);
assertNull(result);
result = instance.computeBestGainThreshold(data, 1, baseVariance);
assertEquals(0.0, result.getFirst());
assertEquals(2.5, result.getSecond());
result = instance.computeBestGainThreshold(data, 2, baseVariance);
assertNull(result);
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 2.0, 3.0), 1.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 4.0, 4.0), 0.5));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 3.0, 5.0), 0.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 0.0, 2.0), 4.5));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 5.0, 2.0), 1.5));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 7.0, 2.0), 2.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 8.0, 2.0), 1.5));
baseVariance = DatasetUtil.computeOutputVariance(data);
result = instance.computeBestGainThreshold(data, 0, baseVariance);
assertNull(result);
result = instance.computeBestGainThreshold(data, 1, baseVariance);
assertEquals(1.307, result.getFirst(), 0.001);
assertEquals(1.5, result.getSecond());
result = instance.computeBestGainThreshold(data, 2, baseVariance);
assertEquals(1.297, result.getFirst(), 0.001);
assertEquals(2.5, result.getSecond());
}
public void testGetDimensionsToConsider()
{
this.testSetDimensionsToConsider();
}
public void testSetDimensionsToConsider()
{
int[] dimensionsToConsider = null;
VectorThresholdVarianceLearner instance = new VectorThresholdVarianceLearner();
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
dimensionsToConsider = new int[] {1,2,5};
instance.setDimensionsToConsider(dimensionsToConsider);
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
dimensionsToConsider = new int[] {0, 9, 12};
instance.setDimensionsToConsider(dimensionsToConsider);
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
dimensionsToConsider = null;
instance.setDimensionsToConsider(dimensionsToConsider);
assertSame(dimensionsToConsider, instance.getDimensionsToConsider());
}
/**
* Test of learn method with childCountThreshold.
*/
public void testLearnChildThreshold()
{
VectorElementThresholdCategorizer result;
LinkedList<InputOutputPair<Vector3, Double>> data =
new LinkedList<InputOutputPair<Vector3, Double>>();
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 4.0, 2.0), 10.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 3.0, 2.0), 2.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 1.0, 2.0), 3.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 1.5, 3.0), 1.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 3.0, 4.0), 2.5));
VectorThresholdVarianceLearner instance =
new VectorThresholdVarianceLearner();
result = instance.learn(data);
assertNotNull(result);
assertEquals(1, result.getIndex());
assertEquals(3.5, result.getThreshold());
instance = new VectorThresholdVarianceLearner(2);
result = instance.learn(data);
assertNotNull(result);
assertEquals(2, result.getIndex());
assertEquals(2.5, result.getThreshold());
instance = new VectorThresholdVarianceLearner(3);
result = instance.learn(data);
assertNull(result);
// Now create a dataset that cannot possibly be split into two children of at least two members
data = new LinkedList<InputOutputPair<Vector3, Double>>();
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 1.0, 1.0), 2.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 2.0, 2.0), 2.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 2.0, 2.0), 3.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 2.0, 2.0), 4.0));
data.add(new DefaultInputOutputPair<Vector3, Double>(new Vector3(1.0, 3.0, 2.0), 5.0));
instance = new VectorThresholdVarianceLearner(1);
result = instance.learn(data);
assertNotNull(result);
assertEquals(1, result.getIndex());
assertEquals(2.5, result.getThreshold());
instance = new VectorThresholdVarianceLearner(2);
result = instance.learn(data);
assertNull(result);
// Setting childLeafCount to zero should be overridden inside so it becomes 1.
boolean exceptionThrown = false;
try
{
instance = new VectorThresholdVarianceLearner(0);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
}
}