/******************************************************************************* * Copyright 2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.solvers.sumproduct; import static com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.*; import static com.analog.lyric.math.Utilities.*; import static java.util.Objects.*; import static org.junit.Assert.*; import org.junit.Test; import com.analog.lyric.dimple.data.ValueDataLayer; import com.analog.lyric.dimple.events.DimpleEventLogger; import com.analog.lyric.dimple.factorfunctions.core.FactorTable; import com.analog.lyric.dimple.factorfunctions.core.IFactorTable; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.core.FactorGraphIterables; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.CurrentModel; import com.analog.lyric.dimple.model.values.DiscreteValue; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Bit; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.solvers.core.FactorToVariableMessageEvent; import com.analog.lyric.dimple.solvers.core.VariableToFactorMessageEvent; import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolver; import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph; import com.analog.lyric.dimple.solvers.sumproduct.SumProductTableFactor; import com.analog.lyric.dimple.test.DimpleTestBase; import com.analog.lyric.dimple.test.model.RandomGraphGenerator; import com.analog.lyric.util.misc.IMapList; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; import cern.colt.list.IntArrayList; /** * * @since 0.08 * @author Christopher Barber */ public class TestDiscreteLikelihood extends DimpleTestBase { @Test public void testSingleFactor() { FactorGraph fg = new FactorGraph(); try (CurrentModel cur = using(fg)) { Bit a = bit("a"), b = bit("b"); IFactorTable table = FactorTable.create(DiscreteDomain.bit(), DiscreteDomain.bit()); table.setWeightForIndices(.2, 0, 0); table.setWeightForIndices(.8, 0, 1); table.setWeightForIndices(.7, 1, 0); table.setWeightForIndices(.3, 1, 1); Factor factor = name("F", fg.addFactor(table, a, b)); factor.setDirectedTo(new int[] { 1 }); testGraph(fg, false); } } @Test public void testTwoFactors() { FactorGraph fg = new FactorGraph(); try (CurrentModel cur = using(fg)) { Bit a = bit("a"), b = bit("b"), c = bit("c"); IFactorTable table = FactorTable.create(DiscreteDomain.bit(), DiscreteDomain.bit()); table.setWeightForIndices(.2, 0, 0); table.setWeightForIndices(.8, 0, 1); table.setWeightForIndices(.7, 1, 0); table.setWeightForIndices(.3, 1, 1); Factor fab = name("F(a,b)", fg.addFactor(table, a, b)); fab.setDirectedTo(new int[] { 1 }); Factor fbc = name("F(b,c)", fg.addFactor(table, b, c)); fbc.setDirectedTo(new int[] { 1 }); testGraph(fg, false); } } @Test public void testRandomTrees() { testRand.setSeed(0x8474da1a9a0f86ddL); FactorGraph fg = new RandomGraphGenerator(testRand).maxBranches(2).buildRandomTree(4); testGraph(fg, false); fg = new RandomGraphGenerator(testRand).maxBranches(5).buildRandomTree(20); testGraph(fg, false); } private void testGraph(FactorGraph fg, boolean debug) { SumProductSolverGraph sfg = requireNonNull(fg.setSolverFactory(new SumProductSolver())); if (debug) { @SuppressWarnings("resource") DimpleEventLogger logger = new DimpleEventLogger(); logger.log(FactorToVariableMessageEvent.class, sfg); logger.log(VariableToFactorMessageEvent.class, sfg); } final double logZ = sfg.computeLogPartitionFunction(); if (debug) { System.out.format("log Z=%g, Z=%g\n", logZ, energyToWeight(-logZ)); } // Do solve again on a copy of the graph with all factors merged into single giant factor. final BiMap<Object,Object> old2new = HashBiMap.create(); final BiMap<Object,Object> new2old = old2new.inverse(); FactorGraph model2 = fg.copyRoot(old2new); IMapList<Factor> factors2 = model2.getFactors(); Factor factor2 = model2.join(factors2.toArray(new Factor[factors2.size()])); SumProductSolverGraph sfg2 = requireNonNull(model2.setSolverFactory(new SumProductSolver())); SumProductTableFactor sfactor2 = (SumProductTableFactor)sfg2.getSolverFactor(factor2); model2.solve(); IFactorTable beliefs = sfactor2.getBeliefTable(); int N = 100; ValueDataLayer dataLayer = new ValueDataLayer(fg); for (int i = 0; i < N; ++i) { for (Variable var : FactorGraphIterables.variables(fg)) { Discrete discrete = var.asDiscreteVariable(); DiscreteValue value = Value.create(discrete.getDomain()); final int index = testRand.nextInt(discrete.getDomain().size()); value.setIndex(index); dataLayer.put(var, value); discrete.setGuessIndex(index); if (debug) { System.out.format("%s=%d\n", discrete, index); } } @SuppressWarnings("deprecation") final double ll = sfg.getScore() - logZ; final double likelihood = energyToWeight(ll); IntArrayList indices = new IntArrayList(factor2.getSiblingCount()); for (Variable var : factor2.getSiblings()) { indices.add(requireNonNull(dataLayer.get(new2old.get(var))).getIndex()); } indices.trimToSize(); final double expectedll = beliefs.getEnergyForIndices(indices.elements()); final double expectedLikelihood = energyToWeight(expectedll); assertEquals(expectedLikelihood, likelihood, 1e-10); assertEquals(expectedll, ll, 1e-10); } } }