/* * File: WinnerTakeAllCategorizerTest.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright July 22, 2009, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. * See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.learning.function.categorization; import gov.sandia.cognition.evaluator.Evaluator; import gov.sandia.cognition.learning.algorithm.regression.MultivariateLinearRegression; import gov.sandia.cognition.learning.data.DefaultInputOutputPair; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.learning.function.vector.MultivariateDiscriminant; import gov.sandia.cognition.learning.function.vector.VectorizableVectorConverter; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.Vectorizable; import gov.sandia.cognition.math.matrix.mtj.Vector3; import gov.sandia.cognition.util.WeightedValue; import java.util.ArrayList; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.Set; import junit.framework.TestCase; /** * Unit tests for class WinnerTakeAllCategorizer. * * @author Justin Basilico * @since 3.0 */ public class WinnerTakeAllCategorizerTest extends TestCase { /** * Creates a new test. * * @param testName The test name. */ public WinnerTakeAllCategorizerTest( String testName) { super(testName); } public void testConstructors() { Evaluator<Vectorizable, Vector> evaluator = null; WinnerTakeAllCategorizer<Vector, String> instance = new WinnerTakeAllCategorizer<Vector, String>(); assertSame(evaluator, instance.getEvaluator()); assertTrue(instance.getCategories().isEmpty()); evaluator = new VectorizableVectorConverter(); Set<String> categories = new HashSet<String>(); categories.add("a"); instance = new WinnerTakeAllCategorizer<Vector, String>(evaluator, categories); assertSame(evaluator, instance.getEvaluator()); assertSame(categories, instance.getCategories()); } /** * Test of evaluate method, of class WinnerTakeAllCategorizer. */ public void testEvaluate() {Set<String> categories = new LinkedHashSet<String>(); categories.add("a"); categories.add("b"); categories.add("c"); WinnerTakeAllCategorizer<Vector, String> instance = new WinnerTakeAllCategorizer<Vector, String>( new VectorizableVectorConverter(), categories); assertEquals("a", instance.evaluate(new Vector3(1.0, 0.0, 0.0))); assertEquals("b", instance.evaluate(new Vector3(0.0, 0.1, 0.0))); assertEquals("c", instance.evaluate(new Vector3(0.0, 0.1, 20.0))); assertEquals("a", instance.evaluate(new Vector3(0.0, 0.0, 0.0))); } /** * Test of evaluateWithDiscriminant method, of class WinnerTakeAllCategorizer. */ public void testEvaluateWithDiscriminant() { Set<String> categories = new LinkedHashSet<String>(); categories.add("a"); categories.add("b"); categories.add("c"); WinnerTakeAllCategorizer<Vector, String> instance = new WinnerTakeAllCategorizer<Vector, String>( new VectorizableVectorConverter(), categories); WeightedValue<String> result = instance.evaluateWithDiscriminant(new Vector3(1.0, 0.0, 0.0)); assertEquals("a", result.getValue()); assertEquals(1.0, result.getWeight()); result = instance.evaluateWithDiscriminant(new Vector3(0.0, 0.1, 0.0)); assertEquals("b", result.getValue()); assertEquals(0.1, result.getWeight()); result = instance.evaluateWithDiscriminant(new Vector3(0.0, 0.1, 0.2)); assertEquals("c", result.getValue()); assertEquals(0.2, result.getWeight()); result = instance.evaluateWithDiscriminant(new Vector3(0.0, 0.0, 0.0)); assertEquals("a", result.getValue()); assertEquals(0.0, result.getWeight()); } /** * Test of findBestCategory method, of class WinnerTakeAllCategorizer. */ public void testFindBestCategory() { Set<String> categories = new LinkedHashSet<String>(); categories.add("a"); categories.add("b"); categories.add("c"); WinnerTakeAllCategorizer<Vector, String> instance = new WinnerTakeAllCategorizer<Vector, String>( null, categories); WeightedValue<String> result = instance.findBestCategory(new Vector3(1.0, 0.0, 0.0)); assertEquals("a", result.getValue()); assertEquals(1.0, result.getWeight()); result = instance.findBestCategory(new Vector3(0.0, 0.1, 0.0)); assertEquals("b", result.getValue()); assertEquals(0.1, result.getWeight()); result = instance.findBestCategory(new Vector3(0.0, 0.1, 0.2)); assertEquals("c", result.getValue()); assertEquals(0.2, result.getWeight()); result = instance.findBestCategory(new Vector3(0.0, 0.0, 0.0)); assertEquals("a", result.getValue()); assertEquals(0.0, result.getWeight()); } /** * Test of getEvaluator method, of class WinnerTakeAllCategorizer. */ public void testGetEvaluator() { this.testSetEvaluator(); } /** * Test of setEvaluator method, of class WinnerTakeAllCategorizer. */ public void testSetEvaluator() { Evaluator<Vectorizable, Vector> evaluator = null; WinnerTakeAllCategorizer<Vector, String> instance = new WinnerTakeAllCategorizer<Vector, String>(); assertSame(evaluator, instance.getEvaluator()); evaluator = new VectorizableVectorConverter(); instance.setEvaluator(evaluator); assertSame(evaluator, instance.getEvaluator()); evaluator = new VectorizableVectorConverter(); instance.setEvaluator(evaluator); assertSame(evaluator, instance.getEvaluator()); evaluator = null; instance.setEvaluator(evaluator); assertSame(evaluator, instance.getEvaluator()); } /** * Test of setCategories method, of class WinnerTakeAllCategorizer. */ public void testSetCategories() { Set<String> categories = null; WinnerTakeAllCategorizer<Vector, String> instance = new WinnerTakeAllCategorizer<Vector, String>(); assertTrue(instance.getCategories().isEmpty()); categories = new HashSet<String>(); instance.setCategories(categories); assertSame(categories, instance.getCategories()); categories = new LinkedHashSet<String>(); categories.add("a"); instance.setCategories(categories); assertSame(categories, instance.getCategories()); categories = null; instance.setCategories(categories); assertSame(categories, instance.getCategories()); } /** * Test of Learner class of class WinnerTakeAllCategorizer. */ public void testLearner() { WinnerTakeAllCategorizer.Learner<Vector, String> learner = new WinnerTakeAllCategorizer.Learner<Vector, String>(); learner.setLearner(new MultivariateLinearRegression()); ArrayList<InputOutputPair<Vector, String>> training = new ArrayList<InputOutputPair<Vector, String>>(); training.add(new DefaultInputOutputPair<Vector, String>(new Vector3(1.0, 0.0, 0.0), "a")); training.add(new DefaultInputOutputPair<Vector, String>(new Vector3(0.0, 2.0, 0.0), "b")); training.add(new DefaultInputOutputPair<Vector, String>(new Vector3(0.0, 0.0, 3.0), "c")); WinnerTakeAllCategorizer<Vector, String> instance = learner.learn(training); assertTrue(instance.getEvaluator() instanceof MultivariateDiscriminant); assertEquals(3, instance.getCategories().size()); assertTrue(instance.getCategories().contains("a")); assertTrue(instance.getCategories().contains("b")); assertTrue(instance.getCategories().contains("c")); assertEquals("a", instance.evaluate(new Vector3(1.0, 0.0, 0.0))); assertEquals("b", instance.evaluate(new Vector3(0.0, 1.0, 0.0))); assertEquals("c", instance.evaluate(new Vector3(0.0, 0.0, 1.0))); assertEquals("a", instance.evaluate(new Vector3(0.0, 0.0, 0.0))); assertEquals("a", instance.evaluate(new Vector3(0.0, -1.0, -1.0))); assertEquals("a", instance.evaluate(new Vector3(1.0, 1.0, 1.0))); } }