package dr.inference.operators; import dr.inference.prior.Prior; import dr.inference.model.*; import dr.inference.distribution.NormalDistributionModel; import dr.inference.distribution.DistributionLikelihood; import dr.inference.distribution.ParametricDistributionModel; import dr.util.Attribute; import dr.math.distributions.NormalDistribution; import dr.math.MathUtils; import java.util.ArrayList; import java.util.List; /** * Implements a generic univariate slice sampler. * * See: RM Neal (2003) Slice Sampling, Annals of Statistics, 31, 705-767 (with discussion) * * @author Marc A. Suchard */ public class SliceOperator extends SimpleMetropolizedGibbsOperator { public SliceOperator(Variable<Double> variable) { this(new SliceInterval.SteppingOut(), variable); } public SliceOperator(SliceInterval sliceInterval, Variable<Double> variable) { this.sliceInterval = sliceInterval; if (variable.getSize() != 1) { throw new RuntimeException("Generic slice sampler is currently for univariate parameters only"); } this.variable = variable; sliceInterval.setSliceSampler(this); } public Variable<Double> getVariable() { return variable; } public double doOperation(Prior prior, Likelihood likelihood) throws OperatorFailedException { double logPosterior = evaluate(likelihood, prior); double cutoffDensity = logPosterior + MathUtils.randomLogDouble(); sliceInterval.drawFromInterval(prior, likelihood, cutoffDensity, width); // No need to set variable, as SliceInterval has already done this (and recomputed posterior) return 0; } public int getStepCount() { return 1; } public String getOperatorName() { return "genericSliceSampler"; } public static void main(String[] arg) { // Define normal model Parameter meanParameter = new Parameter.Default(1.0); // Starting value Variable<Double> stdev = new Variable.D(1.0, 1); // Fixed value ParametricDistributionModel densityModel = new NormalDistributionModel(meanParameter, stdev); DistributionLikelihood likelihood = new DistributionLikelihood(densityModel); // Define prior DistributionLikelihood prior = new DistributionLikelihood(new NormalDistribution(0.0, 1.0)); // Hyper-priors prior.addData(meanParameter); // Define data likelihood.addData(new Attribute.Default<double[]>("Data", new double[] {0.0, 1.0, 2.0})); List<Likelihood> list = new ArrayList<Likelihood>(); list.add(likelihood); list.add(prior); CompoundLikelihood posterior = new CompoundLikelihood(0, list); SliceOperator sliceSampler = new SliceOperator(meanParameter); final int length = 10000; double mean = 0; double variance = 0; for(int i = 0; i < length; i++) { try { sliceSampler.doOperation(null, posterior); } catch (OperatorFailedException e) { System.err.println(e.getMessage()); } double x = meanParameter.getValue(0); mean += x; variance += x*x; } mean /= length; variance /= length; variance -= mean*mean; System.out.println("E(x) = "+mean); System.out.println("V(x) = "+variance); } private final SliceInterval sliceInterval; private final double width = 1.0; private final Variable<Double> variable; }