/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.solvers.gibbs; import static java.util.Objects.*; import static org.junit.Assert.*; import java.util.Random; import org.junit.Test; import com.analog.lyric.dimple.factorfunctions.Normal; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.domains.RealDomain; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.solvers.gibbs.GibbsReal; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph; import com.analog.lyric.dimple.test.DimpleTestBase; @SuppressWarnings("deprecation") public class GibbsTestParameterizedNormal extends DimpleTestBase { protected static boolean debugPrint = false; protected static boolean repeatable = true; @Test public void test1() { if (debugPrint) System.out.println("== test1 =="); int numNormalVariables = 1000; int numSamples = 1000; int updatesPerSample = 10; int burnInUpdates = 1000; FactorGraph graph = new FactorGraph(); GibbsSolverGraph solver = requireNonNull(graph.setSolverFactory(new GibbsSolver())); solver.setNumSamples(numSamples); solver.setUpdatesPerSample(updatesPerSample); solver.setBurnInUpdates(burnInUpdates); // Generate data int seed = 1; Random r; if (repeatable) r = new Random(seed); else r = new Random(); double modelMean = 27; double modelSigma = 14; double modelInverseVariance = 1/(modelSigma*modelSigma); double[] normalValues = new double[numNormalVariables]; for (int i =0; i < numNormalVariables; i++) normalValues[i] = modelSigma*r.nextGaussian() + modelMean; if (debugPrint) System.out.println("ModelMean: " + (Double)modelMean); if (debugPrint) System.out.println("ModelInverseVariance: " + (Double)modelInverseVariance); Real vModelMean = new Real(); Real vModelInverseVariance = new Real(new RealDomain(0, Double.POSITIVE_INFINITY)); vModelMean.setName("Mean"); vModelInverseVariance.setName("InverseVariance"); Object[] vars = new Object[numNormalVariables + 2]; int index = 0; vars[index++] = vModelMean; vars[index++] = vModelInverseVariance; for (int i = 0; i < numNormalVariables; i++) vars[index++] = normalValues[i]; graph.addFactor(new Normal(), vars); GibbsReal svModelMean = requireNonNull((GibbsReal)vModelMean.getSolver()); GibbsReal svModelInverseVariance = requireNonNull((GibbsReal)vModelInverseVariance.getSolver()); if (repeatable) solver.setSeed(1); // Make this repeatable graph.solve(); // Best should be the same as the mean in this case if (debugPrint) System.out.println("vModelMeanBest: " + (Double)svModelMean.getBestSample()); if (debugPrint) System.out.println("vModelInverseVarianceBest: " + (Double)svModelInverseVariance.getBestSample()); assertEquals(1.0, svModelMean.getBestSample() / modelMean, .05); assertEquals(1.0, svModelInverseVariance.getBestSample() / modelInverseVariance, .05); if (repeatable) { // These values are arbitrary and based on the random number generator, sampler and seed assertEquals(27.029282787511672,svModelMean.getBestSample(),1e-12); assertEquals(0.005320791975254845,svModelInverseVariance.getBestSample(),1e-12); } } }