/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.gibbs;
import static com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.*;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.util.Random;
import org.apache.commons.math3.stat.StatUtils;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.Normal;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.CurrentModel;
import com.analog.lyric.dimple.model.variables.Real;
import com.analog.lyric.dimple.options.DimpleOptions;
import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions;
import com.analog.lyric.dimple.solvers.gibbs.GibbsReal;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph;
import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.MHSampler;
import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.SliceSampler;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
*
* @since 0.08
* @author Christopher Barber
*/
public class TestGibbsConjugateSampling extends DimpleTestBase
{
@Test
public void test3()
{
// Adapted from MATLAB testConjugateSampling.m/test3
Random rand = new Random(42);
final double priorMean = 3.0;
final double priorPrecision = 0.01;
final double dataMean = 10;
final double dataPrecision = 0.001;
final int numDatapoints = 100;
double[] data = new double[numDatapoints];
double dataSum = 0.0;
for (int i = 0; i < numDatapoints; ++i)
{
dataSum += data[i] = dataMean + rand.nextGaussian() * dataPrecision;
}
final double expectedPrecision = priorPrecision + numDatapoints * dataPrecision;
final double expectedStd = 1 / Math.sqrt(expectedPrecision);
final double expectedMean = (priorMean * priorPrecision + dataSum * dataPrecision) / expectedPrecision;
final FactorGraph fg = new FactorGraph();
Real mean;
Real[] x;
try (CurrentModel currrent = using(fg))
{
mean = real("mean");
mean.setPrior(new Normal(priorMean, priorPrecision));
x = fixed("x", data);
addFactor(new Normal(), mean, dataPrecision, x);
}
fg.setOption(DimpleOptions.randomSeed, 1L);
fg.setOption(GibbsOptions.numSamples, 1000);
fg.setOption(GibbsOptions.saveAllSamples, true);
fg.setOption(GibbsOptions.saveAllScores, true);
GibbsSolverGraph sfg = requireNonNull(fg.setSolverFactory(new GibbsSolver()));
GibbsReal smean = sfg.getReal(mean);
assertEquals("NormalSampler", smean.getSamplerName());
fg.solve();
double[] means = smean.getAllSamples();
assertEquals(expectedMean, StatUtils.mean(means), 0.01);
assertEquals(expectedStd, Math.sqrt(StatUtils.variance(means)), .05);
// Try again with slice sampler
mean.setOption(GibbsOptions.realSampler, SliceSampler.class);
fg.setOption(GibbsOptions.numSamples, 2000);
fg.solve();
double[] means2 = smean.getAllSamples();
assertEquals("SliceSampler", smean.getSamplerName());
assertEquals("NormalSampler", sfg.getReal(x[0]).getSamplerName());
assertEquals(expectedMean, StatUtils.mean(means2), 0.1);
assertEquals(expectedStd, Math.sqrt(StatUtils.variance(means2)), 0.1);
// Try again with MH
mean.setOption(GibbsOptions.realSampler, MHSampler.class);
fg.setOption(GibbsOptions.scansPerSample, 10);
fg.setOption(GibbsOptions.numSamples, 1000);
fg.solve();
double[] means3 = smean.getAllSamples();
assertEquals("MHSampler", smean.getSamplerName());
assertEquals("NormalSampler", sfg.getReal(x[0]).getSamplerName());
assertEquals(expectedMean, StatUtils.mean(means3), 0.25);
assertEquals(expectedStd, Math.sqrt(StatUtils.variance(means3)), 0.25);
}
}