/*
* File: DiscreteNaiveBayesCategorizerTest.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Oct 21, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.algorithm.bayes;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import junit.framework.TestCase;
import java.util.Random;
/**
* Unit tests for DiscreteNaiveBayesCategorizerTest.
*
* @author krdixon
*/
@PublicationReference(
author="Raymond J. Mooney",
title="CS 391L: Machine Learning: Bayesian Learning: Naive Bayes",
type=PublicationType.Misc,
year=2009,
url="http://www.cs.utexas.edu/~mooney/cs391L/slides/naive-bayes.pdf",
notes="Undergrad course lecture notes."
)
public class DiscreteNaiveBayesCategorizerTest
extends TestCase
{
/**
* Random number generator to use for a fixed random seed.
*/
public final Random RANDOM = new Random( 1 );
/**
* Default tolerance of the regression tests, {@value}.
*/
public final double TOLERANCE = 1e-5;
/**
* Class 1
*/
public static final String CLASS1 = "positive";
/**
* Class 2
*/
public static final String CLASS2 = "negative";
/**
* Tests for class DiscreteNaiveBayesCategorizerTest.
* @param testName Name of the test.
*/
public DiscreteNaiveBayesCategorizerTest(
String testName)
{
super(testName);
}
/**
* Create instance
* @return
* instance
*/
@SuppressWarnings("unchecked")
public DiscreteNaiveBayesCategorizer<Boolean,String> createInstance()
{
// From
Map<String,List<DefaultDataDistribution<Boolean>>> conditionalPMFs =
new LinkedHashMap<String, List<DefaultDataDistribution<Boolean>>>();
DefaultDataDistribution<Boolean> p0 = new DefaultDataDistribution<Boolean>();
p0.increment( Boolean.TRUE, 10 );
p0.increment( Boolean.FALSE, 90 );
DefaultDataDistribution<Boolean> p1 = new DefaultDataDistribution<Boolean>();
p1.increment( Boolean.TRUE, 90 );
p1.increment( Boolean.FALSE, 10 );
DefaultDataDistribution<Boolean> p2 = new DefaultDataDistribution<Boolean>();
p2.increment( Boolean.TRUE, 90 );
p2.increment( Boolean.FALSE, 10 );
DefaultDataDistribution<Boolean> n0 = new DefaultDataDistribution<Boolean>();
n0.increment( Boolean.TRUE, 20 );
n0.increment( Boolean.FALSE, 80 );
DefaultDataDistribution<Boolean> n1 = new DefaultDataDistribution<Boolean>();
n1.increment( Boolean.TRUE, 30 );
n1.increment( Boolean.FALSE, 70 );
DefaultDataDistribution<Boolean> n2 = new DefaultDataDistribution<Boolean>();
n2.increment( Boolean.TRUE, 30 );
n2.increment( Boolean.FALSE, 70 );
conditionalPMFs.put(CLASS1, Arrays.asList( p0, p1, p2 ) );
conditionalPMFs.put(CLASS2, Arrays.asList( n0, n1, n2 ) );
DefaultDataDistribution<String> priors =
new DefaultDataDistribution<String>();
priors.increment( CLASS1, 300 );
priors.increment( CLASS2, 300 );
return new DiscreteNaiveBayesCategorizer<Boolean, String>(
3, priors, conditionalPMFs );
}
/**
* createInstance
*/
public void testCreateInstance()
{
System.out.println( "createInstance" );
DiscreteNaiveBayesCategorizer<Boolean, String> instance =
this.createInstance();
assertEquals( 3, instance.getInputDimensionality() );
assertEquals( 0.5, instance.getPriorProbability(CLASS1) );
assertEquals( 0.5, instance.getPriorProbability(CLASS2) );
assertEquals( 0.1, instance.getConditionalProbability(0, Boolean.TRUE, CLASS1) );
assertEquals( 0.9, instance.getConditionalProbability(0, Boolean.FALSE, CLASS1) );
assertEquals( 0.9, instance.getConditionalProbability(1, Boolean.TRUE, CLASS1) );
assertEquals( 0.1, instance.getConditionalProbability(1, Boolean.FALSE, CLASS1) );
assertEquals( 0.9, instance.getConditionalProbability(2, Boolean.TRUE, CLASS1) );
assertEquals( 0.1, instance.getConditionalProbability(2, Boolean.FALSE, CLASS1) );
assertEquals( 0.2, instance.getConditionalProbability(0, Boolean.TRUE, CLASS2) );
assertEquals( 0.8, instance.getConditionalProbability(0, Boolean.FALSE, CLASS2) );
assertEquals( 0.3, instance.getConditionalProbability(1, Boolean.TRUE, CLASS2) );
assertEquals( 0.7, instance.getConditionalProbability(1, Boolean.FALSE, CLASS2) );
assertEquals( 0.3, instance.getConditionalProbability(2, Boolean.TRUE, CLASS2) );
assertEquals( 0.7, instance.getConditionalProbability(2, Boolean.FALSE, CLASS2) );
}
/**
* Tests the constructors of class DiscreteNaiveBayesCategorizerTest.
*/
public void testConstructors()
{
System.out.println( "Constructors" );
DiscreteNaiveBayesCategorizer<Boolean, String> instance =
new DiscreteNaiveBayesCategorizer<Boolean, String>();
assertEquals( 0, instance.getInputDimensionality() );
assertEquals( 0, instance.getCategories().size() );
}
/**
* Test of clone method, of class DiscreteNaiveBayesCategorizer.
*/
public void testClone()
{
System.out.println("clone");
DiscreteNaiveBayesCategorizer<Boolean,String> instance = this.createInstance();
DiscreteNaiveBayesCategorizer<Boolean,String> clone = instance.clone();
assertNotNull( clone );
assertNotSame( instance, clone );
assertEquals( instance.getInputDimensionality(), clone.getInputDimensionality() );
double prior = instance.getPriorProbability(CLASS1);
assertEquals( prior, clone.getPriorProbability(CLASS1) );
clone.update( Arrays.asList(true,false,true), CLASS1);
assertEquals( prior, instance.getPriorProbability(CLASS1) );
assertEquals( 301.0/601.0, clone.getPriorProbability(CLASS1) );
}
/**
* Test of getCategories method, of class DiscreteNaiveBayesCategorizer.
*/
public void testGetCategories()
{
System.out.println("getCategories");
DiscreteNaiveBayesCategorizer<Boolean,String> instance = this.createInstance();
assertEquals( 2, instance.getCategories().size() );
assertEquals( CLASS1, CollectionUtil.getElement( instance.getCategories(), 0 ) );
assertEquals( CLASS2, CollectionUtil.getElement( instance.getCategories(), 1 ) );
}
/**
* computeConjuctiveProbability
*/
public void testComputeConjuctiveProbability()
{
System.out.println( "computeConjuctiveProbability" );
DiscreteNaiveBayesCategorizer<Boolean, String> instance =
this.createInstance();
List<Boolean> inputs = Arrays.asList( Boolean.TRUE, Boolean.TRUE, Boolean.TRUE );
assertEquals( 0.0405, instance.computeConjuctiveProbability( inputs, CLASS1), TOLERANCE );
assertEquals( 0.0090, instance.computeConjuctiveProbability( inputs, CLASS2), TOLERANCE );
assertEquals( 0.0, instance.computeConjuctiveProbability( inputs, "zero"), TOLERANCE );
}
/**
* computePosterior
*/
public void testComputePosterior()
{
System.out.println( "computePosterior" );
DiscreteNaiveBayesCategorizer<Boolean, String> instance =
this.createInstance();
List<Boolean> inputs = Arrays.asList( Boolean.TRUE, Boolean.TRUE, Boolean.TRUE );
assertEquals( 0.818181, instance.computePosterior( inputs, CLASS1), TOLERANCE );
assertEquals( 0.181818, instance.computePosterior( inputs, CLASS2), TOLERANCE );
}
/**
* computeEvidenceProbability
*/
public void testComputeEvidenceProbability()
{
System.out.println( "computeEvidenceProbability" );
DiscreteNaiveBayesCategorizer<Boolean, String> instance =
this.createInstance();
List<Boolean> inputs = Arrays.asList( Boolean.TRUE, Boolean.TRUE, Boolean.TRUE );
assertEquals( 0.0495, instance.computeEvidenceProbabilty( inputs ), TOLERANCE );
}
/**
* Test of evaluate method, of class DiscreteNaiveBayesCategorizer.
*/
public void testEvaluate()
{
System.out.println("evaluate");
DiscreteNaiveBayesCategorizer<Boolean, String> instance =
this.createInstance();
assertEquals( CLASS1, instance.evaluate( Arrays.asList( true, true, true ) ) );
try
{
instance.evaluate( Arrays.asList( true, true ) );
fail( "Input dimension doesn't match" );
}
catch (Exception e)
{
System.out.println( "Good: " + e );
}
}
/**
* Test of getInputDimensionality method, of class DiscreteNaiveBayesCategorizer.
*/
public void testGetInputDimensionality()
{
System.out.println("getInputDimensionality");
DiscreteNaiveBayesCategorizer<Boolean, String> instance = this.createInstance();
assertEquals( 3, instance.getInputDimensionality() );
}
/**
* Test of setInputDimensionality method, of class DiscreteNaiveBayesCategorizer.
*/
public void testSetInputDimensionality()
{
System.out.println("setInputDimensionality");
DiscreteNaiveBayesCategorizer<Boolean,String> instance = this.createInstance();
assertEquals( 3, instance.getInputDimensionality() );
instance.setInputDimensionality(1);
assertEquals( 0, instance.getCategories().size() );
}
/**
* Update
*/
public void testUpdate()
{
System.out.println( "update" );
DiscreteNaiveBayesCategorizer<String,Integer> instance =
new DiscreteNaiveBayesCategorizer<String,Integer>();
assertEquals( 0, instance.getInputDimensionality() );
instance.update( Arrays.asList( "small", "red", "circle" ), 1 );
assertEquals( 3, instance.getInputDimensionality() );
assertEquals( 1, instance.getCategories().size() );
instance.update( Arrays.asList( "large", "red", "circle" ), 1 );
assertEquals( 3, instance.getInputDimensionality() );
assertEquals( 1, instance.getCategories().size() );
instance.update( Arrays.asList( "small", "red", "circle" ), 2 );
assertEquals( 3, instance.getInputDimensionality() );
assertEquals( 2, instance.getCategories().size() );
assertEquals( 2.0/3.0, instance.getPriorProbability(1) );
assertEquals( 1.0/3.0, instance.getPriorProbability(2) );
instance.update( Arrays.asList( "large", "blue", "triangle" ), 2 );
assertEquals( 3, instance.getInputDimensionality() );
assertEquals( 2, instance.getCategories().size() );
assertEquals( 2.0/4.0, instance.getPriorProbability(1) );
assertEquals( 2.0/4.0, instance.getPriorProbability(2) );
// input 0
assertEquals( 0.5, instance.getConditionalProbability(0, "small", 1 ) );
assertEquals( 0.0, instance.getConditionalProbability(0, "medium", 1 ) );
assertEquals( 0.5, instance.getConditionalProbability(0, "large", 1 ) );
assertEquals( 0.5, instance.getConditionalProbability(0, "small", 2 ) );
assertEquals( 0.0, instance.getConditionalProbability(0, "medium", 2 ) );
assertEquals( 0.5, instance.getConditionalProbability(0, "large", 2 ) );
// input 1
assertEquals( 1.0, instance.getConditionalProbability(1, "red", 1 ) );
assertEquals( 0.0, instance.getConditionalProbability(1, "blue", 1 ) );
assertEquals( 0.0, instance.getConditionalProbability(1, "green", 1 ) );
assertEquals( 0.5, instance.getConditionalProbability(1, "red", 2 ) );
assertEquals( 0.5, instance.getConditionalProbability(1, "blue", 2 ) );
assertEquals( 0.0, instance.getConditionalProbability(1, "green", 2 ) );
// input 2
assertEquals( 0.0, instance.getConditionalProbability(2, "square", 1 ) );
assertEquals( 0.0, instance.getConditionalProbability(2, "triangle", 1 ) );
assertEquals( 1.0, instance.getConditionalProbability(2, "circle", 1 ) );
assertEquals( 0.0, instance.getConditionalProbability(2, "square", 2 ) );
assertEquals( 0.5, instance.getConditionalProbability(2, "triangle", 2 ) );
assertEquals( 0.5, instance.getConditionalProbability(2, "circle", 2 ) );
// Test a couple of thigs
List<String> i1 = Arrays.asList("medium", "red", "circle");
assertEquals( 0.0, instance.computeConjuctiveProbability( i1, 1 ) );
assertEquals( 0.0, instance.computeConjuctiveProbability( i1, 2 ) );
assertEquals( 0.0, instance.computeEvidenceProbabilty( i1 ) );
assertEquals( 0.0, instance.computePosterior( i1, 1 ) );
assertEquals( 0.0, instance.computePosterior( i1, 2 ) );
List<String> i2 = Arrays.asList("small",null,null);
assertEquals( 0.5, instance.computeEvidenceProbabilty(i2));
List<String> i3 = Arrays.asList(null,"red",null);
assertEquals( 0.75, instance.computeEvidenceProbabilty(i3));
List<String> i4 = Arrays.asList(null,"red","circle");
assertEquals( 1.0, instance.computeConditionalProbability(i4,1));
assertEquals( 0.25, instance.computeConditionalProbability(i4,2));
instance.update(i4, 1);
assertEquals( 0.6, instance.getPriorProbability(1) );
assertEquals( 0.857142857, instance.computePosterior(i4, 1), TOLERANCE );
try
{
instance.update( Arrays.asList("fail",null), 1 );
fail( "Input dimensionality doesn't match");
}
catch (Exception e)
{
System.out.println( "Good: " + e );
}
}
/**
* Learner
*/
public void testLearner()
{
System.out.println( "Learner" );
DiscreteNaiveBayesCategorizer.Learner<String,Integer> learner =
new DiscreteNaiveBayesCategorizer.Learner<String, Integer>();
LinkedList<InputOutputPair<List<String>,Integer>> data =
new LinkedList<InputOutputPair<List<String>, Integer>>();
data.add( DefaultInputOutputPair.create(Arrays.asList( "small", "red", "circle" ), 1 ));
data.add( DefaultInputOutputPair.create(Arrays.asList( "large", "red", "circle" ), 1 ));
data.add( DefaultInputOutputPair.create(Arrays.asList( "small", "red", "circle" ), 2 ));
data.add( DefaultInputOutputPair.create(Arrays.asList( "large", "blue", "triangle" ), 2) );
DiscreteNaiveBayesCategorizer<String,Integer> instance = learner.learn(data);
assertEquals( 3, instance.getInputDimensionality() );
assertEquals( 2, instance.getCategories().size() );
assertEquals( 2.0/4.0, instance.getPriorProbability(1) );
assertEquals( 2.0/4.0, instance.getPriorProbability(2) );
// input 0
assertEquals( 0.5, instance.getConditionalProbability(0, "small", 1 ) );
assertEquals( 0.0, instance.getConditionalProbability(0, "medium", 1 ) );
assertEquals( 0.5, instance.getConditionalProbability(0, "large", 1 ) );
assertEquals( 0.5, instance.getConditionalProbability(0, "small", 2 ) );
assertEquals( 0.0, instance.getConditionalProbability(0, "medium", 2 ) );
assertEquals( 0.5, instance.getConditionalProbability(0, "large", 2 ) );
// input 1
assertEquals( 1.0, instance.getConditionalProbability(1, "red", 1 ) );
assertEquals( 0.0, instance.getConditionalProbability(1, "blue", 1 ) );
assertEquals( 0.0, instance.getConditionalProbability(1, "green", 1 ) );
assertEquals( 0.5, instance.getConditionalProbability(1, "red", 2 ) );
assertEquals( 0.5, instance.getConditionalProbability(1, "blue", 2 ) );
assertEquals( 0.0, instance.getConditionalProbability(1, "green", 2 ) );
// input 2
assertEquals( 0.0, instance.getConditionalProbability(2, "square", 1 ) );
assertEquals( 0.0, instance.getConditionalProbability(2, "triangle", 1 ) );
assertEquals( 1.0, instance.getConditionalProbability(2, "circle", 1 ) );
assertEquals( 0.0, instance.getConditionalProbability(2, "square", 2 ) );
assertEquals( 0.5, instance.getConditionalProbability(2, "triangle", 2 ) );
assertEquals( 0.5, instance.getConditionalProbability(2, "circle", 2 ) );
// Test a couple of thigs
List<String> i1 = Arrays.asList("medium", "red", "circle");
assertEquals( 0.0, instance.computeConjuctiveProbability( i1, 1 ) );
assertEquals( 0.0, instance.computeConjuctiveProbability( i1, 2 ) );
assertEquals( 0.0, instance.computeEvidenceProbabilty( i1 ) );
List<String> i2 = Arrays.asList("small",null,null);
assertEquals( 0.5, instance.computeEvidenceProbabilty(i2));
List<String> i3 = Arrays.asList(null,"red",null);
assertEquals( 0.75, instance.computeEvidenceProbabilty(i3));
List<String> i4 = Arrays.asList(null,"red","circle");
assertEquals( 1.0, instance.computeConditionalProbability(i4,1));
assertEquals( 0.25, instance.computeConditionalProbability(i4,2));
assertNotNull( learner.clone() );
}
}