/* * File: CategoricalDistributionTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright May 24, 2011, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.statistics.distribution; import gov.sandia.cognition.collection.CollectionUtil; import java.util.Set; import gov.sandia.cognition.statistics.MultivariateClosedFormComputableDiscreteDistributionTestHarness; import gov.sandia.cognition.math.matrix.Vector; import gov.sandia.cognition.math.matrix.VectorFactory; import org.junit.Test; import static org.junit.Assert.*; /** * Tests for class CategoricalDistributionTest. * @author krdixon */ public class CategoricalDistributionTest extends MultivariateClosedFormComputableDiscreteDistributionTestHarness<Vector> { /** * Default Constructor */ public CategoricalDistributionTest() { super( "test" ); } @Override public void testConstructors() { System.out.println( "Constructors" ); CategoricalDistribution instance = new CategoricalDistribution(); assertEquals( CategoricalDistribution.DEFAULT_NUM_CLASSES, instance.getInputDimensionality() ); assertNotNull( instance.getParameters() ); int dim = 3; instance = new CategoricalDistribution( dim ); assertEquals( dim, instance.getInputDimensionality() ); assertTrue( instance.getParameters().norm1() > 0.0 ); Vector p = VectorFactory.getDefault().createVector(dim, RANDOM.nextDouble()); instance = new CategoricalDistribution(p); assertSame( p, instance.getParameters() ); CategoricalDistribution i2 = new CategoricalDistribution( instance ); assertNotSame( i2, instance ); assertNotSame( instance.getParameters(), i2.getParameters() ); assertEquals( instance.getParameters(), i2.getParameters() ); } @Override public void testProbabilityFunctionConstructors() { System.out.println( "PMF constructors" ); CategoricalDistribution.PMF instance = new CategoricalDistribution.PMF(); assertEquals( CategoricalDistribution.DEFAULT_NUM_CLASSES, instance.getInputDimensionality() ); assertNotNull( instance.getParameters() ); int dim = 3; instance = new CategoricalDistribution.PMF( dim ); assertEquals( dim, instance.getInputDimensionality() ); assertTrue( instance.getParameters().norm1() > 0.0 ); Vector p = VectorFactory.getDefault().createVector(dim, RANDOM.nextDouble()); instance = new CategoricalDistribution.PMF(p); assertSame( p, instance.getParameters() ); CategoricalDistribution.PMF i2 = new CategoricalDistribution.PMF( instance ); assertNotSame( i2, instance ); assertNotSame( instance.getParameters(), i2.getParameters() ); assertEquals( instance.getParameters(), i2.getParameters() ); } /** * Test of getParameters method, of class CategoricalDistribution. */ @Test public void testGetParameters() { System.out.println("getParameters"); Vector p = VectorFactory.getDefault().createVector(3,RANDOM.nextDouble()); CategoricalDistribution instance = new CategoricalDistribution(p); assertSame( p, instance.getParameters() ); } /** * Test of setParameters method, of class CategoricalDistribution. */ @Test public void testSetParameters() { System.out.println("setParameters"); CategoricalDistribution instance = this.createInstance(); Vector p = instance.getParameters().clone(); instance.setParameters(p); assertSame( p, instance.getParameters() ); Vector scale = p.scale(-1.0); try { instance.setParameters(scale); fail( "Elements must be positive" ); } catch (Exception e) { System.out.println( "Good: " + e ); } Vector p1 = VectorFactory.getDefault().createVector(1,1.0); try { instance.setParameters(p1); fail( "dim must be > 1" ); } catch (Exception e) { System.out.println( "Good: " + e ); } } @Override public CategoricalDistribution createInstance() { int dim = 3; DirichletDistribution d = new DirichletDistribution( dim ); return new CategoricalDistribution( d.sample(RANDOM) ); } @Override public void testKnownGetDomain() { int dim = 3; CategoricalDistribution instance = new CategoricalDistribution( dim ); Set<Vector> domain = instance.getDomain(); assertEquals( dim, domain.size() ); for( int i = 0; i < dim; i++ ) { Vector v = VectorFactory.getDefault().createVector(dim); v.setElement(i, 1.0); assertTrue( domain.contains(v) ); } } @Override public void testProbabilityFunctionKnownValues() { System.out.println( "PMF Known Values" ); CategoricalDistribution.PMF instance = this.createInstance().getProbabilityFunction(); Vector p = instance.getParameters(); double psum = p.norm1(); int dim = instance.getInputDimensionality(); for( int i = 0; i < dim; i++ ) { Vector x = VectorFactory.getDefault().createVector(dim); x.setElement(i, 1.0); assertEquals( p.getElement(i)/psum, instance.evaluate(x) ); } Vector x = VectorFactory.getDefault().createVector(dim+1); try { instance.evaluate(x); fail( "Vector input is wrong size!" ); } catch (Exception e) { System.out.println( "Good: " + e ); } x = VectorFactory.getDefault().createVector(dim); try { instance.evaluate(x); fail( "input is all zeros!" ); } catch (Exception e) { System.out.println( "Good: " + e ); } x.setElement(0, RANDOM.nextDouble() ); try { instance.evaluate(x); fail( "input isn't 1.0 or 0.0" ); } catch (Exception e) { System.out.println( "Good: " + e ); } x.setElement(0, 1.0); x.setElement(1, 1.0); try { instance.evaluate(x); fail( "input has more than one 1.0" ); } catch (Exception e) { System.out.println( "Good: " + e ); } } @Override public void testKnownConvertToVector() { System.out.println( "Known ConvertToVector" ); CategoricalDistribution instance = this.createInstance(); Vector p = instance.getParameters(); assertEquals( p, instance.convertToVector() ); assertNotSame( p, instance.convertToVector() ); } @Override public void testKnownValues() { System.out.println( "known values" ); this.testProbabilityFunctionKnownValues(); } }