/*
* File: BinaryVersusCategorizerTest.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright April 22, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.learning.function.categorization;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.algorithm.tree.CategorizationTreeLearner;
import gov.sandia.cognition.learning.algorithm.tree.VectorThresholdInformationGainLearner;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.mtj.Vector3;
import gov.sandia.cognition.util.DefaultPair;
import gov.sandia.cognition.util.Pair;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import junit.framework.TestCase;
/**
* Unit tests for class {@code BinaryVersusCategorizer}.
*
* @author Justin Basilico
* @version 3.0
*/
public class BinaryVersusCategorizerTest
extends TestCase
{
/**
* Creates a new test.
*
* @param testName
* The test name.
*/
public BinaryVersusCategorizerTest(
final String testName)
{
super(testName);
}
/**
* Test of constructors of class BinaryVersusCategorizer.
*/
public void testConstructors()
{
BinaryVersusCategorizer<Vector, String> instance =
new BinaryVersusCategorizer<Vector, String>();
assertTrue(instance.getCategories().isEmpty());
assertTrue(instance.getCategoryPairsToEvaluatorMap().isEmpty());
Set<String> categories = new HashSet<String>();
categories.add("a");
categories.add("b");
categories.add("c");
instance = new BinaryVersusCategorizer<Vector, String>(categories);
assertSame(categories, instance.getCategories());
assertTrue(instance.getCategoryPairsToEvaluatorMap().isEmpty());
LinkedHashMap<Pair<String, String>, Evaluator<? super Vector, Boolean>> categoryPairsToEvaluatorMap =
new LinkedHashMap<Pair<String, String>, Evaluator<? super Vector, Boolean>>();
instance = new BinaryVersusCategorizer<Vector, String>(
categories, categoryPairsToEvaluatorMap);
assertSame(categories, instance.getCategories());
assertSame(categoryPairsToEvaluatorMap, instance.getCategoryPairsToEvaluatorMap());
}
/**
* Test of clone method, of class BinaryVersusCategorizer.
*/
public void testClone()
{
BinaryVersusCategorizer<Vector, String> instance =
new BinaryVersusCategorizer<Vector, String>();
BinaryVersusCategorizer<Vector, String> clone = instance.clone();
assertNotNull(clone);
assertNotSame(clone, instance);
instance.getCategories().add("a");
instance.getCategories().add("b");
instance.getCategoryPairsToEvaluatorMap().put(
new DefaultPair<String, String>("a", "b"),
new VectorElementThresholdCategorizer(0, 1.0));
clone = instance.clone();
assertNotNull(clone);
assertNotSame(clone, instance);
assertNotSame(clone.categoryPairsToEvaluatorMap, instance.categoryPairsToEvaluatorMap);
assertEquals(clone.categoryPairsToEvaluatorMap.keySet(), instance.categoryPairsToEvaluatorMap.keySet());
}
/**
* Test of evaluate method, of class BinaryVersusCategorizer.
*/
public void testEvaluate()
{
BinaryVersusCategorizer<Vector, String> instance =
new BinaryVersusCategorizer<Vector, String>();
assertNull(instance.evaluate(new Vector3()));
instance.getCategories().add("a");
instance.getCategories().add("b");
instance.getCategoryPairsToEvaluatorMap().put(
new DefaultPair<String, String>("a", "b"),
new VectorElementThresholdCategorizer(0, 1.0));
assertEquals("a", instance.evaluate(new Vector3(0.0, 0.0, 0.0)));
assertEquals("b", instance.evaluate(new Vector3(2.0, 0.0, 0.0)));
instance.getCategories().add("c");
instance.getCategoryPairsToEvaluatorMap().put(
new DefaultPair<String, String>("a", "c"),
new VectorElementThresholdCategorizer(1, 1.0));
instance.getCategoryPairsToEvaluatorMap().put(
new DefaultPair<String, String>("b", "c"),
new VectorElementThresholdCategorizer(1, 1.0));
assertEquals("a", instance.evaluate(new Vector3(0.0, 0.0, 0.0)));
assertEquals("b", instance.evaluate(new Vector3(2.0, 0.0, 0.0)));
assertEquals("c", instance.evaluate(new Vector3(0.0, 2.0, 0.0)));
assertEquals("c", instance.evaluate(new Vector3(2.0, 2.0, 0.0)));
}
/**
* Test of getCategoryPairsToEvaluatorMap method, of class BinaryVersusCategorizer.
*/
public void testGetCategoryPairsToEvaluatorMap()
{
this.testSetCategoryPairsToEvaluatorMap();
}
/**
* Test of setCategoryPairsToEvaluatorMap method, of class BinaryVersusCategorizer.
*/
public void testSetCategoryPairsToEvaluatorMap()
{
BinaryVersusCategorizer<Vector, String> instance =
new BinaryVersusCategorizer<Vector, String>();
assertTrue(instance.getCategoryPairsToEvaluatorMap().isEmpty());
LinkedHashMap<Pair<String, String>, Evaluator<? super Vector, Boolean>> categoryPairsToEvaluatorMap =
new LinkedHashMap<Pair<String, String>, Evaluator<? super Vector, Boolean>>();
instance.setCategoryPairsToEvaluatorMap(categoryPairsToEvaluatorMap);
assertSame(categoryPairsToEvaluatorMap, instance.getCategoryPairsToEvaluatorMap());
categoryPairsToEvaluatorMap = null;
instance.setCategoryPairsToEvaluatorMap(categoryPairsToEvaluatorMap);
assertSame(categoryPairsToEvaluatorMap, instance.getCategoryPairsToEvaluatorMap());
}
/**
* Test of Learner class of class BinaryVersusCategorizer.
*/
public void testLearner()
{
BinaryVersusCategorizer.Learner<Vector3, String> instance =
new BinaryVersusCategorizer.Learner<Vector3, String>();
instance.setLearner(
new CategorizationTreeLearner<Vector3, Boolean>(
new VectorThresholdInformationGainLearner<Boolean>()));
BinaryVersusCategorizer<Vector3, String> result;
List<InputOutputPair<Vector3, String>> data =
new LinkedList<InputOutputPair<Vector3, String>>();
result = instance.learn(data);
assertNotNull(result);
assertTrue(result.getCategories().isEmpty());
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 4.0, 2.0), "a"));
result = instance.learn(data);
assertEquals("a", result.evaluate(data.get(0).getInput()));
assertEquals(1, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 2.0), "a"));
result = instance.learn(data);
assertEquals("a", result.evaluate(data.get(0).getInput()));
assertEquals("a", result.evaluate(data.get(1).getInput()));
assertEquals(1, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 2.0, 3.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 4.0, 4.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 3.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 0.0, 2.0), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 5.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 7.0, 2.0), "b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 8.0, 2.0), "b"));
result = instance.learn(data);
for ( InputOutputPair<Vector3, String> example : data )
{
assertEquals(example.getOutput(), result.evaluate(example.getInput()));
}
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "a"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "b"));
result = instance.learn(data);
assertEquals("a", result.evaluate(new Vector3(1.0, 1.0, 1.0)));
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
data.add(new DefaultInputOutputPair<Vector3, String>(new Vector3(1.0, 1.0, 1.0), "b"));
result = instance.learn(data);
assertEquals("b", result.evaluate(new Vector3(1.0, 1.0, 1.0)));
assertEquals(2, result.getCategories().size());
assertTrue(result.getCategories().contains("a"));
assertTrue(result.getCategories().contains("b"));
}
}