/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.solvers.gibbs; import static com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.*; import static org.junit.Assert.*; import org.junit.Test; import com.analog.lyric.dimple.factorfunctions.core.FactorFunction; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.CurrentModel; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.options.SolverOptions; import com.analog.lyric.dimple.solvers.gibbs.GibbsDiscrete; import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions; import com.analog.lyric.dimple.solvers.gibbs.GibbsReal; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph; import com.analog.lyric.dimple.solvers.gibbs.GibbsTableFactor; import com.analog.lyric.dimple.solvers.gibbs.Solver; import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolver; import com.analog.lyric.dimple.test.DimpleTestBase; @SuppressWarnings({"null", "deprecation"}) public class GibbsTest extends DimpleTestBase { protected static boolean debugPrint = false; @Test public void basicTest() { // Test solver equality Solver solver1 = new Solver(); Solver solver2 = new GibbsSolver(); Solver solver3 = new GibbsSolver(); assertEquals(solver1, solver2); assertEquals(solver3, solver2); assertEquals(solver1.hashCode(), solver2.hashCode()); assertEquals(solver2.hashCode(), solver3.hashCode()); assertNotEquals(solver1, new SumProductSolver()); basicDiscreteCase(true); basicDiscreteCase(false); // make sure it still works without factor tables } @Test public void testFixedSample() { // Adapted from MATLAB test of same name FactorGraph fg = new FactorGraph(); Real a,b,c; try (CurrentModel current = using(fg)) { b = fixed("b", 2.0); c = fixed("c", 3.0); a = name("a", sum(b,c)); } GibbsSolverGraph sfg = fg.setSolverFactory(new GibbsSolver()); GibbsReal sa = sfg.getReal(a); fg.initialize(); assertEquals(5.0, sa.getCurrentSample(), 0.0); // TODO add more cases... } private void basicDiscreteCase(boolean useFactorTable) { int numSamples = 10000; int updatesPerSample = 2; int burnInUpdates = 2000; FactorGraph graph = new FactorGraph(); GibbsSolverGraph solver = graph.setSolverFactory(new GibbsSolver()); graph.setOption(GibbsOptions.numSamples, numSamples); solver.setUpdatesPerSample(updatesPerSample); solver.setBurnInUpdates(burnInUpdates); if (useFactorTable) { graph.unsetOption(SolverOptions.maxAutomaticFactorTableSize); } else { graph.setOption(SolverOptions.maxAutomaticFactorTableSize, 0); } Discrete a = new Discrete(1,0); Discrete b = new Discrete(1,0); a.setName("a"); b.setName("b"); Factor pA = graph.addFactor(new PA(), a); Factor pBA = graph.addFactor(new PBA(), b, a); solver.setSeed(1); // Make this repeatable graph.setOption(GibbsOptions.saveAllSamples, true); graph.solve(); GibbsDiscrete sa = (GibbsDiscrete)a.getSolver(); GibbsDiscrete sb = (GibbsDiscrete)b.getSolver(); GibbsTableFactor sA = (GibbsTableFactor)pA.getSolver(); GibbsTableFactor sBA = (GibbsTableFactor)pBA.getSolver(); Object[] aSamples = sa.getAllSamples(); Object[] bSamples = sb.getAllSamples(); int aSum = 0; for (Object s : aSamples) aSum += (Integer)s; double aMean = (double)aSum/(double)aSamples.length; if (debugPrint) System.out.println("sai: " + aMean); int bSum = 0; for (Object s : bSamples) bSum += (Integer)s; double bMean = (double)bSum/(double)bSamples.length; if (debugPrint) System.out.println("sbi: " + bMean); if (debugPrint) System.out.println("aBest: " + sa.getBestSample()); if (debugPrint) System.out.println("bBest: " + sb.getBestSample()); double totalPotential = 0; totalPotential += sA.getPotential(new int[]{sa.getBestSampleIndex()}); totalPotential += sBA.getPotential(new int[]{sb.getBestSampleIndex(),sa.getBestSampleIndex()}); if (debugPrint) System.out.println("Min potential: " + totalPotential + " (" + Math.exp(-totalPotential) + ")"); if (debugPrint) System.out.println("a: " + a.getBelief()[0]); if (debugPrint) System.out.println("b: " + b.getBelief()[0]); double pa1 = 0.2; double pa0 = 1 - pa1; double pb1Ia1 = 0.1; double pb0Ia1 = 1 - pb1Ia1; double pb1Ia0 = 0.75; double pb0Ia0 = 1 - pb1Ia0; double pa1b1 = pa1*pb1Ia1; double ba1b0 = pa1*pb0Ia1; double pa0b1 = pa0*pb1Ia0; @SuppressWarnings("unused") double pa0b0 = pa0*pb0Ia0; double pa1m = pa1b1 + ba1b0; double pb1m = pa1b1 + pa0b1; if (debugPrint) System.out.println("pa1: " + pa1m); if (debugPrint) System.out.println("pb1: " + pb1m); assertTrue((Integer)sa.getBestSample() == 0); assertTrue((Integer)sb.getBestSample() == 1); assertTrue(nearlyEquals(a.getBelief()[0],aMean)); assertTrue(nearlyEquals(b.getBelief()[0],bMean)); assertTrue(nearlyEquals(a.getBelief()[0],0.2047)); assertTrue(nearlyEquals(b.getBelief()[0],0.6115)); assertTrue(nearlyEquals(Math.exp(-totalPotential),0.6)); } public static class PA extends FactorFunction { public PA() {super("PA");} @Override public final double evalEnergy(Value[] input) { double value = 0; int a = input[0].getInt(); if (a == 1) value = 0.2; else value = 0.8; return -Math.log(value); } } public static class PBA extends FactorFunction { public PBA() {super("PBA");} @Override public final double evalEnergy(Value[] input) { double value = 0; int b = input[0].getInt(); int a = input[1].getInt(); if ((b == 1) && (a == 1)) value = 0.1; else if ((b == 0) && (a == 1)) value = 0.9; else if ((b == 1) && (a == 0)) value = 0.75; else value = 0.25; return -Math.log(value); } } private static double TOLERANCE = 1e-12; private boolean nearlyEquals(double a, double b) { double diff = a - b; if (diff > TOLERANCE) return false; if (diff < -TOLERANCE) return false; return true; } }