/* * File: RejectionSamplingTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Mar 3, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.statistics.bayesian; import gov.sandia.cognition.statistics.bayesian.conjugate.BernoulliBayesianEstimator; import gov.sandia.cognition.statistics.DataDistribution; import gov.sandia.cognition.statistics.distribution.BernoulliDistribution; import gov.sandia.cognition.statistics.distribution.BetaDistribution; import gov.sandia.cognition.statistics.distribution.UniformDistribution; import gov.sandia.cognition.statistics.distribution.UnivariateGaussian; import java.util.ArrayList; import junit.framework.TestCase; import java.util.Random; /** * Unit tests for RejectionSamplingTest. * * @author krdixon */ public class RejectionSamplingTest extends TestCase { /** * Random number generator to use for a fixed random seed. */ public final Random RANDOM = new Random( 2 ); /** * Default tolerance of the regression tests, {@value}. */ public final double TOLERANCE = 1e-5; /** * Tests for class RejectionSamplingTest. * @param testName Name of the test. */ public RejectionSamplingTest( String testName) { super(testName); } /** * Tests the constructors of class RejectionSamplingTest. */ public void testConstructors() { System.out.println( "Constructors" ); RejectionSampling<Integer,Double> instance = new RejectionSampling<Integer,Double>(); assertEquals( RejectionSampling.DEFAULT_NUM_SAMPLES, instance.getNumSamples() ); assertNull( instance.getRandom() ); assertNull( instance.getUpdater() ); } /** * Test of clone method, of class RejectionSampling. */ public void testClone() { System.out.println("clone"); RejectionSampling<Number,Double> instance = new RejectionSampling<Number, Double>(); instance.setNumSamples(11); UniformDistribution.PDF prior = new UniformDistribution.PDF( 0.0, 1.0 ); BernoulliDistribution.PMF conditional = new BernoulliDistribution.PMF(); instance.setUpdater( new RejectionSampling.DefaultUpdater<Number, Double>( new DefaultBayesianParameter<Double,BernoulliDistribution.PMF,UniformDistribution.PDF>( conditional, "p", prior ) ) ); instance.setRandom(RANDOM); RejectionSampling<Number,Double> clone = instance.clone(); assertNotSame( instance, clone ); assertNotSame( instance.getRandom(), clone.getRandom() ); assertNotNull( clone.getUpdater() ); assertNotSame( instance.getUpdater(), clone.getUpdater() ); assertEquals( instance.getNumSamples(), clone.getNumSamples() ); } public void testBernoulliInference() { System.out.println( "Bernoulli Inference" ); double p = 0.75; BernoulliDistribution.PMF target = new BernoulliDistribution.PMF(p); final int numSamples = 100; ArrayList<Number> samples = target.sample(RANDOM, numSamples); RejectionSampling<Number,Double> instance = new RejectionSampling<Number, Double>(); instance.setNumSamples(numSamples); UniformDistribution.PDF prior = new UniformDistribution.PDF( 0.0, 1.0 ); BernoulliDistribution.PMF conditional = new BernoulliDistribution.PMF(); instance.setUpdater( new RejectionSampling.DefaultUpdater<Number, Double>( new DefaultBayesianParameter<Double,BernoulliDistribution.PMF,UniformDistribution.PDF>( conditional, "p", prior ) ) ); instance.setRandom(RANDOM); DataDistribution<Double> berns = instance.learn(samples); ArrayList<Double> ps = new ArrayList<Double>( berns.getDomain().size() ); for( Double b : berns.getDomain() ) { ps.add( b ); } UnivariateGaussian presult = UnivariateGaussian.MaximumLikelihoodEstimator.learn(ps, 0.0); System.out.println( "Proposals: " + ((RejectionSampling.DefaultUpdater) instance.getUpdater()).getProposals() ); System.out.println( "P: " + presult ); // Run this through a Conjugate Prior Bayesian Estimator to see // what the answer "should" (with uniform prior) BernoulliBayesianEstimator bbe = new BernoulliBayesianEstimator(); BetaDistribution posterior = bbe.learn(samples); System.out.println( "Beta: Mean = " + posterior.getMean() + ", Variance = " + posterior.getVariance() ); } /** * computeGaussianSampler */ public void testComputeGaussianSampler() { System.out.println( "computeGaussianSampler" ); double p = 0.75; BernoulliDistribution.PMF target = new BernoulliDistribution.PMF(p); final int numSamples = 100; ArrayList<Number> samples = target.sample(RANDOM, numSamples); RejectionSampling<Number,Double> instance = new RejectionSampling<Number, Double>(); instance.setNumSamples(numSamples); UniformDistribution.PDF prior = new UniformDistribution.PDF( 0.0, 1.0 ); BernoulliDistribution.PMF conditional = new BernoulliDistribution.PMF(); BayesianParameter<Double,BernoulliDistribution.PMF,UniformDistribution.PDF> conjunctive = new DefaultBayesianParameter<Double,BernoulliDistribution.PMF,UniformDistribution.PDF>( conditional, "p", prior ); RejectionSampling.DefaultUpdater<Number,Double> updater = new RejectionSampling.DefaultUpdater<Number, Double>( conjunctive ); UnivariateGaussian.PDF sampler = updater.computeGaussianSampler(samples, RANDOM, 100); System.out.println( "Sampler = " + sampler ); updater.setSampler(sampler); instance.setUpdater(updater); instance.setRandom(RANDOM); DataDistribution<Double> berns = instance.learn(samples); ArrayList<Double> ps = new ArrayList<Double>( berns.getDomain().size() ); for( Double b : berns.getDomain() ) { ps.add( b ); } UnivariateGaussian presult = UnivariateGaussian.MaximumLikelihoodEstimator.learn(ps, 0.0); System.out.println( "Proposals: " + ((RejectionSampling.DefaultUpdater) instance.getUpdater()).getProposals() ); System.out.println( "P: " + presult ); // Run this through a Conjugate Prior Bayesian Estimator to see // what the answer "should" (with uniform prior) BernoulliBayesianEstimator bbe = new BernoulliBayesianEstimator(); BetaDistribution posterior = bbe.learn(samples); System.out.println( "Beta: Mean = " + posterior.getMean() + ", Variance = " + posterior.getVariance() ); } }