/* * File: NRCMCMCTest.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Feb 23, 2010, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.statistics.bayesian; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.collection.CollectionUtil; import gov.sandia.cognition.statistics.AbstractDistribution; import gov.sandia.cognition.statistics.DataDistribution; import gov.sandia.cognition.statistics.Distribution; import gov.sandia.cognition.statistics.distribution.GammaDistribution; import gov.sandia.cognition.statistics.distribution.LogNormalDistribution; import gov.sandia.cognition.statistics.distribution.UnivariateGaussian; import gov.sandia.cognition.statistics.method.GaussianConfidence; import gov.sandia.cognition.util.AbstractCloneableSerializable; import gov.sandia.cognition.util.DefaultWeightedValue; import gov.sandia.cognition.util.ObjectUtil; import gov.sandia.cognition.util.WeightedValue; import java.util.ArrayList; import java.util.Collection; import junit.framework.TestCase; import java.util.Random; /** * Unit tests for NRCMCMCTest. * * @author krdixon */ @PublicationReference( author={ "William H. Press", "Saul A. Teukolsky", "William T. Vetterling", "Brian P. Flannery" }, title="Numerical Recipes, Third Edition", type=PublicationType.Book, year=2007, pages={829,835} ) public class NRCMCMCTest extends TestCase { /** * Random number generator to use for a fixed random seed. */ public final static Random RANDOM = new Random( 1 ); /** * Default tolerance of the regression tests, {@value}. */ public final double TOLERANCE = 1e-5; /** * Tests for class NRCMCMCTest. * @param testName Name of the test. */ public NRCMCMCTest( String testName) { super(testName); } public void testMCMC() { System.out.println( "NRC MCMC Example" ); // Target on page 833 State target = new State(); target.block1 = new GammaDistribution.PDF( 1, 1.0/3.0 ); target.block2 = new GammaDistribution.PDF( 2, 1.0/2.0 ); target.tc = 200.0; ArrayList<Double> data = target.sample(RANDOM, 1000); int numSamples = 1000; MetropolisHastingsAlgorithm<Double,State> mcmc = new MetropolisHastingsAlgorithm<Double,State>(); mcmc.setBurnInIterations(1000); mcmc.setIterationsPerSample(10); mcmc.setMaxIterations( numSamples ); mcmc.setRandom(RANDOM); mcmc.setUpdater( new StateProposer() ); DataDistribution<State> result = mcmc.learn( data ); ArrayList<Double> lam1 = new ArrayList<Double>( result.getDomain().size() ); ArrayList<Double> lam2 = new ArrayList<Double>( result.getDomain().size() ); ArrayList<Double> k1 = new ArrayList<Double>( result.getDomain().size() ); ArrayList<Double> k2 = new ArrayList<Double>( result.getDomain().size() ); ArrayList<Double> tc = new ArrayList<Double>( result.getDomain().size() ); for( State state : result.getDomain() ) { lam1.add( 1.0 / state.block1.getScale() ); k1.add( state.block1.getShape() ); lam2.add( 1.0 / state.block2.getScale() ); k2.add( state.block2.getShape() ); tc.add( state.tc ); } UnivariateGaussian.MaximumLikelihoodEstimator mle = new UnivariateGaussian.MaximumLikelihoodEstimator(); UnivariateGaussian l1 = mle.learn(lam1); UnivariateGaussian l2 = mle.learn(lam2); UnivariateGaussian sample1 = mle.learn(k1); UnivariateGaussian sample2 = mle.learn(k2); UnivariateGaussian switchTime = mle.learn(tc); System.out.println( "Num Samples: " + result.getDomain().size() ); System.out.println( "Lambda1: " + l1 + " (" + 1.0/target.block1.getScale() + ")" ); System.out.println( "Lambda2: " + l2 + " (" + 1.0/target.block2.getScale() + ")" ); System.out.println( "K1: " + sample1 + " (" + target.block1.getShape() + ")" ); System.out.println( "K2: " + sample2 + " (" + target.block2.getShape() + ")" ); System.out.println( "TC: " + switchTime + " (" + target.tc + ")" ); System.out.println( "Proposals: " + ((StateProposer) mcmc.getUpdater()).numProposals ); final double confidence = 0.95; assertTrue( GaussianConfidence.computeConfidenceInterval(switchTime, 1, confidence ).withinInterval( target.tc ) ); assertTrue( GaussianConfidence.computeConfidenceInterval(l1, 1, confidence ).withinInterval( 1.0/target.block1.getScale() ) ); assertTrue( GaussianConfidence.computeConfidenceInterval(l2, 1, confidence ).withinInterval( 1.0/target.block2.getScale() ) ); assertTrue( GaussianConfidence.computeConfidenceInterval(sample1, 1, confidence ).withinInterval( target.block1.getShape() ) ); assertTrue( GaussianConfidence.computeConfidenceInterval(sample2, 1, confidence ).withinInterval( target.block2.getShape() ) ); } public static class State extends AbstractDistribution<Double> { // Distribution as Gamma(k1,1/lam1) GammaDistribution.PDF block1; GammaDistribution.PDF block2; double tc; State() { } @Override public State clone() { State clone = (State) super.clone(); clone.block1 = ObjectUtil.cloneSafe(this.block1); clone.block2 = ObjectUtil.cloneSafe(this.block2); return clone; } @Override public String toString() { StringBuilder retval = new StringBuilder( 1000 ); retval.append( "L1: " + (1.0/this.block1.getScale()) ); retval.append( ", L2: " + (1.0/this.block2.getScale()) ); retval.append( ", k1: " + this.block1.getShape() ); retval.append( ", k2: " + this.block2.getShape() ); retval.append( ", TC: " + this.tc ); return retval.toString(); } public Double getMean() { throw new UnsupportedOperationException("Not supported yet."); } @Override public void sampleInto( final Random random, final int sampleCount, final Collection<? super Double> output) { double tn = this.block1.sample(random); double tnm1; int num = 0; while( tn < this.tc ) { if( num >= sampleCount ) { break; } tnm1 = tn; double delta = this.block1.sample(random); tn = tnm1 + delta; output.add( tn ); num++; } while( num < sampleCount ) { tnm1 = tn; double delta = this.block2.sample(random); tn = tnm1 + delta; output.add( tn ); num++; } } } public static class StateProposer extends AbstractCloneableSerializable implements MetropolisHastingsAlgorithm.Updater<Double,State> { int numProposals; Distribution<Double> rateDistribution; public StateProposer() { // Variance = 1e-2 //NRC MCMC Example //Num Samples: 1000 //Lambda1: Mean: 3.046970925969783 Variance: 0.01682961943657213 //Lambda2: Mean: 1.956034800811245 Variance: 0.006137517773549317 //K1: Mean: 1.0 Variance: 1.0E-5 //K2: Mean: 2.0 Variance: 1.0E-5 //TC: Mean: 204.05029025923912 Variance: 6.079847500181761 //Proposals: 189790 // Variance = 1e-3 //NRC MCMC Example //Num Samples: 1000 //Lambda1: Mean: 3.0500469358025795 Variance: 0.015369884721472172 //Lambda2: Mean: 1.9544424909132538 Variance: 0.005578737285304441 //K1: Mean: 1.0 Variance: 1.0E-5 //K2: Mean: 2.0 Variance: 1.0E-5 //TC: Mean: 203.68103081793976 Variance: 6.182053680254045 //Proposals: 56832 // Variance = 1e-4 //NRC MCMC Example //Num Samples: 1000 //Lambda1: Mean: 3.049272758347756 Variance: 0.014315273259578792 //Lambda2: Mean: 1.9582259130893622 Variance: 0.005027429314720611 //K1: Mean: 1.0 Variance: 1.0E-5 //K2: Mean: 2.0 Variance: 1.0E-5 //TC: Mean: 203.85986467803403 Variance: 5.229608836301206 //Proposals: 25666 // Variance = 1e-5 //NRC MCMC Example //Num Samples: 1000 //Lambda1: Mean: 2.9129840668142233 Variance: 0.12450411895211685 //Lambda2: Mean: 1.4253602585644953 Variance: 0.08850209147018964 //K1: Mean: 1.0 Variance: 1.0E-5 //K2: Mean: 1.181 Variance: 0.1483973873873874 //TC: Mean: 142.983944933737 Variance: 1502.8535284551501 //Proposals: 13953 // Variance = 1e-6 //NRC MCMC Example //Num Samples: 1000 //Lambda1: Mean: 1.4832103096822644 Variance: 0.03894186531760483 //Lambda2: Mean: 1.6441403574137285 Variance: 0.04866942131302686 //K1: Mean: 1.0 Variance: 1.0E-5 //K2: Mean: 1.0 Variance: 1.0E-5 //TC: Mean: 90.75214296664389 Variance: 4.61211278684217 //Proposals: 12499 this.rateDistribution = new LogNormalDistribution( 0.0, 1e-4 ); this.numProposals = 0; } public WeightedValue<State> makeProposal( State location) { this.numProposals++; State proposal = location.clone(); double qratio; // 90% of the time, fiddle with the rates double p = RANDOM.nextDouble(); if( p < 0.9 ) { double lam1 = 1.0/location.block1.getScale(); double ln1 = this.rateDistribution.sample(RANDOM); proposal.block1.setScale( 1.0 / (lam1*ln1) ); double lam2 = 1.0/location.block2.getScale(); double ln2 = this.rateDistribution.sample(RANDOM); proposal.block2.setScale( 1.0 / (lam2*ln2) ); double tcstep = this.rateDistribution.sample(RANDOM); proposal.tc = location.tc * tcstep; qratio = ln1 * ln2 * tcstep; } else { p = RANDOM.nextDouble(); int k1 = (int) location.block1.getShape(); double k1hat = k1; if( k1 > 1 ) { if( p < 0.5 ) { k1hat = k1; // Do nothing } else if( p < 0.75 ) { k1hat = k1+1; } else { k1hat = k1-1; } } else { // If we're already at k=1, then increament 25% of the time if( p < 0.75 ) { k1hat = k1; } else { k1hat = k1+1; } } double ln1 = k1hat / k1; proposal.block1.setScale( proposal.block1.getScale() / ln1 ); proposal.block1.setShape(k1hat); p = RANDOM.nextDouble(); int k2 = (int) location.block2.getShape(); double k2hat = k2; if( k2 > 1 ) { if( p < 0.5 ) { k2hat = k2; // Do nothing } else if( p < 0.75 ) { k2hat = k2+1; } else { k2hat = k2-1; } } else { // If we're already at k=1, then increament 25% of the time if( p < 0.75 ) { k2hat = k2; } else { k2hat = k2+1; } } double ln2 = k2hat / k2; proposal.block2.setScale( proposal.block2.getScale() / ln2 ); proposal.block2.setShape(k2hat); qratio = 1.0; } return new DefaultWeightedValue<NRCMCMCTest.State>( proposal, qratio ); } public State createInitialParameter() { State initial = new State(); initial.block1 = new GammaDistribution.PDF( 1.0, 1.0 ); initial.block2 = new GammaDistribution.PDF( 1.0, 1.0/3.0 ); initial.tc = 100.0; return initial; } public double computeLogLikelihood( State first, Iterable<? extends Double> second) { ArrayList<? extends Double> data = CollectionUtil.asArrayList(second); GammaDistribution.PDF gamma; final int num = data.size(); double plog = 0.0; for( int n = 1; n < num; n++ ) { final double tn = data.get(n); if( tn <= first.tc ) { gamma = first.block1; } else { gamma = first.block2; } double delta = tn - data.get(n-1); plog += Math.log(gamma.evaluate(delta)); } return plog; } } }