/*
* File: ChineseRestaurantProcessTest.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright Feb 10, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.distribution;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.MultivariateClosedFormComputableDiscreteDistributionTestHarness;
import java.util.ArrayList;
import java.util.Set;
/**
* Unit tests for ChineseRestaurantProcessTest.
*
* @author krdixon
*/
public class ChineseRestaurantProcessTest
extends MultivariateClosedFormComputableDiscreteDistributionTestHarness<Vector>
{
/**
* Tests for class ChineseRestaurantProcessTest.
* @param testName Name of the test.
*/
public ChineseRestaurantProcessTest(
String testName)
{
super(testName);
}
/**
* Creates an instance.
* @return
* Instance.
*/
public ChineseRestaurantProcess createInstance()
{
double alpha = 1.0;
int numCustomers = 5;
return new ChineseRestaurantProcess(alpha, numCustomers);
}
/**
* Tests the constructors of class ChineseRestaurantProcessTest.
*/
public void testConstructors()
{
// System.out.println( "Constructors" );
// // TODO review the generated test code and remove the default call to fail.
// fail("The test case is a prototype.");
}
@Override
public void testProbabilityFunctionConstructors()
{
// throw new UnsupportedOperationException("Not supported yet.");
}
@Override
public void testPMFEvaluate()
{
// super.testPMFEvaluate();
}
@Override
public void testPMFChiSquare()
{
// super.testPMFChiSquare();
}
/**
* Test of getMean method, of class ChineseRestaurantProcess.
*/
@Override
public void testGetMean()
{
System.out.println("getMean");
ChineseRestaurantProcess instance = this.createInstance();
try
{
instance.getMean();
fail( "Mean throws Exception" );
}
catch (Exception e)
{
System.out.println( "Good: " + e );
}
}
/**
* Test of sample method, of class ChineseRestaurantProcess.
*/
@Override
public void testKnownValues()
{
System.out.println("sample");
ChineseRestaurantProcess instance = this.createInstance();
ArrayList<Vector> samples = instance.sample(RANDOM,20);
for( int n = 0; n < samples.size(); n++ )
{
System.out.println( n + ": Tables = " + samples.get(n).getDimensionality()
+ " Assignments = " + samples.get(n) );
}
}
/**
* Test of getAlpha method, of class ChineseRestaurantProcess.
*/
public void testGetAlpha()
{
System.out.println("getAlpha");
ChineseRestaurantProcess instance = this.createInstance();
assertTrue( instance.getAlpha() > 0.0 );
}
/**
* Test of setAlpha method, of class ChineseRestaurantProcess.
*/
public void testSetAlpha()
{
System.out.println("setAlpha");
ChineseRestaurantProcess instance = this.createInstance();
double alpha = RANDOM.nextDouble() * 10.0;
instance.setAlpha(alpha);
assertEquals( alpha, instance.getAlpha() );
try
{
instance.setAlpha(0.0);
fail( "Alpha must be > 0.0" );
}
catch (Exception e)
{
System.out.println( "Good: " + e );
}
}
/**
* Test of getNumCustomers method, of class ChineseRestaurantProcess.
*/
public void testGetNumCustomers()
{
System.out.println("getNumCustomers");
ChineseRestaurantProcess instance = this.createInstance();
assertTrue( instance.getNumCustomers() > 0 );
}
/**
* Test of setNumCustomers method, of class ChineseRestaurantProcess.
*/
public void testSetNumCustomers()
{
System.out.println("setNumCustomers");
ChineseRestaurantProcess instance = this.createInstance();
int numCustomers = instance.getNumCustomers() + 1;
instance.setNumCustomers(numCustomers);
assertEquals( numCustomers, instance.getNumCustomers() );
try
{
instance.setNumCustomers(0);
fail( "numCustomers must be > 0 " );
}
catch (Exception e)
{
System.out.println( "Good: " + e );
}
}
@Override
public void testPMFGetInputDimensionality()
{
// We need to nerf this test because our dimensionality varies...
// super.testPMFGetInputDimensionality();
}
@Override
public void testProbabilityFunctionKnownValues()
{
System.out.println( "ProbabilityFunctionKnownValues" );
// http://www.cs.princeton.edu/courses/archive/fall07/cos597C/scribe/20070921.pdf
Vector x = VectorFactory.getDefault().copyValues( 3.0, 4.0, 3.0 );
double alpha = RANDOM.nextDouble();
ChineseRestaurantProcess.PMF instance =
new ChineseRestaurantProcess( alpha, 10 ).getProbabilityFunction();
double expected = 1.0 * alpha/(1.0+alpha) * 1.0/(2.0+alpha) * alpha/(3.0+alpha) * 1.0/(4.0+alpha) * 1.0/(5.0+alpha)
* 2.0/(6.0+alpha) * 2.0/(7.0+alpha) * 2.0/(8.0+alpha) * 3.0/(9.0+alpha);
assertEquals( expected, instance.evaluate(x), 1e-10 );
instance.setAlpha(alpha);
instance.setNumCustomers(4);
expected = 1.0 * 1.0/(1.0+alpha) * 2.0/(2.0+alpha) * 3.0/(3.0+alpha);
x = VectorFactory.getDefault().copyValues(4.0);
assertEquals( expected, instance.evaluate(x), 1e-10 );
expected = 1.0 * alpha/(1.0+alpha) * alpha/(2.0+alpha) * alpha/(3.0+alpha);
x = VectorFactory.getDefault().copyValues( 1.0, 1.0, 1.0, 1.0 );
assertEquals( expected, instance.evaluate(x), 1e-10 );
// alpha = 1.5;
alpha = 1e-10;
instance.setAlpha(alpha);
instance.setNumCustomers(4);
x = VectorFactory.getDefault().copyValues( 3.0, 1.0 );
expected = 1.0 * 1.0/(1.0+alpha) * 2.0/(2.0+alpha) * alpha/(3.0+alpha);
System.out.println( "Expected: " + expected + ", x=" + x );
assertEquals( expected, instance.evaluate(x), 1e-10 );
}
@Override
public void testKnownConvertToVector()
{
System.out.println( "Known convertToVector" );
ChineseRestaurantProcess instance = this.createInstance();
Vector p = instance.convertToVector();
assertEquals( 2, p.getDimensionality() );
assertEquals( instance.getAlpha(), p.getElement(0) );
assertEquals( (double) instance.getNumCustomers(), p.getElement(1) );
}
@Override
public void testKnownGetDomain()
{
System.out.println( "Known Domain" );
ChineseRestaurantProcess.PMF instance =
this.createInstance().getProbabilityFunction();
instance.setAlpha(2.0);
Set<Vector> domain = instance.getDomain();
ArrayList<Vector> samples = instance.sample(RANDOM, 1000);
DefaultDataDistribution<Vector> hist = new DefaultDataDistribution<Vector>( instance.getNumCustomers() );
for( Vector sample : samples )
{
hist.increment( sample );
}
double sum = 0.0;
for( Vector d : domain )
{
double p = instance.evaluate(d);
if( p > 0.0 )
{
System.out.printf( "SAMPLE = %.4e, PMF = %.4e, ratio = %.1f ",
hist.getFraction(d), p, hist.getFraction(d)/p );
System.out.println( d );
}
else
{
// System.out.println( "BAD: " + d );
}
sum += p;
}
System.out.println( "Sum = " + sum );
}
}