/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.junctiontree;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.util.List;
import java.util.Map;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.core.FactorTable;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.domains.JointDomainReindexer;
import com.analog.lyric.dimple.model.factors.DiscreteFactor;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.model.variables.VariableList;
import com.analog.lyric.dimple.solvers.core.SolverBase;
import com.analog.lyric.dimple.solvers.junctiontree.JunctionTreeSolver;
import com.analog.lyric.dimple.solvers.junctiontree.JunctionTreeSolverGraphBase;
import com.analog.lyric.dimple.solvers.junctiontreemap.JunctionTreeMAPSolver;
import com.analog.lyric.dimple.test.DimpleTestBase;
import com.analog.lyric.dimple.test.model.RandomGraphGenerator;
import com.analog.lyric.dimple.test.model.TestJunctionTreeTransform;
import com.analog.lyric.util.misc.IMapList;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.math.DoubleMath;
/**
* Unit tests for {@link JunctionTreeSolver}
* <p>
* @since 0.05
* @author Christopher Barber
* @see TestJunctionTreeTransform
*/
public class TestJunctionTree extends DimpleTestBase
{
private final RandomGraphGenerator _graphGenerator = new RandomGraphGenerator(testRand);
@Test
public void testSolverEquality()
{
SolverBase<?> solver1 = new JunctionTreeSolver();
SolverBase<?> solver2 = new JunctionTreeSolver();
SolverBase<?> solver3 = new com.analog.lyric.dimple.solvers.junctiontree.Solver();
SolverBase<?> solver4 = new JunctionTreeMAPSolver();
SolverBase<?> solver5 = new JunctionTreeMAPSolver();
SolverBase<?> solver6 = new com.analog.lyric.dimple.solvers.junctiontreemap.Solver();
assertEquals(solver1, solver2);
assertEquals(solver1.hashCode(), solver2.hashCode());
assertEquals(solver2, solver3);
assertEquals(solver2.hashCode(), solver3.hashCode());
assertNotEquals(solver3, solver4);
assertNotEquals(solver3.hashCode(), solver4.hashCode());
assertEquals(solver4, solver5);
assertEquals(solver4.hashCode(), solver5.hashCode());
assertEquals(solver5, solver6);
assertEquals(solver5.hashCode(), solver6.hashCode());
}
@Test
public void testTriangle()
{
testGraph(_graphGenerator.buildTriangle());
}
@Test
public void testGrid2()
{
testGraph(_graphGenerator.buildGrid(2));
}
@Test
public void testGrid3()
{
testGraph(_graphGenerator.buildGrid(3));
}
@Test
public void testStudentNetwork()
{
testGraph(_graphGenerator.buildStudentNetwork());
}
@Test
public void testRandomGraphs()
{
final int nGraphs = 20;
final int maxSize = 1000;
RandomGraphGenerator gen = _graphGenerator.maxBranches(2).maxTreeWidth(3);
for (int i = 0; i < nGraphs; ++i)
{
testGraph(gen.buildRandomGraph(testRand.nextInt(maxSize) + 10));
}
}
private void testGraph(FactorGraph model)
{
try
{
testGraphImpl(model);
}
catch (Throwable ex)
{
long seed = testRand.getSeed();
String msg = String.format("%s. TestJunctionTreeTransform._seed==%dL", ex.toString(), seed);
ex.printStackTrace(System.err);
System.err.format(">>> TestJunctionTreeTransform._seed==%dL;<<<\n", seed);
throw new RuntimeException(msg, ex);
}
}
private void testGraphImpl(FactorGraph model)
{
// FIXME
testGraphImpl(model, false);
testGraphImpl(model, true);
}
private void testGraphImpl(FactorGraph model, boolean useMap)
{
testGraphImpl(model, useMap, false);
// Choose a variable at random and give it a fixed value.
final VariableList variables = model.getVariables();
final int varIndex = testRand.nextInt(variables.size());
final Discrete variable = variables.getByIndex(varIndex).asDiscreteVariable();
final int valueIndex = testRand.nextInt(variable.getDomain().size());
variable.asDiscreteVariable().setPriorIndex(valueIndex);
testGraphImpl(model, useMap, false);
testGraphImpl(model, useMap, true);
// Clear fixed value
variable.setPrior(null);
}
private void testGraphImpl(FactorGraph model, boolean useMap, boolean useConditioning)
{
JunctionTreeSolverGraphBase<?> jtgraph =
model.createSolver(useMap ? new JunctionTreeMAPSolver() : new JunctionTreeSolver());
jtgraph.useConditioning(useConditioning);
jtgraph.getTransformer().random(testRand); // set random generator so we can reproduce failures
model.solve();
FactorGraph transformedModel = requireNonNull(jtgraph.getDelegate()).getModelObject();
RandomGraphGenerator.labelFactors(transformedModel);
assertTrue(transformedModel.isForest());
// Do solve again on a copy of the graph with all factors merged into single giant factor.
final BiMap<Object,Object> old2new = HashBiMap.create();
final BiMap<Object,Object> new2old = old2new.inverse();
FactorGraph model2 = model.copyRoot(old2new);
IMapList<Factor> factors2 = model2.getFactors();
Factor factor2 = null;
if (factors2.size() > 0)
{
factor2 = model2.join(factors2.toArray(new Factor[factors2.size()]));
}
model2.setSolverFactory(jtgraph.getDelegateSolverFactory());
model2.solve();
// Compare marginal variable beliefs and scores
for (Variable variable : model.getVariables())
{
final Variable variable2 = (Variable)old2new.get(variable);
final Object belief1 = variable.getBeliefObject();
final Object belief2 = variable2.getBeliefObject();
if (belief1 instanceof double[])
{
assertArrayEquals((double[])belief2, (double[])belief1, 1e-10);
}
else
{
assertEquals(belief1, belief2);
}
// Compare scores
double score = variable.getScore();
double score2 = variable2.getScore();
assertEquals(score, score2, 1e-10);
if (!useMap)
{
// Compare entropy
double entropy = variable.getBetheEntropy();
double entropy1 = variable2.getBetheEntropy();
assertEquals(entropy, entropy1, 1e-10);
// Compare internal energy
double internalEnergy = variable.getInternalEnergy();
double internalEnergy2 = variable2.getInternalEnergy();
assertEquals(internalEnergy, internalEnergy2, 1e-10);
}
}
// Compare factor beliefs
if (!useMap && factor2 instanceof DiscreteFactor)
{
DiscreteFactor discreteFactor2 = (DiscreteFactor) factor2;
JointDomainIndexer fullDomains = discreteFactor2.getDomainList();
final double[] fullBeliefs = discreteFactor2.getBelief();
final int[][] fullBeliefIndices = discreteFactor2.getPossibleBeliefIndices();
final IFactorTable fullTable = FactorTable.create(fullDomains);
fullTable.setWeightsSparse(fullBeliefIndices, fullBeliefs);
final int nFullDomains = fullDomains.size();
final int[] fullToMarginal = new int[nFullDomains];
for (Factor factor : model.getFactors())
{
final DiscreteFactor discreteFactor = (DiscreteFactor)factor;
JointDomainIndexer factorDomains = discreteFactor.getDomainList();
final int nFactorDomains = factorDomains.size();
final double[] beliefs = discreteFactor.getBelief();
final int[][] beliefIndices = discreteFactor.getPossibleBeliefIndices();
// Marginalize corresponding beliefs from full table.
final List<? extends Variable> factorVars = discreteFactor.getSiblings();
for (int from = 0, remove = nFactorDomains; from < nFullDomains; ++from)
{
final Variable fromVar = discreteFactor2.getSibling(from);
final int to = factorVars.indexOf(new2old.get(fromVar));
fullToMarginal[from] = to >= 0 ? to : remove++;
}
final JointDomainReindexer marginalizer =
JointDomainReindexer.createPermuter(fullDomains, factorDomains, fullToMarginal);
final IFactorTable beliefTable2 = fullTable.convert(marginalizer);
double[] beliefs2 = beliefTable2.getWeightsSparseUnsafe();
int[][] beliefIndices2 = beliefTable2.getIndicesSparseUnsafe();
// BUG 27 can result in beliefs that are close to but not equal to zero so we have
// to filter out beliefs that are close to zero.
int i = 0, j = 0;
while (true)
{
while (i < beliefs.length && DoubleMath.fuzzyEquals(beliefs[i], 0.0, 1e-15))
{
++i;
}
while (j < beliefs2.length && DoubleMath.fuzzyEquals(beliefs2[j], 0.0, 1e-15))
{
++j;
}
if (i >= beliefs.length || j >= beliefs2.length)
{
break;
}
assertEquals(beliefs[i], beliefs2[j], 1e-12);
assertArrayEquals(beliefIndices[i], beliefIndices2[j]);
++i;
++j;
}
assertEquals(beliefs.length, i);
assertEquals(beliefs2.length, j);
}
}
// Compare scores for two versions with same guesses
for (int i = 0; i < 10; ++i)
{
// Randomly set guesses
for (Map.Entry<Object,Object> entry : old2new.entrySet())
{
Object node = entry.getKey();
if (node instanceof Discrete)
{
Discrete var = (Discrete)node;
Discrete var2 = (Discrete)entry.getValue();
if (!var.hasFixedValue())
{
int guessIndex = testRand.nextInt(var.getDomain().size());
var.setGuessIndex(guessIndex);
var2.setGuessIndex(guessIndex);
assertEquals(var.getScore(), var2.getScore(), 1e-14);
}
}
}
double score = model.getScore();
double score2 = model2.getScore();
assertEquals(score, score2, 1e-10);
}
if (!useMap)
{
double internalEnergy = model.getInternalEnergy();
double internalEnergy2 = model2.getInternalEnergy();
assertEquals(internalEnergy, internalEnergy2, 1e-10);
// The entropy and free energy depend on the factorization, and thus cannot be compared vs. the
// joint factor and cannot be easily compared with original model either because it depends on
// the beliefs....
}
}
}