/*******************************************************************************
* Copyright 2014 Analog Devices, Inc. Licensed under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License. You may obtain a
* copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable
* law or agreed to in writing, software distributed under the License is distributed on an "AS IS"
* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License
* for the specific language governing permissions and limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.minsum;
import static com.analog.lyric.dimple.test.solvers.sumproduct.TestSumProductOptimizedUpdate.*;
import static org.junit.Assert.*;
import java.util.List;
import java.util.Random;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.core.FactorTable;
import com.analog.lyric.dimple.factorfunctions.core.FactorTableRepresentation;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.JointDomainIndexer;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.options.BPOptions;
import com.analog.lyric.dimple.schedulers.FloodingScheduler;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph;
import com.analog.lyric.dimple.solvers.minsum.MinSumSolver;
import com.analog.lyric.dimple.solvers.minsum.MinSumSolverGraph;
import com.analog.lyric.dimple.solvers.minsum.MinSumTableFactor;
import com.analog.lyric.dimple.solvers.optimizedupdate.UpdateApproach;
import com.analog.lyric.dimple.test.DimpleTestBase;
import com.analog.lyric.dimple.test.solvers.sumproduct.TestSumProductOptimizedUpdate.Graph;
import com.analog.lyric.dimple.test.solvers.sumproduct.TestSumProductOptimizedUpdate.Graph2;
import com.analog.lyric.options.IOptionHolder;
/**
* @since 0.07
* @author jking
*/
public class TestMinSumOptimizedUpdate extends DimpleTestBase
{
static final UpdateApproach defaultApproach = UpdateApproach.AUTOMATIC;
static final double defaultAllocationScale = 10.0;
static final double defaultExecutionTimeScale = 1.0;
static final double defaultSparseThreshold = 1.0;
private static MinSumSolverGraph getMinSumSolverGraph(FactorGraph fg)
{
MinSumSolverGraph sfg = (MinSumSolverGraph) fg.getSolver();
assertNotNull(sfg);
return sfg;
}
private static MinSumTableFactor getMinSumFactorTable(Factor f)
{
MinSumTableFactor sft = (MinSumTableFactor) f.getSolver();
assertNotNull(sft);
return sft;
}
private void checkDefaults(IOptionHolder optionHolder)
{
assertEquals(defaultApproach, optionHolder.getOptionOrDefault(BPOptions.updateApproach));
assertEquals(defaultAllocationScale, optionHolder.getOptionOrDefault(BPOptions.automaticMemoryAllocationScalingFactor), 1.0e-9);
assertEquals(defaultExecutionTimeScale, optionHolder.getOptionOrDefault(BPOptions.automaticExecutionTimeScalingFactor), 1.0e-9);
assertEquals(defaultSparseThreshold, optionHolder.getOptionOrDefault(BPOptions.optimizedUpdateSparseThreshold), 1.0e-9);
}
/**
* Verify that the solver graph has the correct default property values.
*
* @since 0.07
*/
@Test
public void testGraphPropertiesDefaults()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
MinSumSolverGraph sfg = getMinSumSolverGraph(fg);
checkDefaults(sfg);
}
/**
* Verify that the solver factor has the correct default property values.
*
* @since 0.07
*/
@Test
public void testFactorPropertiesDefaults()
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
Factor f = add2BitFactor(new Random(), fg);
MinSumTableFactor sft = getMinSumFactorTable(f);
checkDefaults(sft);
}
/**
* Verify that properties are properly inherited by factors from their graph.
*
* @since 0.07
*/
@Test
public void testPropertyInheritance()
{
Random rand = new Random();
rand.setSeed(0); // Don't be random
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
MinSumSolverGraph sfg = getMinSumSolverGraph(fg);
Factor f = add2BitFactor(rand, fg);
MinSumTableFactor sft = getMinSumFactorTable(f);
// Make sure the factor has the default values initially
assertEquals(defaultApproach, sft.getOptionOrDefault(BPOptions.updateApproach));
assertEquals(defaultAllocationScale, sft.getOptionOrDefault(BPOptions.automaticMemoryAllocationScalingFactor), 1.0e-9);
assertEquals(defaultExecutionTimeScale, sft.getOptionOrDefault(BPOptions.automaticExecutionTimeScalingFactor), 1.0e-9);
assertEquals(defaultSparseThreshold, sft.getOptionOrDefault(BPOptions.optimizedUpdateSparseThreshold), 1.0e-9);
// Set the properties at the graph to non-default values
final UpdateApproach graphApproach = UpdateApproach.OPTIMIZED;
final double graphAllocationScale = 2.0;
final double graphExecutionTimeScale = 50.0;
final double graphSparseThreshold = 0.6;
sfg.setOption(BPOptions.updateApproach, graphApproach);
sfg.setOption(BPOptions.automaticMemoryAllocationScalingFactor, graphAllocationScale);
sfg.setOption(BPOptions.automaticExecutionTimeScalingFactor, graphExecutionTimeScale);
sfg.setOption(BPOptions.optimizedUpdateSparseThreshold, graphSparseThreshold);
// Check that the factor returns the values programmed on the graph
assertEquals(graphApproach, sft.getOptionOrDefault(BPOptions.updateApproach));
assertEquals(graphAllocationScale, sft.getOptionOrDefault(BPOptions.automaticMemoryAllocationScalingFactor), 1.0e-9);
assertEquals(graphExecutionTimeScale, sft.getOptionOrDefault(BPOptions.automaticExecutionTimeScalingFactor), 1.0e-9);
assertEquals(graphSparseThreshold, sft.getOptionOrDefault(BPOptions.optimizedUpdateSparseThreshold), 1.0e-9);
// Set the properties at the factor to yet different values
final UpdateApproach factorApproach = UpdateApproach.AUTOMATIC;
final double factorAllocationScale = 3.0;
final double factorExecutionTimeScale = 60.0;
final double factorSparseThreshold = 0.7;
sft.setOption(BPOptions.updateApproach, factorApproach);
sft.setOption(BPOptions.automaticMemoryAllocationScalingFactor, factorAllocationScale);
sft.setOption(BPOptions.automaticExecutionTimeScalingFactor, factorExecutionTimeScale);
sft.setOption(BPOptions.optimizedUpdateSparseThreshold, factorSparseThreshold);
// Check that the factor returns the values programmed on the factor
assertEquals(factorApproach, sft.getOptionOrDefault(BPOptions.updateApproach));
assertEquals(factorAllocationScale, sft.getOptionOrDefault(BPOptions.automaticMemoryAllocationScalingFactor), 1.0e-9);
assertEquals(factorExecutionTimeScale, sft.getOptionOrDefault(BPOptions.automaticExecutionTimeScalingFactor), 1.0e-9);
assertEquals(factorSparseThreshold, sft.getOptionOrDefault(BPOptions.optimizedUpdateSparseThreshold), 1.0e-9);
// And that the graph returns its own values still
assertEquals(graphApproach, sfg.getOptionOrDefault(BPOptions.updateApproach));
assertEquals(graphAllocationScale, sfg.getOptionOrDefault(BPOptions.automaticMemoryAllocationScalingFactor), 1.0e-9);
assertEquals(graphExecutionTimeScale, sfg.getOptionOrDefault(BPOptions.automaticExecutionTimeScalingFactor), 1.0e-9);
assertEquals(graphSparseThreshold, sfg.getOptionOrDefault(BPOptions.optimizedUpdateSparseThreshold), 1.0e-9);
// "Unset" the factor properties
sft.unsetOption(BPOptions.updateApproach);
sft.unsetOption(BPOptions.automaticMemoryAllocationScalingFactor);
sft.unsetOption(BPOptions.automaticExecutionTimeScalingFactor);
sft.unsetOption(BPOptions.optimizedUpdateSparseThreshold);
// And check that the factor again returns the graph values
assertEquals(graphApproach, sft.getOptionOrDefault(BPOptions.updateApproach));
assertEquals(graphAllocationScale, sft.getOptionOrDefault(BPOptions.automaticMemoryAllocationScalingFactor), 1.0e-9);
assertEquals(graphExecutionTimeScale, sft.getOptionOrDefault(BPOptions.automaticExecutionTimeScalingFactor), 1.0e-9);
assertEquals(graphSparseThreshold, sft.getOptionOrDefault(BPOptions.optimizedUpdateSparseThreshold), 1.0e-9);
}
public static void runSolver(FactorGraph fg,
final double sparsityControl,
final double damping,
final boolean useMultithreading)
{
ISolverFactorGraph solver = fg.getSolver();
if (solver != null)
{
MinSumSolverGraph ssolver = (MinSumSolverGraph) solver;
solver.useMultithreading(true);
ssolver.setOption(BPOptions.updateApproach, UpdateApproach.NORMAL);
ssolver.setDamping(damping);
fg.initialize();
ssolver.iterate(5);
List<Object> normalBeliefs = getBeliefs(fg);
solver.useMultithreading(useMultithreading);
ssolver.setOption(BPOptions.optimizedUpdateSparseThreshold, sparsityControl);
ssolver.setDamping(damping);
ssolver.setOption(BPOptions.updateApproach, UpdateApproach.OPTIMIZED);
fg.initialize();
ssolver.iterate(5);
List<Object> optimizedBeliefs = getBeliefs(fg);
assertEqualListDoubleArray(normalBeliefs, optimizedBeliefs);
}
else
{
fail("solver was null");
}
}
private void doTest(final int zeroControl,
final double sparsityControl,
final double damping,
final boolean useMultithreading)
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
@SuppressWarnings("unused")
Graph g = new Graph(fg, zeroControl);
runSolver(fg, sparsityControl, damping, useMultithreading);
}
@Test
public void testSparse()
{
final int zeroControl = 2000;
final double sparsityControl = 0.9;
final double damping = 0.9;
final boolean useMultithreading = false;
doTest(zeroControl, sparsityControl, damping, useMultithreading);
}
@Test
public void testSparseMultithreaded()
{
final int zeroControl = 2000;
final double sparsityControl = 0.9;
final double damping = 0.0;
final boolean useMultithreading = true;
doTest(zeroControl, sparsityControl, damping, useMultithreading);
}
@Test
public void testVerySparse()
{
final int zeroControl = -10;
final double sparsityControl = 1.0;
final double damping = 0.9;
final boolean useMultithreading = false;
doTest(zeroControl, sparsityControl, damping, useMultithreading);
}
@Test
public void testDense()
{
final int zeroControl = 0;
final double sparsityControl = 1.0;
final double damping = 0.0;
final boolean useMultithreading = false;
doTest(zeroControl, sparsityControl, damping, useMultithreading);
}
@Test
public void testDenseMultithreaded()
{
final int zeroControl = 0;
final double sparsityControl = 1.0;
final double damping = 0.0;
final boolean useMultithreading = true;
doTest(zeroControl, sparsityControl, damping, useMultithreading);
}
/**
* @since 0.07
*/
private void automaticHelper(final double density, final boolean multithreaded)
{
Random rand = new Random();
rand.setSeed(0); // Don't be random
final FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
DiscreteDomain[] domains = new DiscreteDomain[5];
final int N = 5;
Discrete[] vars = new Discrete[N];
for (int i = 0; i < domains.length; i++)
{
domains[i] = DiscreteDomain.range(1, Math.abs(2 - i) + 2);
vars[i] = new Discrete(domains[i]);
}
IFactorTable table = FactorTable.create(domains);
table.setRepresentation(FactorTableRepresentation.DENSE_WEIGHT);
int[] coordinates = null;
final JointDomainIndexer domainIndexer = table.getDomainIndexer();
final int d = (int) (domainIndexer.getCardinality() * density);
for (int i = 0; i < d; i++)
{
coordinates = domainIndexer.randomIndices(rand, coordinates);
table.setWeightForIndices(rand.nextDouble(), coordinates);
}
fg.addFactor(table, vars);
MinSumSolverGraph sfg = (MinSumSolverGraph) fg.getSolver();
if (sfg == null)
{
fail("MinSumSolverGraph is null");
return;
}
sfg.setOption(BPOptions.updateApproach, UpdateApproach.AUTOMATIC);
sfg.useMultithreading(multithreaded);
// Before optimize, all factors should not have their optimize enable explicitly set
for (Factor factor : fg.getFactors())
{
MinSumTableFactor sft = (MinSumTableFactor) factor.getSolver();
if (sft != null)
{
UpdateApproach automaticUpdateApproach = sft.getAutomaticUpdateApproach();
assertNull(automaticUpdateApproach);
}
}
sfg.initialize();
// Afterward, they should
for (Factor factor : fg.getFactors())
{
MinSumTableFactor sft = (MinSumTableFactor) factor.getSolver();
if (sft != null)
{
UpdateApproach automaticUpdateApproach = sft.getAutomaticUpdateApproach();
assertNotNull(automaticUpdateApproach);
}
}
}
/**
* @since 0.07
*/
@Test
public void testAutomaticSparse()
{
final boolean multithreaded = false;
final double density = 0.05;
automaticHelper(density, multithreaded);
}
/**
* @since 0.07
*/
@Test
public void testAutomaticDense()
{
final boolean multithreaded = false;
final double density = 1.0;
automaticHelper(density, multithreaded);
}
/**
* @since 0.07
*/
@Test
public void testAutomaticMultithreaded()
{
final double density = 1.0;
final boolean multithreaded = true;
automaticHelper(density, multithreaded);
}
private void doTest2(final double sparsityControl,
final double damping,
final boolean useMultithreading,
final Random rnd)
{
FactorGraph fg = new FactorGraph();
fg.setSolverFactory(new MinSumSolver());
fg.setOption(BPOptions.scheduler, new FloodingScheduler());
@SuppressWarnings("unused")
Graph2 g = new Graph2(fg, rnd);
runSolver(fg, sparsityControl, damping, useMultithreading);
}
@Test
public void TestMixedFactorTablesDense()
{
final double sparsityControl = 0.2;
final double damping = 0.0;
final boolean useMultithreading = false;
boolean printSeed = true;
try
{
for (int i = 0; i < 3; ++i)
doTest2(sparsityControl, damping, useMultithreading, testRand);
printSeed = false;
}
finally
{
// On failure print seed so that we can reproduce
if (printSeed)
System.err.println("seed="+testRand.getSeed());
}
}
}