/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.sumproduct;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.And;
import com.analog.lyric.dimple.factorfunctions.Xor;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Bit;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.options.BPOptions;
import com.analog.lyric.dimple.solvers.core.SolverBase;
import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions;
import com.analog.lyric.dimple.solvers.minsum.MinSumSolver;
import com.analog.lyric.dimple.solvers.optimizedupdate.UpdateApproach;
import com.analog.lyric.dimple.solvers.sumproduct.Solver;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductDiscrete;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolver;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductTableFactor;
import com.analog.lyric.dimple.solvers.sumproduct.sampledfactor.SampledFactor;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
*
* @since 0.07
* @author Christopher Barber
*/
public class TestSumProductOptions extends DimpleTestBase
{
@SuppressWarnings({ "deprecation", "null" })
@Test
public void test()
{
// Test default values
assertEquals(0.0, BPOptions.damping.defaultValue(), 0.0);
assertTrue(BPOptions.nodeSpecificDamping.defaultValue().isEmpty());
assertEquals(Integer.MAX_VALUE, (int)BPOptions.maxMessageSize.defaultValue());
assertEquals(UpdateApproach.AUTOMATIC, BPOptions.updateApproach.defaultValue());
assertEquals(1.0, BPOptions.automaticExecutionTimeScalingFactor.defaultValue(), 1.0e-9);
assertEquals(10.0, BPOptions.automaticMemoryAllocationScalingFactor.defaultValue(), 1.0e-9);
assertEquals(1.0, BPOptions.optimizedUpdateSparseThreshold.defaultValue(), 1.0e-9);
final int nVars = 4;
FactorGraph fg = new FactorGraph();
Discrete[] vars = new Discrete[nVars];
for (int i = 0; i < nVars; ++i)
{
vars[i] = new Bit();
}
Factor f1 = fg.addFactor(new Xor(), vars); // has custom factor
Factor f2 = fg.addFactor(new And(), vars);
// Check initial defaults
SumProductSolverGraph sfg = requireNonNull(fg.setSolverFactory(new SumProductSolver()));
assertEquals(0.0, sfg.getDamping(), 0.0);
SumProductTableFactor sf1 = (SumProductTableFactor)requireNonNull(f1.getSolver());
assertEquals(0, sf1.getK());
assertEquals(0.0, sf1.getDamping(0), 0.0);
SumProductTableFactor sf2 = (SumProductTableFactor)requireNonNull(f2.getSolver());
assertEquals(0.0, sf2.getDamping(0), 0.0);
assertEquals(0, sf2.getK());
assertEquals(UpdateApproach.AUTOMATIC, sf1.getOptionOrDefault(BPOptions.updateApproach));
assertEquals(SampledFactor.DEFAULT_BURN_IN_SCANS_PER_UPDATE, sfg.getSampledFactorBurnInScansPerUpdate());
assertEquals(SampledFactor.DEFAULT_SAMPLES_PER_UPDATE, sfg.getSampledFactorSamplesPerUpdate());
assertEquals(SampledFactor.DEFAULT_SCANS_PER_SAMPLE, sfg.getSampledFactorScansPerSample());
assertNull(fg.setSolverFactory(null));
// Set initial options on model
fg.setOption(BPOptions.damping, .9);
fg.setOption(BPOptions.maxMessageSize, 10);
fg.setOption(GibbsOptions.burnInScans, 42); // will be overridden by default option in solver graph
fg.setOption(GibbsOptions.scansPerSample, 23); // will be overridden by default option in solver graph
fg.setOption(GibbsOptions.numSamples, 12); // will be overridden by default option in solver graph
BPOptions.nodeSpecificDamping.set(f1, .4, .5, .6, .7);
BPOptions.nodeSpecificDamping.set(f2, .3, .4, .5, .6);
fg.setOption(BPOptions.updateApproach, UpdateApproach.AUTOMATIC);
f2.setOption(BPOptions.updateApproach, UpdateApproach.NORMAL);
// Test options that are updated on initialize()
sfg = requireNonNull(fg.setSolverFactory(new SumProductSolver()));
assertEquals(0.0, sfg.getDamping(), 0.0);
assertEquals(0.0, sf1.getDamping(0), 0.0);
sf1 = (SumProductTableFactor)requireNonNull(f1.getSolver());
assertEquals(0, sf1.getK());
sf2 = (SumProductTableFactor)requireNonNull(f2.getSolver());
assertEquals(0, sf2.getK());
assertEquals(SampledFactor.DEFAULT_BURN_IN_SCANS_PER_UPDATE, sfg.getSampledFactorBurnInScansPerUpdate());
assertEquals(SampledFactor.DEFAULT_SAMPLES_PER_UPDATE, sfg.getSampledFactorSamplesPerUpdate());
assertEquals(SampledFactor.DEFAULT_SCANS_PER_SAMPLE, sfg.getSampledFactorScansPerSample());
assertEquals((Integer)SampledFactor.DEFAULT_SAMPLES_PER_UPDATE, sfg.getLocalOption(GibbsOptions.numSamples));
assertEquals((Integer)SampledFactor.DEFAULT_SCANS_PER_SAMPLE, sfg.getLocalOption(GibbsOptions.scansPerSample));
assertEquals((Integer)SampledFactor.DEFAULT_BURN_IN_SCANS_PER_UPDATE,
sfg.getLocalOption(GibbsOptions.burnInScans));
SumProductDiscrete sv1 = (SumProductDiscrete)vars[0].getSolver();
assertEquals(0.0, sv1.getDamping(0), 0.0);
sfg.initialize();
assertEquals(.9, sfg.getDamping(), 0.0);
assertEquals(.4, sf1.getDamping(0), 0.0);
assertEquals(.5, sf1.getDamping(1), 0.0);
assertEquals(.6, sf1.getDamping(2), 0.0);
assertEquals(.7, sf1.getDamping(3), 0.0);
assertEquals(10, sf1.getK());
assertEquals(.3, sf2.getDamping(0), 0.0);
assertEquals(.4, sf2.getDamping(1), 0.0);
assertEquals(.5, sf2.getDamping(2), 0.0);
assertEquals(.6, sf2.getDamping(3), 0.0);
assertEquals(10, sf2.getK());
assertEquals(.9, sv1.getDamping(0), 0.0);
assertEquals(UpdateApproach.OPTIMIZED, sf1.getEffectiveUpdateApproach());
assertEquals(UpdateApproach.NORMAL, sf2.getEffectiveUpdateApproach());
// Test using set methods
sfg.setDamping(.5);
assertEquals(.5, sfg.getDamping(), 0.0);
assertEquals(.5, requireNonNull(sfg.getLocalOption(BPOptions.damping)), 0.0);
sf1.setK(3);
assertEquals(3, sf1.getK());
assertEquals((Integer)3, sf1.getLocalOption(BPOptions.maxMessageSize));
sf1.setDamping(1, .23);
assertEquals(.4, sf1.getDamping(0), 0.0);
assertEquals(.23, sf1.getDamping(1), 0.0);
assertEquals(.6, sf1.getDamping(2), 0.0);
assertEquals(.7, sf1.getDamping(3), 0.0);
assertArrayEquals(new double[] { .4,.23,.6,.7},
BPOptions.nodeSpecificDamping.get(sf1).toPrimitiveArray(), 0.0);
sfg.setSampledFactorSamplesPerUpdate(142);
sfg.setSampledFactorScansPerSample(24);
sfg.setSampledFactorBurnInScansPerUpdate(11);
assertEquals(142, sfg.getSampledFactorSamplesPerUpdate());
assertEquals(24, sfg.getSampledFactorScansPerSample());
assertEquals(11, sfg.getSampledFactorBurnInScansPerUpdate());
assertEquals((Integer)142, sfg.getLocalOption(GibbsOptions.numSamples));
assertEquals((Integer)24, sfg.getLocalOption(GibbsOptions.scansPerSample));
assertEquals(UpdateApproach.AUTOMATIC, sfg.getOption(BPOptions.updateApproach));
assertNull(sf1.getLocalOption(BPOptions.updateApproach));
sf1.setOption(BPOptions.updateApproach, UpdateApproach.OPTIMIZED);
assertEquals(UpdateApproach.OPTIMIZED, sf1.getOption(BPOptions.updateApproach));
assertEquals(UpdateApproach.OPTIMIZED, sf1.getLocalOption(BPOptions.updateApproach));
sf1.setOption(BPOptions.updateApproach, UpdateApproach.NORMAL);
assertEquals(UpdateApproach.NORMAL, sf1.getOption(BPOptions.updateApproach));
assertEquals(UpdateApproach.NORMAL, sf1.getLocalOption(BPOptions.updateApproach));
assertEquals((Integer)11, sfg.getLocalOption(GibbsOptions.burnInScans));
}
@Test
public void testSolverEquality()
{
SolverBase<?> solver1 = new SumProductSolver();
SolverBase<?> solver2 = new Solver();
SolverBase<?> solver3 = new MinSumSolver();
SolverBase<?> solver4 = new com.analog.lyric.dimple.solvers.gaussian.Solver();
assertEquals(solver1, solver2);
assertEquals(solver1.hashCode(), solver2.hashCode());
assertNotEquals(solver1, solver3);
assertNotEquals(solver1.hashCode(), solver3.hashCode());
assertEquals(solver1, solver4);
assertEquals(solver1.hashCode(), solver4.hashCode());
}
}