/*
* File: CategorizationTreeLearnerTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright November 16, 2007, Sandia Corporation. Under the terms of Contract
* DE-AC04-94AL85000, there is a non-exclusive license for use of this work by
* or on behalf of the U.S. Government. Export of this program may require a
* license from the United States Government. See CopyrightHistory.txt for
* complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.tree;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.function.categorization.VectorElementThresholdCategorizer;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.matrix.mtj.Vector3;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.HashMap;
import java.util.Map;
import junit.framework.TestCase;
/**
* This class implements JUnit tests for the following classes:
*
* CategorizationTreeLearner
*
* @author Justin Basilico
* @since 2.0
*/
public class CategorizationTreeLearnerTest
extends TestCase
{
public CategorizationTreeLearnerTest(
String testName)
{
super(testName);
}
public void testConstructors()
{
CategorizationTreeLearner<Vector3, String> instance =
new CategorizationTreeLearner<Vector3, String>();
assertNull(instance.getDeciderLearner());
VectorThresholdInformationGainLearner<String> deciderLearner =
new VectorThresholdInformationGainLearner<String>();
instance = new CategorizationTreeLearner<Vector3, String>(deciderLearner);
assertSame(deciderLearner, instance.getDeciderLearner());
}
/**
* Test of learn method, of class gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeLearner.
*/
public void testLearn()
{
VectorThresholdInformationGainLearner<String> deciderLearner =
new VectorThresholdInformationGainLearner<String>();
CategorizationTreeLearner<Vector3, String> instance =
new CategorizationTreeLearner<Vector3, String>(deciderLearner);
instance.setDeciderLearner(new VectorThresholdHellingerDistanceLearner<String>());
CategorizationTree<Vector3, String> result = instance.learn(null);
assertNull(result);
ArrayList<InputOutputPair<Vector3, String>> data =
new ArrayList<InputOutputPair<Vector3, String>>();
result = instance.learn(data);
assertNull(result.getRootNode());
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 4.0, 2.0), "a"));
result = instance.learn(data);
assertNotNull(result.getRootNode());
assertTrue(result.getRootNode().isLeaf());
assertEquals("a", result.evaluate(data.get(0).getInput()));
assertEquals(1, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 2.0), "a"));
result = instance.learn(data);
assertNotNull(result.getRootNode());
assertTrue(result.getRootNode().isLeaf());
assertEquals("a", result.evaluate(data.get(0).getInput()));
assertEquals("a", result.evaluate(data.get(1).getInput()));
assertEquals(1, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 2.0, 3.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 4.0, 4.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 3.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 0.0, 2.0), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 5.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 7.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 8.0, 2.0), "b"));
result = instance.learn(data);
assertNotNull(result.getRootNode());
assertFalse(result.getRootNode().isLeaf());
for ( InputOutputPair<Vector3, String> example : data )
{
assertEquals(example.getOutput(), result.evaluate(example.getInput()));
}
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "b"));
result = instance.learn(data);
assertEquals("a", result.evaluate(new Vector3(1.0, 1.0, 1.0)));
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "b"));
result = instance.learn(data);
assertEquals("b", result.evaluate(new Vector3(1.0, 1.0, 1.0)));
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
data.clear();
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "b"));
result = instance.learn(data);
assertEquals("a", result.evaluate(new Vector3(1.0, 1.0, 1.0)));
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
data.clear();
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(0.0, 0.0, 0.0), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 0.0, 0.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(0.0, 1.0, 0.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 0.0), "a"));
result = instance.learn(data);
for ( InputOutputPair<Vector3, String> example : data )
{
assertEquals(example.getOutput(), result.evaluate(example.getInput()));
}
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
}
/**
* Test of learn method, when using manual prior weights.
*/
public void testLearnWithPriors()
{
HashMap<String, Double> priors = new HashMap<String, Double>();
priors.put("a", 0.8);
priors.put("b", 0.2);
VectorThresholdInformationGainLearner<String> deciderLearner =
new VectorThresholdInformationGainLearner<String>();
CategorizationTreeLearner<Vector3, String> instance =
new CategorizationTreeLearner<Vector3, String>(deciderLearner);
instance.setCategoryPriors(priors);
CategorizationTree<Vector3, String> result = instance.learn(null);
assertNull(result);
ArrayList<InputOutputPair<Vector3, String>> data =
new ArrayList<InputOutputPair<Vector3, String>>();
result = instance.learn(data);
assertNull(result.getRootNode());
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 4.0, 2.0), "a"));
result = instance.learn(data);
assertNotNull(result.getRootNode());
assertTrue(result.getRootNode().isLeaf());
assertEquals("a", result.evaluate(data.get(0).getInput()));
assertEquals(1, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 2.0), "a"));
result = instance.learn(data);
assertNotNull(result.getRootNode());
assertTrue(result.getRootNode().isLeaf());
assertEquals("a", result.evaluate(data.get(0).getInput()));
assertEquals("a", result.evaluate(data.get(1).getInput()));
assertEquals(1, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 2.0, 3.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 4.0, 4.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 3.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 0.0, 2.0), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 5.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 7.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 8.0, 2.0), "b"));
result = instance.learn(data);
assertNotNull(result.getRootNode());
assertFalse(result.getRootNode().isLeaf());
for ( InputOutputPair<Vector3, String> example : data )
{
assertEquals(example.getOutput(), result.evaluate(example.getInput()));
}
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "b"));
result = instance.learn(data);
assertEquals("a", result.evaluate(new Vector3(1.0, 1.0, 1.0)));
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
// Because class a is weighted 4x as much as b by prior, tree
// should predict "a" for pattern (1,1,1) despite seeing two
// examples of class a and two examples of class b.
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "b"));
result = instance.learn(data);
assertEquals("a", result.evaluate(new Vector3(1.0, 1.0, 1.0)));
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
data.clear();
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "b"));
result = instance.learn(data);
assertEquals("a", result.evaluate(new Vector3(1.0, 1.0, 1.0)));
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
data.clear();
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(0.0, 0.0, 0.0), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 0.0, 0.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(0.0, 1.0, 0.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 0.0), "a"));
result = instance.learn(data);
for ( InputOutputPair<Vector3, String> example : data )
{
assertEquals(example.getOutput(), result.evaluate(example.getInput()));
}
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
}
/**
* Test of areAllOutputsEqual method, of class gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeLearner.
*/
public void testAreAllOutputsEqual()
{
CategorizationTreeLearner<Vector3, String> instance =
new CategorizationTreeLearner<Vector3, String>();
LinkedList<InputOutputPair<Vector3, String>> data =
new LinkedList<InputOutputPair<Vector3, String>>();
assertTrue(instance.areAllOutputsEqual(data));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 0.0, 0.0), "a"));
assertTrue(instance.areAllOutputsEqual(data));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(0.0, 1.0, 0.0), "a"));
assertTrue(instance.areAllOutputsEqual(data));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(0.0, 0.0, 1.0), "a"));
assertTrue(instance.areAllOutputsEqual(data));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 0.0, 1.0), "b"));
assertFalse(instance.areAllOutputsEqual(data));
}
/**
* Test of getOutputCounts method, of class gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeLearner.
*/
public void testGetOutputCounts()
{
DefaultDataDistribution<String> result =
CategorizationTreeLearner.getOutputCounts(null);
assertEquals(0.0, result.getTotal());
LinkedList<InputOutputPair<Vector3, String>> data =
new LinkedList<InputOutputPair<Vector3, String>>();
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(), "a"));
result = CategorizationTreeLearner.getOutputCounts(data);
assertEquals(1.0, result.getTotal());
assertEquals(1.0, result.get("a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(), "b"));
result = CategorizationTreeLearner.getOutputCounts(data);
assertEquals(2.0, result.getTotal());
assertEquals(1.0, result.get("a"));
assertEquals(1.0, result.get("b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(), "b"));
result = CategorizationTreeLearner.getOutputCounts(data);
assertEquals(3.0, result.getTotal());
assertEquals(1.0, result.get("a"));
assertEquals(2.0, result.get("b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(), "c"));
result = CategorizationTreeLearner.getOutputCounts(data);
assertEquals(4.0, result.getTotal());
assertEquals(1.0, result.get("a"));
assertEquals(2.0, result.get("b"));
assertEquals(1.0, result.get("c"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(), "c"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(), "a"));
result = CategorizationTreeLearner.getOutputCounts(data);
assertEquals(9.0, result.getTotal());
assertEquals(3.0, result.get("a"));
assertEquals(4.0, result.get("b"));
assertEquals(2.0, result.get("c"));
}
/**
* Test of splitData method, of class gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeLearner.
*/
public void testSplitData()
{
CategorizationTreeLearner<Vector3, String> instance =
new CategorizationTreeLearner<Vector3, String>();
ArrayList<InputOutputPair<Vector3, String>> data =
new ArrayList<InputOutputPair<Vector3, String>>();
VectorElementThresholdCategorizer decider =
new VectorElementThresholdCategorizer(1, 2.5);
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 4.0, 2.0), "a"));
Map<Boolean, LinkedList<InputOutputPair<? extends Vector3, String>>>
result = instance.splitData(data, decider);
assertEquals(1, result.size());
assertEquals(1, result.get(true).size());
assertTrue(result.get(true).contains(data.get(0)));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 2.0), "a"));
result = instance.splitData(data, decider);
assertEquals(2, result.size());
assertEquals(1, result.get(true).size());
assertEquals(1, result.get(false).size());
assertTrue(result.get(true).contains(data.get(0)));
assertTrue(result.get(false).contains(data.get(1)));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 2.0, 3.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 4.0, 4.0), "b"));
result = instance.splitData(data, decider);
assertEquals(2, result.size());
assertEquals(2, result.get(true).size());
assertEquals(2, result.get(false).size());
assertTrue(result.get(true).contains(data.get(0)));
assertTrue(result.get(false).contains(data.get(1)));
assertTrue(result.get(false).contains(data.get(2)));
assertTrue(result.get(true).contains(data.get(3)));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 3.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 0.0, 2.0), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 5.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 7.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 8.0, 2.0), "b"));
result = instance.splitData(data, decider);
assertEquals(2, result.size());
assertEquals(6, result.get(true).size());
assertEquals(3, result.get(false).size());
assertTrue(result.get(true).contains(data.get(0)));
assertTrue(result.get(false).contains(data.get(1)));
assertTrue(result.get(false).contains(data.get(2)));
assertTrue(result.get(true).contains(data.get(3)));
assertTrue(result.get(true).contains(data.get(4)));
assertTrue(result.get(false).contains(data.get(5)));
assertTrue(result.get(true).contains(data.get(6)));
assertTrue(result.get(true).contains(data.get(7)));
assertTrue(result.get(true).contains(data.get(8)));
}
/**
* Test of getDeciderLearner method, of class gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeLearner.
*/
public void testGetDeciderLearner()
{
this.testSetDeciderLearner();
}
/**
* Test of setDeciderLearner method, of class gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeLearner.
*/
public void testSetDeciderLearner()
{
CategorizationTreeLearner<Vector3, String> instance =
new CategorizationTreeLearner<Vector3, String>();
assertNull(instance.getDeciderLearner());
VectorThresholdInformationGainLearner<String> deciderLearner =
new VectorThresholdInformationGainLearner<String>();
instance.setDeciderLearner(deciderLearner);
assertSame(deciderLearner, instance.getDeciderLearner());
instance.setDeciderLearner(null);
assertNull(instance.getDeciderLearner());
}
public void testGetLeafCountThreshold()
{
this.testSetLeafCountThreshold();
}
public void testSetLeafCountThreshold()
{
int leafCountThreshold = CategorizationTreeLearner.DEFAULT_LEAF_COUNT_THRESHOLD;
CategorizationTreeLearner<Vector3, String> instance =
new CategorizationTreeLearner<Vector3, String>();
assertEquals(leafCountThreshold, instance.getLeafCountThreshold());
leafCountThreshold = 1;
instance.setLeafCountThreshold(leafCountThreshold);
assertEquals(leafCountThreshold, instance.getLeafCountThreshold());
leafCountThreshold = 2;
instance.setLeafCountThreshold(leafCountThreshold);
assertEquals(leafCountThreshold, instance.getLeafCountThreshold());
leafCountThreshold = 0;
instance.setLeafCountThreshold(leafCountThreshold);
assertEquals(leafCountThreshold, instance.getLeafCountThreshold());
leafCountThreshold = 10;
instance.setLeafCountThreshold(leafCountThreshold);
assertEquals(leafCountThreshold, instance.getLeafCountThreshold());
boolean exceptionThrown = false;
try
{
instance.setLeafCountThreshold(-1);
}
catch (IllegalArgumentException e)
{
exceptionThrown = true;
}
finally
{
assertTrue(exceptionThrown);
}
assertEquals(leafCountThreshold, instance.getLeafCountThreshold());
}
public void testSetMaxDepth()
{
int maxDepth = CategorizationTreeLearner.DEFAULT_MAX_DEPTH;
CategorizationTreeLearner<Vector3, String> instance =
new CategorizationTreeLearner<Vector3, String>();
assertEquals(maxDepth, instance.getMaxDepth());
maxDepth = 1;
instance.setMaxDepth(maxDepth);
assertEquals(maxDepth, instance.getMaxDepth());
maxDepth = 2;
instance.setMaxDepth(maxDepth);
assertEquals(maxDepth, instance.getMaxDepth());
maxDepth = 10;
instance.setMaxDepth(maxDepth);
assertEquals(maxDepth, instance.getMaxDepth());
maxDepth = 0;
instance.setMaxDepth(maxDepth);
assertEquals(maxDepth, instance.getMaxDepth());
maxDepth = -1;
instance.setMaxDepth(maxDepth);
assertEquals(maxDepth, instance.getMaxDepth());
}
}