/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.solvers.gibbs; import static com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.*; import static java.util.Objects.*; import static org.junit.Assert.*; import org.apache.commons.math3.stat.StatUtils; import org.junit.Test; import com.analog.lyric.dimple.environment.DimpleEnvironment; import com.analog.lyric.dimple.factorfunctions.AdditiveNoise; import com.analog.lyric.dimple.factorfunctions.MixedNormal; import com.analog.lyric.dimple.factorfunctions.Normal; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.repeated.FactorFunctionDataSource; import com.analog.lyric.dimple.model.repeated.FactorGraphStream; import com.analog.lyric.dimple.model.repeated.RealStream; import com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.CurrentModel; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.options.DimpleOptions; import com.analog.lyric.dimple.solvers.gibbs.GibbsDiscrete; import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions; import com.analog.lyric.dimple.solvers.gibbs.GibbsReal; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph; import com.analog.lyric.dimple.solvers.gibbs.ISolverFactorGibbs; import com.analog.lyric.dimple.test.DimpleTestBase; public class RealVariableGibbsTest extends DimpleTestBase { protected static boolean debugPrint = false; protected static boolean repeatable = true; @SuppressWarnings({ "deprecation", "null" }) @Test public void basicTest1() { if (debugPrint) System.out.println("== basicTest1 =="); int numSamples = 10000; int updatesPerSample = 10; int burnInUpdates = 1000; FactorGraph graph = new FactorGraph(); graph.setSolverFactory(new com.analog.lyric.dimple.solvers.gibbs.Solver()); GibbsSolverGraph solver = (GibbsSolverGraph)graph.getSolver(); solver.setNumSamples(numSamples); solver.setUpdatesPerSample(updatesPerSample); solver.setBurnInUpdates(burnInUpdates); double aPriorMean = 1; double aPriorSigma = 0.5; double aPriorR = 1/(aPriorSigma*aPriorSigma); double bPriorMean = -1; double bPriorSigma = 2.; double bPriorR = 1/(bPriorSigma*bPriorSigma); Real a = new Real(); Real b = new Real(); a.setInputObject(new Normal(aPriorMean,aPriorR)); b.setInputObject(new Normal(bPriorMean, bPriorR)); a.setName("a"); b.setName("b"); double abMean = 0; double abSigma = 1; double abR = 1/(abSigma*abSigma); graph.addFactor(new Normal(abMean,abR), a, b); GibbsReal sa = (GibbsReal)a.getSolver(); GibbsReal sb = (GibbsReal)b.getSolver(); sa.setProposalStandardDeviation(0.1); sb.setProposalStandardDeviation(0.1); if (repeatable) solver.setSeed(1); // Make this repeatable solver.saveAllSamples(); graph.solve(); double[] aSamples = sa.getAllSamples(); double[] bSamples = sb.getAllSamples(); double aSum = 0; for (Object s : aSamples) aSum += (Double)s; double aMean = aSum/aSamples.length; if (debugPrint) System.out.println("aSampleMean: " + aMean); double bSum = 0; for (Object s : bSamples) bSum += (Double)s; double bMean = bSum/bSamples.length; if (debugPrint) System.out.println("bSampleMean: " + bMean); double aExpectedMean = (aPriorMean*aPriorR + abMean*abR)/(aPriorR + abR); double bExpectedMean = (bPriorMean*bPriorR + abMean*abR)/(bPriorR + abR); if (debugPrint) System.out.println("aExpectedMean: " + aExpectedMean); if (debugPrint) System.out.println("bExpectedMean: " + bExpectedMean); // Best should be the same as the mean in this case if (debugPrint) System.out.println("aBest: " + (Double)sa.getBestSample()); if (debugPrint) System.out.println("bBest: " + (Double)sb.getBestSample()); assertEquals(aMean,0.8050875226168582,1e-12); assertEquals(bMean,-0.1921312702232493,1e-12); assertEquals(sa.getBestSample(),0.8043550661413381,1e-12); assertEquals(sb.getBestSample(),-0.20700427734616236,1e-12); } @SuppressWarnings("deprecation") @Test public void basicTest2() { if (debugPrint) System.out.println("== basicTest2 =="); int numSamples = 10000; int updatesPerSample = 10; int burnInUpdates = 1000; FactorGraph graph = new FactorGraph(); graph.setSolverFactory(new com.analog.lyric.dimple.solvers.gibbs.Solver()); GibbsSolverGraph solver = requireNonNull((GibbsSolverGraph)graph.getSolver()); solver.setNumSamples(numSamples); solver.setUpdatesPerSample(updatesPerSample); solver.setBurnInUpdates(burnInUpdates); double aPriorMean = 0; double aPriorSigma = 5; double aPriorR = 1/(aPriorSigma*aPriorSigma); double bProb1 = 0.6; double bProb0 = 1 - bProb1; Real a = new Real(); a.setInputObject(new Normal(aPriorMean,aPriorR)); Discrete b = new Discrete(0,1); b.setInput(bProb0, bProb1); a.setName("a"); b.setName("b"); double fMean0 = -1; double fSigma0 = 0.75; double fMean1 = 1; double fSigma1 = 0.75; double fR0 = 1/(fSigma0*fSigma0); double fR1 = 1/(fSigma1*fSigma1); graph.addFactor(new MixedNormal(fMean0, fR0, fMean1, fR1), a, b); GibbsReal sa = requireNonNull((GibbsReal)a.getSolver()); GibbsDiscrete sb = requireNonNull((GibbsDiscrete)b.getSolver()); sa.setProposalStandardDeviation(1.0); if (repeatable) solver.setSeed(1); // Make this repeatable solver.saveAllSamples(); graph.solve(); double[] aSamples = sa.getAllSamples(); Object[] bSamples = sb.getAllSamples(); double aSum = 0; for (Object s : aSamples) aSum += (Double)s; double aMean = aSum/aSamples.length; if (debugPrint) System.out.println("aSampleMean: " + aMean); double bSum = 0; for (Object s : bSamples) bSum += (Integer)s; double bMean = bSum/bSamples.length; if (debugPrint) System.out.println("bSampleMean: " + bMean); if (debugPrint) { System.out.print("a = ["); for (Object s : aSamples) System.out.print(s + " "); System.out.print("];\n"); } double aExpectedMean = bProb0*(aPriorMean*aPriorR + fMean0*fR0)/(aPriorR + fR0) + bProb1*(aPriorMean*aPriorR + fMean1*fR1)/(aPriorR + fR1); if (debugPrint) System.out.println("aExpectedMean: " + aExpectedMean); if (debugPrint) System.out.println("bExpectedMean: " + bProb1); if (debugPrint) System.out.println("aBest: " + (Double)sa.getBestSample()); if (debugPrint) System.out.println("bBest: " + sb.getBestSample()); assertEquals(aMean,0.20867216566185906, 1e-12); assertEquals(bMean,0.6055,1e-12); assertEquals(sa.getBestSample(),0.977986266650138,1e-12); assertTrue((Integer)sb.getBestSample() == 1); } @Test public void testBeliefMoments() { // Java version of MATLAB testBeliefMoments/test1 // Construct model final FactorGraph fg = new FactorGraph(); Real a, b, x, y; try (CurrentModel current = using(fg)) { a = name("a", normal(0, 1)); b = name("b", gamma(1, 1)); x = name("x", square(sum(a, b))); y = name("y", sum(x, square(log(lognormal(2,7))))); } // Set data y.setPrior(5); // Configure Gibbs options // TODO: test is sensitive to choice of seed! Perhaps we should increase numSamples or something... fg.setOption(DimpleOptions.randomSeed, 1L); fg.setOption(GibbsOptions.numSamples, 100); fg.setOption(GibbsOptions.burnInScans, 10); // Run the solver without saving samples. GibbsSolverGraph sfg = requireNonNull(fg.setSolverFactory(new GibbsSolver())); fg.setOption(GibbsOptions.saveAllSamples, false); fg.solve(); GibbsReal sa = sfg.getReal(a); GibbsReal sb = sfg.getReal(b); GibbsReal sx = sfg.getReal(x); GibbsReal sy = sfg.getReal(y); double aMean = sa.getSampleMean(); double aVariance = sa.getSampleVariance(); double bMean = sb.getSampleMean(); double bVariance = sb.getSampleVariance(); double xMean = sx.getSampleMean(); double xVariance = sx.getSampleVariance(); double yMean = sy.getSampleMean(); double yVariance = sy.getSampleVariance(); // Run the solver again, this time saving all samples fg.setOption(GibbsOptions.saveAllSamples, true); fg.solve(); double[] aSamples = sa.getAllSamples(); double[] bSamples = sb.getAllSamples(); double[] xSamples = sx.getAllSamples(); double[] ySamples = sy.getAllSamples(); assertEquals(aMean, StatUtils.mean(aSamples), 1e-13); assertEquals(bMean, StatUtils.mean(bSamples), 1e-13); assertEquals(xMean, StatUtils.mean(xSamples), 1e-13); assertEquals(yMean, StatUtils.mean(ySamples), 1e-13); assertEquals(aVariance, StatUtils.variance(aSamples), 1e-13); assertEquals(bVariance, StatUtils.variance(bSamples), 1e-13); assertEquals(xVariance, StatUtils.variance(xSamples), 1e-13); assertEquals(yVariance, StatUtils.variance(ySamples), 1e-13); assertEquals(5.0, yMean, 0.0); assertEquals(0.0, yVariance, 0.0); // Make sure moments are the same the next time assertEquals(aMean, sa.getSampleMean(), 0.0); assertEquals(aVariance, sa.getSampleVariance(), 0.0); } @Test public void testRealRolledUp() { // Java version of MATLAB testRealRolledUp.m // FIXME - test is highly dependent on value of seed! // Graph parameters final boolean useSeed = true; final long seed = 45L; final int hmmLength = 20; final int bufferSize = 10; // Gibbs parameters DimpleEnvironment env = DimpleEnvironment.active(); env.setOption(GibbsOptions.numSamples, 10000); env.setOption(GibbsOptions.burnInScans, 100); env.setOption(GibbsOptions.numRandomRestarts, 0); if (useSeed) { testRand.setSeed(seed); env.setOption(DimpleOptions.randomSeed, seed); } // Model parameters final double initialMean = 0.0; final double initialSigma = 20.0; final double transitionMean = 0.0; final double transitionSigma = 0.1; final double obsMean = 0.0; final double obsSigma = 1.0; // Sample from system to be estimated final double[] x = new double[hmmLength]; x[0] = testRand.nextGaussian() * initialSigma + initialMean; for (int i = 1; i < hmmLength; ++i) { x[i] = x[i-1] + testRand.nextGaussian() * transitionSigma + transitionMean; } final double[] obsNoise = new double[hmmLength]; final double[] o = x.clone(); for (int i = 0; i < hmmLength; ++i) { o[i] += obsNoise[i] = testRand.nextGaussian() * obsSigma + obsMean; } // Solve using Gibbs final FactorGraph sg = new FactorGraph(); Real Xo, Xi, Ob; try (CurrentModel cur = using(sg)) { Xo = boundary(real("Xo")); Xi = boundary(real("Xi")); Ob = boundary(real("Ob")); name("transitionNoise", addFactor(new AdditiveNoise(transitionSigma), Xo, Xi)); name("observationNoise", addFactor(new AdditiveNoise(obsSigma), Ob, Xi)); } FactorGraph fg = name("fg", new FactorGraph()); RealStream X = new RealStream("X"), O = new RealStream("O"); FactorGraphStream f = fg.addRepeatedFactor(sg, X, X.getSlice(1), O); f.setBufferSize(bufferSize); // Solve final GibbsSolverGraph sfg = requireNonNull(fg.setSolverFactory(new GibbsSolver())); int inputIndex = 0, outputIndex = 0; final double[] output = new double[hmmLength]; for (int j = 0, end = O.size(); j < end; ++j) { O.get(j).setPrior(o[inputIndex++]); } fg.initialize(); fg.setNumSteps(0); final GibbsReal X0 = sfg.getReal(X.get(0)); final ISolverFactorGibbs X0first = X0.getSibling(0); final ISolverFactorGibbs X0next = X0.getSibling(1); final GibbsReal Olast = sfg.getReal(O.get(O.size()-1)); final int ln = hmmLength - bufferSize; for (int i = 0; i < ln; ++i) { fg.solveOneStep(); output[outputIndex++] = X0.getBestSample(); if (!fg.hasNext()) { break; } final double tmp = X0first.getPotential(); fg.advance(); assertEquals(tmp, X0next.getPotential(), 0.0); Olast.setAndHoldSampleValue(o[inputIndex++]); } final double[] actualdiff = new double[hmmLength]; final double[] obsdiff = new double[hmmLength]; double actualnorm = 0.0, obsnorm = 0.0; for (int i = 0; i < ln; ++i) { double diff = actualdiff[i] = x[i] - output[i]; actualnorm += diff * diff; diff = obsdiff[i] = x[i] - o[i]; obsnorm += diff * diff; } actualnorm = Math.sqrt(actualnorm); obsnorm = Math.sqrt(obsnorm); assertTrue(actualnorm < 1.0); assertTrue(obsnorm > 3.0); } @Test public void testRolledUpBeliefMoments() { // Java version of MATLAB testBeliefMoments/test3 // Construct model final int numDataPoints = 10; final double dataPrecision = 1e4; final double transitionPrecision = 10; final FactorGraph fg = new FactorGraph(); fg.setName("root"); final GibbsSolverGraph sfg = requireNonNull(fg.setSolverFactory(new GibbsSolver())); final FactorGraph nfg = new FactorGraph(); nfg.setName("nested"); try (CurrentModel current = using(nfg)) { Real x = boundary(real("x")); /*Real y = */ boundary(name("y",normal(name("x11",product(x, 1.1)), transitionPrecision))); } final RealStream vars = new RealStream("r"); fg.addRepeatedFactor(nfg, vars, vars.getSlice(1)); FactorFunctionDataSource dataSource = new FactorFunctionDataSource(); for (int i = 0; i < numDataPoints; ++i) { dataSource.add(new Normal(1.0, dataPrecision)); } vars.setDataSource(dataSource); // Configure Gibbs DimpleEnvironment env = DimpleEnvironment.active(); env.setOption(DimpleOptions.randomSeed, 2L); env.setOption(GibbsOptions.numSamples, 3000); env.setOption(GibbsOptions.burnInScans, 10); fg.initialize(); // Construct second model final FactorGraph fg2 = new FactorGraph(); Real[] r = new Real[numDataPoints]; try (CurrentModel current = using(fg2)) { r[0] = real("r0"); for (int i = 1; i < numDataPoints; ++i) { r[i] = name("r"+i, normal(name("r"+i+"x11", product(r[i-1], 1.1)), transitionPrecision)); } } // Configure Gibbs final GibbsSolverGraph sfg2 = requireNonNull(fg2.setSolverFactory(new GibbsSolver())); // Options inherited from environment // Run fg.setNumSteps(0); for (int i = 0; fg.hasNext(); fg.advance(), ++i) { r[i].setPrior(new Normal(1, dataPrecision)); r[i+1].setPrior(new Normal(1, dataPrecision)); fg.solveOneStep(); fg2.solve(); GibbsReal a = sfg.getReal(vars.get(0)); GibbsReal b = sfg.getReal(vars.get(1)); GibbsReal a2 = sfg2.getReal(r[i]); GibbsReal b2 = sfg2.getReal(r[i+1]); double a2mean = a2.getSampleMean(); double amean = a.getSampleMean(); double b2mean = b2.getSampleMean(); double bmean = b.getSampleMean(); double a2variance = a2.getSampleVariance(); double avariance = a.getSampleVariance(); double b2variance = b2.getSampleVariance(); double bvariance = b.getSampleVariance(); // System.out.format("%d: a2 %f/%f\n", i, a2mean, a2variance); // System.out.format("%d: a %f/%f\n", i, amean, avariance); // // System.out.format("%d: b2 %f/%f\n", i, b2mean, b2variance); // System.out.format("%d: b %f/%f\n", i, bmean, bvariance); assertEquals(0.0, a2mean - amean, 0.01); assertEquals(0.0, 1.0 - a2variance / avariance, 0.1); assertEquals(0.0, b2mean - bmean, 0.01); assertEquals(0.0, 1.0 - b2variance / bvariance, 0.1); } } }