/*
* File: VectorNaiveBayesCategorizerTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry Learning Core
*
* Copyright November 24, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
*/
package gov.sandia.cognition.learning.algorithm.bayes;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.matrix.mtj.Vector2;
import gov.sandia.cognition.math.matrix.mtj.Vector3;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import junit.framework.TestCase;
/**
* Unit tests for class VectorNaiveBayesCategorizer.
*
* @author Justin Basilico
* @since 3.1
*/
public class VectorNaiveBayesCategorizerTest
extends TestCase
{
/**
* Creates a new test.
*
* @param testName
* The test name.
*/
public VectorNaiveBayesCategorizerTest(
final String testName)
{
super(testName);
}
/**
* Test of constructors for class VectorNaiveBayesCategorizer.
*/
public void testConstructors()
{
VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF> instance = new VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF>();
assertEquals( 0.0, instance.getPriors().getTotal(), 0.0 );
assertTrue(instance.getConditionals().isEmpty());
DataDistribution<String> priors = new DefaultDataDistribution<String>();
Map<String, List<UnivariateGaussian.PDF>> conditionals =
new LinkedHashMap<String, List<UnivariateGaussian.PDF>>();
instance = new VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF>(priors, conditionals);
assertSame(priors, instance.getPriors());
assertSame(conditionals, instance.getConditionals());
}
/**
* Test of clone method, of class VectorNaiveBayesCategorizer.
*/
public void testClone()
{
VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF> instance = new VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF>();
VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF> clone = instance.clone();
assertNotSame(instance.getPriors(), clone.getPriors());
assertNotSame(instance.getConditionals(), clone.getConditionals());
}
/**
* Test of evaluate method, of class VectorNaiveBayesCategorizer.
*/
public void testEvaluate()
{
Vector2 input = new Vector2(1, 2);
VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF> instance = new VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF>();
assertNull(instance.evaluate(input));
instance.getPriors().increment("a", 3);
instance.getConditionals().put("a", new ArrayList<UnivariateGaussian.PDF>());
instance.getConditionals().get("a").add(new UnivariateGaussian.PDF(3.0, 10.0));
instance.getConditionals().get("a").add(new UnivariateGaussian.PDF(5.0, 1.0));
assertEquals("a", instance.evaluate(input));
instance.getPriors().increment("b", 2);
instance.getConditionals().put("b", new ArrayList<UnivariateGaussian.PDF>());
instance.getConditionals().get("b").add(new UnivariateGaussian.PDF(0.0, 1.0));
instance.getConditionals().get("b").add(new UnivariateGaussian.PDF(1.0, 1.0));
assertEquals("b", instance.evaluate(input));
}
/**
* Test of getCategories method, of class VectorNaiveBayesCategorizer.
*/
public void testGetCategories()
{
VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF> instance = new VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF>();
assertTrue(instance.getCategories().isEmpty());
instance.getPriors().increment("a", 1);
instance.getConditionals().put("a", new ArrayList<UnivariateGaussian.PDF>());
assertEquals(1, instance.getCategories().size());
assertTrue(instance.getCategories().contains("a"));
instance.getPriors().increment("b", 1);
instance.getConditionals().put("b", new ArrayList<UnivariateGaussian.PDF>());
assertEquals(2, instance.getCategories().size());
assertTrue(instance.getCategories().contains("a"));
assertTrue(instance.getCategories().contains("b"));
}
/**
* Test of getInputDimensionality method, of class VectorNaiveBayesCategorizer.
*/
public void testGetInputDimensionality()
{
VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF> instance = new VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF>();
assertEquals(0, instance.getInputDimensionality());
instance.getPriors().increment("a", 1);
instance.getConditionals().put("a", new ArrayList<UnivariateGaussian.PDF>());
assertEquals(0, instance.getInputDimensionality());
instance.getConditionals().get("a").add(new UnivariateGaussian.PDF());
assertEquals(1, instance.getInputDimensionality());
instance.getConditionals().get("a").add(new UnivariateGaussian.PDF());
assertEquals(2, instance.getInputDimensionality());
}
/**
* Test of getPriors method, of class VectorNaiveBayesCategorizer.
*/
public void testGetPriors()
{
this.testSetPriors();
}
/**
* Test of setPriors method, of class VectorNaiveBayesCategorizer.
*/
public void testSetPriors()
{
VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF> instance = new VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF>();
assertEquals( 0.0, instance.getPriors().getTotal(), 0.0 );
DataDistribution<String> priors = new DefaultDataDistribution<String>();
instance.setPriors(priors);
assertSame(priors, instance.getPriors());
priors = null;
instance.setPriors(priors);
assertSame(priors, instance.getPriors());
priors = new DefaultDataDistribution<String>();
instance.setPriors(priors);
assertSame(priors, instance.getPriors());
}
/**
* Test of getConditionals method, of class VectorNaiveBayesCategorizer.
*/
public void testGetConditionals()
{
this.testSetConditionals();
}
/**
* Test of setConditionals method, of class VectorNaiveBayesCategorizer.
*/
public void testSetConditionals()
{
VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF> instance = new VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF>();
assertTrue(instance.getConditionals().isEmpty());
Map<String, List<UnivariateGaussian.PDF>> conditionals =
new LinkedHashMap<String, List<UnivariateGaussian.PDF>>();
instance.setConditionals(conditionals);
assertSame(conditionals, instance.getConditionals());
conditionals = null;
instance.setConditionals(conditionals);
assertSame(conditionals, instance.getConditionals());
conditionals = new LinkedHashMap<String, List<UnivariateGaussian.PDF>>();
instance.setConditionals(conditionals);
assertSame(conditionals, instance.getConditionals());
}
/**
* Test of the VectorNaiveBayesCategorizer.Learner class.
*/
public void testLearner()
{
VectorNaiveBayesCategorizer.Learner<String, UnivariateGaussian.PDF> learner =
new VectorNaiveBayesCategorizer.Learner<String, UnivariateGaussian.PDF>(
new UnivariateGaussian.MaximumLikelihoodEstimator());
ArrayList<InputOutputPair<Vector3, String>> data =
new ArrayList<InputOutputPair<Vector3, String>>();
data.add(DefaultInputOutputPair.create(new Vector3(-1.0, 0.0, 3.0), "a"));
data.add(DefaultInputOutputPair.create(new Vector3(2.0, 1.0, 9.0), "b"));
data.add(DefaultInputOutputPair.create(new Vector3(4.0, 0.0, 9.2), "b"));
data.add(DefaultInputOutputPair.create(new Vector3(-5.0, 1.0, 2.0), "a"));
data.add(DefaultInputOutputPair.create(new Vector3(-7.0, 1.0, 3.0), "a"));
VectorNaiveBayesCategorizer<String, UnivariateGaussian.PDF> instance = learner.learn(data);
assertEquals(2, instance.getCategories().size());
assertTrue(instance.getCategories().contains("a"));
assertTrue(instance.getCategories().contains("b"));
assertEquals(3.0, instance.getPriors().get("a"));
assertEquals(2.0, instance.getPriors().get("b"));
assertEquals(5.0, instance.getPriors().getTotal());
assertEquals(2, instance.getConditionals().size());
assertEquals(3, instance.getConditionals().get("a").size());
assertEquals(3, instance.getConditionals().get("b").size());
assertEquals(3, instance.getInputDimensionality());
for (InputOutputPair<Vector3, String> example : data)
{
assertEquals(example.getOutput(), instance.evaluate(example.getInput()));
}
}
}