/******************************************************************************* * Copyright 2014-2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.model; import static org.junit.Assert.*; import org.apache.commons.math3.stat.descriptive.moment.Variance; import org.eclipse.jdt.annotation.Nullable; import org.junit.Test; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.transform.JunctionTreeTransform; import com.analog.lyric.dimple.model.transform.JunctionTreeTransformMap; import com.analog.lyric.dimple.model.transform.JunctionTreeTransformMap.AddedJointVariable; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.model.variables.VariableList; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph; import com.analog.lyric.dimple.solvers.gibbs.ISolverVariableGibbs; import com.analog.lyric.dimple.test.DimpleTestBase; import com.analog.lyric.util.misc.Misc; /** * Tests for {@link JunctionTreeTransform} */ public class TestJunctionTreeTransform extends DimpleTestBase { private final RandomGraphGenerator _graphGenerator = new RandomGraphGenerator(testRand); private static final DiscreteDomain d2 = DiscreteDomain.range(0, 1); private static final DiscreteDomain d3 = DiscreteDomain.range(0, 2); private static final DiscreteDomain d4 = DiscreteDomain.range(0, 3); private static final DiscreteDomain d5 = DiscreteDomain.range(0, 5); @Test public void testTrivialLoop() { testGraph(_graphGenerator.buildTrivialLoop()); } @Test public void testTriangle() { testGraph(_graphGenerator.buildTriangle()); } @Test public void testGrid2() { testGraph(_graphGenerator.buildGrid(2)); } @Test public void testGrid3() { testGraph(_graphGenerator.domains( d2, d3, d4).buildGrid(3)); } @Test public void testGrid4() { testGraph(_graphGenerator.buildGrid(4)); } @Test public void testGrid2by20() { testGraph(_graphGenerator.domains(d2, d3, d5).buildGrid(2, 20)); } @Test public void testGrid1by100() { FactorGraph model = _graphGenerator.domains(d2, d3, d4).buildGrid(1, 100); assertTrue(model.isTree()); testTree(model); } @Test public void testRandomGraphs() { final int nGraphs = 20; final int maxSize = 1000; RandomGraphGenerator gen = _graphGenerator.maxBranches(2).maxTreeWidth(5); for (int i = 0; i < nGraphs; ++i) { testGraph(gen.buildRandomGraph(testRand.nextInt(maxSize) + 10), null); } } @Test public void testRandomTree() { FactorGraph tree = _graphGenerator.maxBranches(5).domains(d2, d3, d4, d5).buildRandomTree(500); assertTrue(tree.isTree()); testTree(tree); } /** * @see RandomGraphGenerator#buildStudentNetwork() */ @Test public void testStudentNetwork() { testGraph(_graphGenerator.buildStudentNetwork()); } /*----------------- * Helper methods */ /** * Assert that source and target graphs in {@code transformMap} represent the same * joint distribution down to some level of precision. * * @param transformMap */ @SuppressWarnings("null") private void assertModelsEquivalent(JunctionTreeTransformMap transformMap) { if (transformMap.isIdentity()) { return; } final FactorGraph source = transformMap.source(); final FactorGraph target = transformMap.target(); GibbsSolver gibbs = new GibbsSolver(); GibbsSolverGraph sourceGibbs = source.setSolverFactory(gibbs); GibbsSolverGraph targetGibbs = target.setSolverFactory(gibbs); targetGibbs.initialize(); final int nSamples = 100; final double[] differences = new double[nSamples]; for (int n = 0; n < nSamples; ++n) { // Generate a sample on the source graph source.solve(); // Copy sample values to new graph for (Variable sourceVar : source.getVariables()) { Variable targetVar = transformMap.sourceToTargetVariable(sourceVar); ISolverVariableGibbs sourceSVar = sourceGibbs.getSolverVariable(sourceVar); ISolverVariableGibbs targetSVar = targetGibbs.getSolverVariable(targetVar); targetSVar.setCurrentSample(sourceSVar.getCurrentSampleValue()); } // Update values of added variables for (AddedJointVariable<?> added : transformMap.addedJointVariables()) { final ISolverVariableGibbs addedSVar = targetGibbs.getSolverVariable(added.getVariable()); final Value value = addedSVar.getCurrentSampleValue(); final Value[] inputs = new Value[added.getInputCount()]; for (int i = inputs.length; --i>=0;) { final Variable inputVar = added.getInput(i); final ISolverVariableGibbs inputSVar = targetGibbs.getSolverVariable(inputVar); inputs[i] = inputSVar.getCurrentSampleValue(); } added.updateValue(value, inputs); } // Compare the joint likelihoods final double sourceEnergy = sourceGibbs.getSampleScore(); final double targetEnergy = targetGibbs.getSampleScore(); final double difference = sourceEnergy - targetEnergy; if (Math.abs(difference) > 1e-10) { Misc.breakpoint(); } differences[n] = difference; } double variance = new Variance().evaluate(differences); assertEquals(0.0, variance, 1e-10); } private void testGraph(FactorGraph model) { testGraph(model, false); } private void testTree(FactorGraph model) { testGraph(model, true); } private void testGraph(FactorGraph model, @Nullable Boolean expectIdentity) { try { testGraphImpl(model, expectIdentity); } catch (Throwable ex) { String msg = String.format("%s. TestJunctionTreeTransform._seed==%dL", ex.toString(), testRand.getSeed()); ex.printStackTrace(System.err); System.err.format(">>> TestJunctionTreeTransform._seed==%dL;<<<\n", testRand.getSeed()); throw new RuntimeException(msg, ex); } } private void testGraphImpl(FactorGraph model, @Nullable Boolean expectIdentity) { JunctionTreeTransform jt = new JunctionTreeTransform().random(testRand); assertSame(testRand, jt.random()); assertFalse(jt.useConditioning()); JunctionTreeTransformMap transformMap = jt.transform(model); if (expectIdentity != null) { assertEquals(expectIdentity, transformMap.isIdentity()); } if (transformMap.isIdentity()) { assertTrue(model.isForest()); } for (Factor factor : transformMap.target().getFactors()) { // Name target factors as a debugging aid RandomGraphGenerator.labelFactor(factor); } assertTrue(transformMap.target().isForest()); assertModelsEquivalent(transformMap); // Try with conditioning model.setSolverFactory(null); VariableList variables = model.getVariables(); for (int i = 0; i < 100000; ++i) { Variable variable = variables.getByIndex(testRand.nextInt(variables.size())); if (variable instanceof Discrete) { Discrete discrete = (Discrete)variable; discrete.setPriorIndex(testRand.nextInt(discrete.getDomain().size())); break; } } jt.useConditioning(true); assertTrue(jt.useConditioning()); transformMap = jt.transform(model); for (Factor factor : transformMap.target().getFactors()) { // Name target factors as a debugging aid RandomGraphGenerator.labelFactor(factor); } assertTrue(transformMap.target().isForest()); assertModelsEquivalent(transformMap); } }