/* * File: SamplingImportanceResamplingParticleFilterTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Feb 24, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.statistics.bayesian; import gov.sandia.cognition.statistics.bayesian.conjugate.BernoulliBayesianEstimator; import gov.sandia.cognition.statistics.Distribution; import gov.sandia.cognition.statistics.DataDistribution; import gov.sandia.cognition.statistics.distribution.BernoulliDistribution; import gov.sandia.cognition.statistics.distribution.BernoulliDistribution.PMF; import gov.sandia.cognition.statistics.distribution.BetaDistribution; import gov.sandia.cognition.statistics.distribution.GammaDistribution; import gov.sandia.cognition.statistics.distribution.LogNormalDistribution; import gov.sandia.cognition.statistics.distribution.DefaultDataDistribution; import gov.sandia.cognition.statistics.distribution.UnivariateGaussian; import gov.sandia.cognition.util.AbstractCloneableSerializable; import gov.sandia.cognition.util.DefaultWeightedValue; import gov.sandia.cognition.util.WeightedValue; import java.util.ArrayList; import java.util.Random; import junit.framework.TestCase; /** * Unit tests for SamplingImportanceResamplingParticleFilterTest. * * @author krdixon */ public class SamplingImportanceResamplingParticleFilterTest extends TestCase // extends RecursiveBayesianEstimatorTestHarness<Double, Double, ScalarDataDistribution> { /** * Random number generator to use for a fixed random seed. */ public Random RANDOM = new Random( 1 ); /** * Default tolerance of the regression tests, {@value}. */ public double TOLERANCE = 1e-5; /** * Default number of samples to draw, {@value}. */ public int NUM_SAMPLES = 100; /** * Tests for class SamplingImportanceResamplingParticleFilterTest. * @param testName Name of the test. */ public SamplingImportanceResamplingParticleFilterTest( String testName) { super(testName); } /** * Tests the constructors of class SamplingImportanceResamplingParticleFilterTest. */ public void testConstructors() { System.out.println( "Constructors" ); SamplingImportanceResamplingParticleFilter<Double,GammaDistribution.PDF> particleFilter = new SamplingImportanceResamplingParticleFilter<Double, GammaDistribution.PDF>(); assertNotNull( particleFilter ); } /** * clone */ public void testClone() { System.out.println( "Clone" ); SamplingImportanceResamplingParticleFilter<Double,GammaDistribution.PDF> particleFilter = new SamplingImportanceResamplingParticleFilter<Double, GammaDistribution.PDF>(); particleFilter.setRandom(RANDOM); particleFilter.setNumParticles(200); particleFilter.setParticlePctThreadhold(0.5); particleFilter.setUpdater(new GammaUpdater() ); SamplingImportanceResamplingParticleFilter<Double,GammaDistribution.PDF> clone = (SamplingImportanceResamplingParticleFilter<Double, GammaDistribution.PDF>) particleFilter.clone(); assertNotNull( clone ); assertNotSame( clone, particleFilter ); assertSame( particleFilter.getRandom(), clone.getRandom() ); assertNotSame( particleFilter.getUpdater(), clone.getUpdater() ); assertEquals( particleFilter.getParticlePctThreadhold(), clone.getParticlePctThreadhold() ); assertEquals( particleFilter.getNumParticles(), clone.getNumParticles() ); } /** * Test Gamma Distribution */ public void testGammaInference() { System.out.println( "Gamma Distribution Inference" ); double shape = 5.0; double scale = 2.0; GammaDistribution.PDF target = new GammaDistribution.PDF( shape, scale ); final int numSamples = 1000; ArrayList<Double> samples = target.sample(RANDOM, numSamples); SamplingImportanceResamplingParticleFilter<Double,GammaDistribution.PDF> particleFilter = new SamplingImportanceResamplingParticleFilter<Double, GammaDistribution.PDF>(); particleFilter.setRandom(RANDOM); particleFilter.setNumParticles(200); particleFilter.setParticlePctThreadhold(0.5); particleFilter.setUpdater(new GammaUpdater() ); DataDistribution<GammaDistribution.PDF> particles = particleFilter.learn(samples); ArrayList<WeightedValue<Double>> shapes = new ArrayList<WeightedValue<Double>>( particles.getDomain().size() ); ArrayList<WeightedValue<Double>> scales = new ArrayList<WeightedValue<Double>>( particles.getDomain().size() ); for( GammaDistribution.PDF gamma : particles.getDomain() ) { shapes.add( new DefaultWeightedValue<Double>( gamma.getShape(), particles.get(gamma) ) ); scales.add( new DefaultWeightedValue<Double>( gamma.getScale(), particles.get(gamma) ) ); } UnivariateGaussian shapeResult = UnivariateGaussian.WeightedMaximumLikelihoodEstimator.learn(shapes, 0.0); UnivariateGaussian scaleResult = UnivariateGaussian.WeightedMaximumLikelihoodEstimator.learn(scales, 0.0); System.out.println( "Shape: " + shapeResult ); System.out.println( "Scale: " + scaleResult ); System.out.println( "Target: " + target ); } public SamplingImportanceResamplingParticleFilter<Double,Double> createInstance() { SamplingImportanceResamplingParticleFilter<Double,Double> particleFilter = new SamplingImportanceResamplingParticleFilter<Double,Double>(); particleFilter.setRandom(RANDOM); particleFilter.setNumParticles(100); particleFilter.setParticlePctThreadhold(0.5); particleFilter.setUpdater( new GaussianUpdater() ); return particleFilter; } public UnivariateGaussian createConditionalDistribution() { double mean = RANDOM.nextGaussian(); double variance = RANDOM.nextDouble() * 2.0 + 1.0; return new UnivariateGaussian( mean, variance ); } public class GammaUpdater extends AbstractCloneableSerializable implements ParticleFilter.Updater<Double,GammaDistribution.PDF> { private GammaDistribution.PDF initialDistribution; private Distribution<Double> tweaker; public GammaUpdater() { this.initialDistribution = new GammaDistribution.PDF( 2.0, 1.0 ); this.tweaker = new LogNormalDistribution(0.0, 1e-4); } public GammaDistribution.PDF update( GammaDistribution.PDF previousParameter) { double sf1 = this.tweaker.sample(RANDOM); double sf2 = this.tweaker.sample(RANDOM); return new GammaDistribution.PDF( sf1 * previousParameter.getShape(), sf2 * previousParameter.getScale() ); } public DataDistribution<GammaDistribution.PDF> createInitialParticles( int numParticles) { DataDistribution<GammaDistribution.PDF> distribution = new DefaultDataDistribution<GammaDistribution.PDF>(); final double uniformWeight = 1.0/numParticles; for( int i = 0; i < numParticles; i++ ) { distribution.increment( this.update(this.initialDistribution), uniformWeight ); } return distribution; } public double computeLogLikelihood( GammaDistribution.PDF particle, Double observation) { return particle.logEvaluate(observation); } } public void testKnownValues() { System.out.println( "Bernoulli Inference" ); double p = 0.75; BernoulliDistribution.PMF target = new BernoulliDistribution.PMF(p); final int numSamples = 1000; ArrayList<Number> samples = target.sample(RANDOM, numSamples); SamplingImportanceResamplingParticleFilter<Number,BernoulliDistribution.PMF> particleFilter = new SamplingImportanceResamplingParticleFilter<Number,BernoulliDistribution.PMF>(); particleFilter.setRandom(RANDOM); particleFilter.setNumParticles(200); particleFilter.setParticlePctThreadhold(0.5); particleFilter.setUpdater(new BernoulliUpdater() ); DataDistribution<BernoulliDistribution.PMF> particles = particleFilter.learn(samples); ArrayList<WeightedValue<Double>> ps = new ArrayList<WeightedValue<Double>>( particles.getDomain().size() ); for( BernoulliDistribution.PMF b : particles.getDomain() ) { ps.add( new DefaultWeightedValue<Double>( b.getP(), particles.get(b) ) ); } UnivariateGaussian presult = UnivariateGaussian.WeightedMaximumLikelihoodEstimator.learn(ps, 0.0); System.out.println( "Presult: " + presult ); BernoulliBayesianEstimator bbe = new BernoulliBayesianEstimator(); BetaDistribution posterior = bbe.learn(samples); System.out.println( "Beta: Mean = " + posterior.getMean() + ", Variance = " + posterior.getVariance() ); } public void testBernoulliInference2() { System.out.println( "Bernoulli Inference2" ); double p = 0.75; BernoulliDistribution.PMF target = new BernoulliDistribution.PMF(p); final int numSamples = 100; ArrayList<Number> samples = target.sample(RANDOM, numSamples); SamplingImportanceResamplingParticleFilter<Number,BernoulliDistribution.PMF> particleFilter = new SamplingImportanceResamplingParticleFilter<Number,BernoulliDistribution.PMF>(); particleFilter.setRandom(RANDOM); particleFilter.setNumParticles(100); particleFilter.setParticlePctThreadhold(1.0); particleFilter.setUpdater(new BernoulliUpdater() ); DataDistribution<BernoulliDistribution.PMF> particles = particleFilter.learn(samples); ArrayList<WeightedValue<Double>> ps = new ArrayList<WeightedValue<Double>>( particles.getDomain().size() ); for( BernoulliDistribution.PMF b : particles.getDomain() ) { ps.add( new DefaultWeightedValue<Double>( b.getP(), particles.get(b) ) ); } UnivariateGaussian presult = UnivariateGaussian.WeightedMaximumLikelihoodEstimator.learn(ps, 0.0); System.out.println( "Presult: " + presult ); BernoulliBayesianEstimator bbe = new BernoulliBayesianEstimator(); BetaDistribution posterior = bbe.learn(samples); System.out.println( "Beta: Mean = " + posterior.getMean() + ", Variance = " + posterior.getVariance() ); } public class GaussianUpdater extends AbstractCloneableSerializable implements ParticleFilter.Updater<Double,Double> { UnivariateGaussian tweaker; double variance; public GaussianUpdater() { this.tweaker = new UnivariateGaussian(); this.variance = 2.0; } public Double update( Double previousParameter) { return tweaker.sample(RANDOM) + previousParameter; } public DataDistribution<Double> createInitialParticles( int numParticles) { return new DefaultDataDistribution<Double>( this.tweaker.sample(RANDOM, numParticles) ); } public double computeLogLikelihood( Double particle, Double observation) { return UnivariateGaussian.PDF.logEvaluate( observation, particle, this.variance ); } } public class BernoulliUpdater extends AbstractCloneableSerializable implements ParticleFilter.Updater<Number,BernoulliDistribution.PMF> { private Distribution<Double> tweaker; public BernoulliUpdater() { this.tweaker = new LogNormalDistribution(0.0, 1e-4); } public BernoulliDistribution.PMF update( BernoulliDistribution.PMF previousParameter) { double sf1 = this.tweaker.sample(RANDOM); double pinv = 1.0 / previousParameter.getP() - 1.0; pinv *= sf1; return new BernoulliDistribution.PMF( 1.0/(pinv+1.0) ); } public DataDistribution<BernoulliDistribution.PMF> createInitialParticles( int numParticles) { DataDistribution<BernoulliDistribution.PMF> particles = new DefaultDataDistribution<PMF>(); final double uniformMass = 1.0/numParticles; for( int i = 0; i < numParticles; i++ ) { double p = RANDOM.nextDouble(); particles.increment( new BernoulliDistribution.PMF( p ), uniformMass ); } return particles; } public double computeLogLikelihood( BernoulliDistribution.PMF particle, Number observation) { return particle.logEvaluate(observation); } } }