/* * File: LinearMultiCategorizerTest.java * Authors: Justin Basilico * Company: Sandia National Laboratories * Project: Cognitive Foundry Learning Core * * Copyright April 21, 2011, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. Export * of this program may require a license from the United States Government. */ package gov.sandia.cognition.learning.function.categorization; import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.mtj.Vector2; import java.util.LinkedHashMap; import gov.sandia.cognition.math.matrix.VectorFactory; import java.util.Random; import java.util.Arrays; import java.util.HashSet; import java.util.Map; import java.util.TreeMap; import org.junit.Test; import static org.junit.Assert.*; /** * Unit tests for class LinearMultiCategorizer. * * @author Justin Basilico * @since 3.2.0 */ public class LinearMultiCategorizerTest { /** Random number generator. */ protected Random random = new Random(211); /** * Creates a new test. */ public LinearMultiCategorizerTest() { super(); } /** * Test of constructors of class LinearMultiCategorizer. */ @Test public void testConstructors() { LinearMultiCategorizer<String> instance = new LinearMultiCategorizer<String>(); assertNotNull(instance.getPrototypes()); assertTrue(instance.getPrototypes().isEmpty()); Map<String, LinearBinaryCategorizer> prototypes = new TreeMap<String, LinearBinaryCategorizer>(); instance = new LinearMultiCategorizer<String>(prototypes); assertSame(prototypes, instance.getPrototypes()); } /** * Test of clone method, of class LinearMultiCategorizer. */ @Test public void testClone() { LinearMultiCategorizer<String> instance = new LinearMultiCategorizer<String>(); LinearMultiCategorizer<String> clone = instance.clone(); assertNotSame(instance, clone); assertNotSame(clone, instance.clone()); assertNotSame(instance.getPrototypes(), clone.getPrototypes()); Vector w = new Vector2(); double b = random.nextDouble(); LinearBinaryCategorizer a = new LinearBinaryCategorizer(w, b); instance.getPrototypes().put("a", a); clone = instance.clone(); assertNotSame(instance, clone); assertNotSame(instance.getPrototypes(), clone.getPrototypes()); assertNotNull(instance.getPrototypes().get("a")); assertNotSame(a, clone.getPrototypes().get("a")); assertEquals(w, clone.getPrototypes().get("a").getWeights()); assertNotSame(w, clone.getPrototypes().get("a").getWeights()); assertEquals(b, clone.getPrototypes().get("a").getBias(), 0.0); } /** * Test of evaluate method, of class LinearMultiCategorizer. */ @Test public void testEvaluate() { Vector x = new Vector2(1.0, -1.0); LinearMultiCategorizer<String> instance = new LinearMultiCategorizer<String>(); assertEquals(null, instance.evaluate(x)); instance.getPrototypes().put("a", new LinearBinaryCategorizer( new Vector2(1.0, 0.0), 1.0)); assertEquals("a", instance.evaluate(x)); instance.getPrototypes().put("b", new LinearBinaryCategorizer( new Vector2(-1.0, 0.0), 0.0)); instance.getPrototypes().put("c", new LinearBinaryCategorizer( new Vector2(-1.0, 4.0), -5.0)); instance.getPrototypes().put("d", new LinearBinaryCategorizer( new Vector2(0.0, 0.0), 0.0)); assertEquals("a", instance.evaluate(x)); assertEquals("b", instance.evaluate(new Vector2(-1.0, 1.0))); assertEquals("c", instance.evaluate(new Vector2(-1.0, 10.0))); } /** * Test of evaluateWithDiscriminant method, of class LinearMultiCategorizer. */ @Test public void testEvaluateWithDiscriminant() { Vector x = new Vector2(1.0, -1.0); LinearMultiCategorizer<String> instance = new LinearMultiCategorizer<String>(); DefaultWeightedValueDiscriminant<String> result = instance.evaluateWithDiscriminant(x); assertEquals(null, result.getValue()); assertEquals(0.0, result.getDiscriminant(), 0.0); instance.getPrototypes().put("a", new LinearBinaryCategorizer( new Vector2(1.0, 0.0), 1.0)); result = instance.evaluateWithDiscriminant(x); assertEquals("a", result.getValue()); assertEquals(2.0, result.getDiscriminant(), 0.0); instance.getPrototypes().put("b", new LinearBinaryCategorizer( new Vector2(-1.0, 0.0), 0.0)); instance.getPrototypes().put("c", new LinearBinaryCategorizer( new Vector2(-1.0, 4.0), -5.0)); instance.getPrototypes().put("d", new LinearBinaryCategorizer( new Vector2(0.0, 0.0), 0.0)); result = instance.evaluateWithDiscriminant(x); assertEquals("a", result.getValue()); assertEquals(2.0, result.getDiscriminant(), 0.0); result = instance.evaluateWithDiscriminant(new Vector2(-1.0, 1.0)); assertEquals("b", result.getValue()); assertEquals(1.0, result.getDiscriminant(), 0.0); result = instance.evaluateWithDiscriminant(new Vector2(-1.0, 10.0)); assertEquals("c", result.getValue()); assertEquals(36.0, result.getDiscriminant(), 0.0); } /** * Test of evaluateAsDouble method, of class LinearMultiCategorizer. */ @Test public void testEvaluateAsDouble() { Vector x = new Vector2(1.0, -1.0); LinearMultiCategorizer<String> instance = new LinearMultiCategorizer<String>(); assertEquals(0.0, instance.evaluateAsDouble(x, "a"), 0.0); instance.getPrototypes().put("a", new LinearBinaryCategorizer( new Vector2(1.0, 0.0), 1.0)); assertEquals(2.0, instance.evaluateAsDouble(x, "a"), 0.0); instance.getPrototypes().put("b", new LinearBinaryCategorizer( new Vector2(-1.0, 0.0), 0.0)); instance.getPrototypes().put("c", new LinearBinaryCategorizer( new Vector2(-1.0, 4.0), -5.0)); instance.getPrototypes().put("d", new LinearBinaryCategorizer( new Vector2(0.0, 0.0), 0.0)); assertEquals(2.0, instance.evaluateAsDouble(x, "a"), 0.0); assertEquals(-1.0, instance.evaluateAsDouble(x, "b"), 0.0); assertEquals(-10.0, instance.evaluateAsDouble(x, "c"), 0.0); assertEquals(0.0, instance.evaluateAsDouble(x, "d"), 0.0); } /** * Test of getCategories method, of class LinearMultiCategorizer. */ @Test public void testGetCategories() { LinearMultiCategorizer<String> instance = new LinearMultiCategorizer<String>(); assertTrue(instance.getCategories().isEmpty()); instance.getPrototypes().put("a", new LinearBinaryCategorizer()); instance.getPrototypes().put("b", new LinearBinaryCategorizer()); instance.getPrototypes().put("c", new LinearBinaryCategorizer()); assertEquals(new HashSet<String>(Arrays.asList("a", "b", "c")), instance.getCategories()); } /** * Test of getInputDimensionality method, of class LinearMultiCategorizer. */ @Test public void testGetInputDimensionality() { LinearMultiCategorizer<String> instance = new LinearMultiCategorizer<String>(); assertEquals(-1, instance.getInputDimensionality()); instance.getPrototypes().put("a", new LinearBinaryCategorizer()); assertEquals(-1, instance.getInputDimensionality()); int d = 1 + random.nextInt(100); instance.getPrototypes().get("a").setWeights( VectorFactory.getDefault().createVector(d)); assertEquals(d, instance.getInputDimensionality()); } /** * Test of getPrototypes method, of class LinearMultiCategorizer. */ @Test public void testGetPrototypes() { this.testSetPrototypes(); } /** * Test of setPrototypes method, of class LinearMultiCategorizer. */ @Test public void testSetPrototypes() { LinearMultiCategorizer<String> instance = new LinearMultiCategorizer<String>(); assertNotNull(instance.getPrototypes()); assertTrue(instance.getPrototypes().isEmpty()); Map<String, LinearBinaryCategorizer> prototypes = new LinkedHashMap<String, LinearBinaryCategorizer>(); instance.setPrototypes(prototypes); assertSame(prototypes, instance.getPrototypes()); prototypes = null; instance.setPrototypes(prototypes); assertSame(prototypes, instance.getPrototypes()); prototypes = new TreeMap<String, LinearBinaryCategorizer>(); instance.setPrototypes(prototypes); assertSame(prototypes, instance.getPrototypes()); } }