/*
* File: DirichletProcessMixtureModelTest.java
* Authors: Kevin R. Dixon
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright May 2, 2010, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.statistics.bayesian;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.MatrixFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.DataDistribution;
import gov.sandia.cognition.statistics.bayesian.conjugate.MultivariateGaussianMeanBayesianEstimator;
import gov.sandia.cognition.statistics.distribution.MultivariateGaussian;
import gov.sandia.cognition.statistics.distribution.ScalarDataDistribution;
import java.util.ArrayList;
import java.util.Iterator;
import junit.framework.TestCase;
import java.util.Random;
/**
* Unit tests for DirichletProcessMixtureModelTest.
*
* @author krdixon
*/
public class DirichletProcessMixtureModelTest
extends TestCase
{
/**
* Random number generator to use for a fixed random seed.
*/
public final Random RANDOM = new Random( 1 );
/**
* Default tolerance of the regression tests, {@value}.
*/
public final double TOLERANCE = 1e-5;
/**
* Number of samples
*/
public int NUM_SAMPLES = 5;
/**
* Dimensionality
*/
public int DIM = 2;
/**
* Tests for class DirichletProcessMixtureModelTest.
* @param testName Name of the test.
*/
public DirichletProcessMixtureModelTest(
String testName)
{
super(testName);
}
/**
* Instance
* @return
* Instance
*/
public DirichletProcessMixtureModel<Vector> createInstance()
{
DirichletProcessMixtureModel<Vector> instance =
new DirichletProcessMixtureModel<Vector>();
instance.setUpdater( new DirichletProcessMixtureModel.MultivariateMeanCovarianceUpdater(DIM) );
instance.setMaxIterations(100);
instance.setBurnInIterations(instance.getMaxIterations());
instance.setIterationsPerSample(2);
instance.setRandom(RANDOM);
return instance;
}
/**
* Tests the constructors of class DirichletProcessMixtureModelTest.
*/
public void testConstructors()
{
System.out.println( "Constructors" );
DirichletProcessMixtureModel<Vector> instance =
new DirichletProcessMixtureModel<Vector>();
assertNotNull( instance );
assertNull( instance.getUpdater() );
assertNull( instance.getRandom() );
assertTrue( instance.getIterationsPerSample() >= 1 );
assertTrue( instance.getNumInitialClusters() > 1 );
}
/**
* Test of clone method, of class DirichletProcessMixtureModel.
*/
public void testClone()
{
System.out.println("clone");
DirichletProcessMixtureModel<Vector> instance = this.createInstance();
instance.setMaxIterations(10);
instance.setIterationsPerSample(1);
instance.setBurnInIterations(1);
DirichletProcessMixtureModel<Vector> clone = instance.clone();
assertNotSame( instance, clone );
assertNotNull( clone );
assertNotSame( instance.getUpdater(), clone.getUpdater() );
double m = 1.0;
double s = 0.1;
int N = 2;
ArrayList<Vector> samples = new ArrayList<Vector>( NUM_SAMPLES*N );
for( int n = 0; n < N; n++ )
{
Vector mean = VectorFactory.getDefault().createVector(DIM, m*n);
Matrix C = MatrixFactory.getDefault().createIdentity(DIM, DIM).scale(s*s);
MultivariateGaussian g = new MultivariateGaussian(mean, C);
samples.addAll( g.sample(RANDOM, 2 ) );
}
Random r1 = new Random(1);
Random r2 = new Random(1);
instance.setRandom(r1);
clone.setRandom(r2);
DataDistribution<DirichletProcessMixtureModel.Sample<Vector>> d1 =
instance.learn(samples);
DataDistribution<DirichletProcessMixtureModel.Sample<Vector>> d2 =
clone.learn(samples);
assertEquals( d1.getTotal(), d2.getTotal() );
Iterator<? extends DirichletProcessMixtureModel.Sample<Vector>> i1 =
d1.getDomain().iterator();
Iterator<? extends DirichletProcessMixtureModel.Sample<Vector>> i2 =
d2.getDomain().iterator();
while( i1.hasNext() )
{
DirichletProcessMixtureModel.Sample<Vector> s1 = i1.next();
DirichletProcessMixtureModel.Sample<Vector> s2 = i2.next();
assertNotSame( s1, s2 );
assertEquals( s1.getAlpha(), s2.getAlpha(), TOLERANCE );
assertEquals( s1.getNumClusters(), s2.getNumClusters() );
}
}
/**
* Tests learn
*/
public void testLearn()
{
System.out.println( "Learn" );
DirichletProcessMixtureModel<Vector> instance = this.createInstance();
// Serial / Parallel (Thread=1)
// Best: ll = 1768.1222563043032, k = 4, alpha = 0.04826570450192344
double m = 1.0;
double s = 0.1;
int N = 4;
ArrayList<Vector> samples = new ArrayList<Vector>( NUM_SAMPLES*N );
for( int n = 0; n < N; n++ )
{
Vector mean = VectorFactory.getDefault().createVector(DIM, m*n);
Matrix C = MatrixFactory.getDefault().createIdentity(DIM, DIM).scale(s*s);
MultivariateGaussian g = new MultivariateGaussian(mean, C);
samples.addAll( g.sample(RANDOM, NUM_SAMPLES ) );
}
long start = System.currentTimeMillis();
DataDistribution<DirichletProcessMixtureModel.Sample<Vector>> results =
instance.learn(samples);
long stop = System.currentTimeMillis();
System.out.println( "Time taken: " + (stop-start)/1000.0);
ScalarDataDistribution.PMF ks = new ScalarDataDistribution.PMF();
DirichletProcessMixtureModel.Sample<Vector> bestSample = null;
double maxLL = Double.NEGATIVE_INFINITY;
int maxIndex = -1;
int index = 0;
for( DirichletProcessMixtureModel.Sample<Vector> result : results.getDomain() )
{
ks.increment( (double) result.getNumClusters() );
Double ll = result.getPosteriorLogLikelihood();
double actualLL = result.computePosteriorLogLikelihood(samples);
if( ll != null )
{
assertEquals( index + ": expected " + actualLL + ", got: " + ll, actualLL, ll, TOLERANCE );
}
if( (ll != null) && (maxLL < ll) )
{
maxIndex = index;
maxLL = ll;
bestSample = result;
}
index++;
}
for( Double k : ks.getDomain() )
{
double pk = ks.evaluate(k);
if( pk > 0.0 )
{
System.out.println( "p(" + k + "):" + pk );
}
}
System.out.println( "Mean k = " + ks.getMean() );
System.out.println( "Best: " + maxIndex + ": ll = " + maxLL + ", k = " + bestSample.getNumClusters() + ", alpha = " + bestSample.getAlpha() );
for( int i = 0; i < bestSample.getNumClusters(); i++ )
{
System.out.println( "Members = " + bestSample.getClusters().get(i).getMembers().size() );
System.out.println( "PDF =\n" + bestSample.getClusters().get(i).getProbabilityFunction() );
}
}
/**
* Tests learn
*/
public void testLearnConstantVariance()
{
System.out.println( "Learn Constant Variance" );
double m = 1.0;
double s = 0.1;
int N = 4;
DirichletProcessMixtureModel<Vector> instance = this.createInstance();
Matrix Ci = MatrixFactory.getDefault().createIdentity(DIM, DIM).scale(s*s).inverse();
DirichletProcessMixtureModel.MultivariateMeanUpdater updater =
new DirichletProcessMixtureModel.MultivariateMeanUpdater(
new MultivariateGaussianMeanBayesianEstimator( Ci ) );
instance.setUpdater( updater );
ArrayList<Vector> samples = new ArrayList<Vector>( NUM_SAMPLES*N );
for( int n = 0; n < N; n++ )
{
Vector mean = VectorFactory.getDefault().createVector(DIM, m*n);
Matrix C = MatrixFactory.getDefault().createIdentity(DIM, DIM).scale(s*s);
MultivariateGaussian g = new MultivariateGaussian(mean, C);
samples.addAll( g.sample(RANDOM, NUM_SAMPLES ) );
}
long start = System.currentTimeMillis();
DataDistribution<DirichletProcessMixtureModel.Sample<Vector>> results =
instance.learn(samples);
long stop = System.currentTimeMillis();
System.out.println( "Time taken: " + (stop-start)/1000.0);
ScalarDataDistribution.PMF ks = new ScalarDataDistribution.PMF();
DirichletProcessMixtureModel.Sample<Vector> bestSample = null;
double maxLL = Double.NEGATIVE_INFINITY;
int maxIndex = -1;
int index = 0;
for( DirichletProcessMixtureModel.Sample<Vector> result : results.getDomain() )
{
ks.increment( (double) result.getNumClusters() );
Double ll = result.getPosteriorLogLikelihood();
if( (ll != null) && (maxLL < ll) )
{
maxIndex = index;
maxLL = ll;
bestSample = result;
}
index++;
}
for( Double k : ks.getDomain() )
{
double pk = ks.evaluate(k);
if( pk > 0.0 )
{
System.out.println( "p(" + k + "):" + pk );
}
}
System.out.println( "Mean k = " + ks.getMean() );
System.out.println( "Best: " + maxIndex + ", ll = " + maxLL + ", k = " + bestSample.getNumClusters() + ", alpha = " + bestSample.getAlpha() );
for( int i = 0; i < bestSample.getNumClusters(); i++ )
{
System.out.println( "Members = " + bestSample.getClusters().get(i).getMembers().size() );
System.out.println( "PDF =\n" + bestSample.getClusters().get(i).getProbabilityFunction() );
}
}
}