/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.solvers.gibbs; import static java.util.Objects.*; import static org.junit.Assert.*; import org.junit.Test; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.variables.Bit; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.model.variables.RealJoint; import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph; import com.analog.lyric.dimple.solvers.gibbs.GibbsDiscrete; import com.analog.lyric.dimple.solvers.gibbs.GibbsRealJoint; import com.analog.lyric.dimple.solvers.gibbs.GibbsReal; import com.analog.lyric.dimple.test.DimpleTestBase; /** * Tests setting of {@link GibbsOptions} * @since 0.07 * @author Christopher Barber */ @SuppressWarnings({"null", "deprecation"}) public class TestGibbsOptions extends DimpleTestBase { @Test public void test() { // Test default values assertEquals(1, GibbsOptions.numSamples.defaultIntValue()); assertEquals(0, GibbsOptions.numRandomRestarts.defaultIntValue()); assertEquals(1, GibbsOptions.scansPerSample.defaultIntValue()); assertEquals(0, GibbsOptions.burnInScans.defaultIntValue()); assertFalse(GibbsOptions.saveAllSamples.defaultBooleanValue()); assertFalse(GibbsOptions.saveAllScores.defaultBooleanValue()); assertFalse(GibbsOptions.enableAnnealing.defaultValue()); assertEquals(1.0, GibbsOptions.initialTemperature.defaultDoubleValue(), 1.0); assertEquals(1.0, GibbsOptions.annealingHalfLife.defaultDoubleValue(), 1.0); // Build test graph FactorGraph fg = new FactorGraph(); Bit b1 = new Bit(); Bit b2 = new Bit(); fg.addVariables(b1, b2); Real r1 = new Real(); Real r2 = new Real(); fg.addVariables(r1, r2); RealJoint j1 = new RealJoint(2); RealJoint j2 = new RealJoint(2); fg.addVariables(j1, j2); int nVars = fg.getVariableCount(); // Test default initialization GibbsSolverGraph sfg = requireNonNull(fg.setSolverFactory(new GibbsSolver())); GibbsDiscrete sb1 = (GibbsDiscrete)sfg.getSolverVariable(b1); GibbsDiscrete sb2 = (GibbsDiscrete)sfg.getSolverVariable(b2); GibbsReal sr1 = (GibbsReal)sfg.getReal(r1); GibbsReal sr2 = (GibbsReal)sfg.getReal(r2); GibbsRealJoint sj1 = (GibbsRealJoint)sfg.getSolverVariable(j1); GibbsRealJoint sj2 = (GibbsRealJoint)sfg.getSolverVariable(j2); assertEquals(GibbsOptions.numSamples.defaultIntValue(), sfg.getNumSamples()); assertEquals(GibbsOptions.numRandomRestarts.defaultIntValue(), sfg.getNumRestarts()); assertEquals(nVars * GibbsOptions.burnInScans.defaultIntValue(), sfg.getBurnInUpdates()); sfg.initialize(); assertEquals(GibbsOptions.numSamples.defaultIntValue(), sfg.getNumSamples()); assertEquals(GibbsOptions.numRandomRestarts.defaultIntValue(), sfg.getNumRestarts()); assertEquals(nVars * GibbsOptions.scansPerSample.defaultIntValue(), sfg.getUpdatesPerSample()); assertEquals(nVars * GibbsOptions.burnInScans.defaultIntValue(), sfg.getBurnInUpdates()); assertFalse(sfg.isTemperingEnabled()); assertEquals(GibbsOptions.initialTemperature.defaultDoubleValue(), sfg.getInitialTemperature(), 0.0); assertEquals(GibbsOptions.annealingHalfLife.defaultDoubleValue(), sfg.getTemperingHalfLifeInSamples(), 1e-9); // Test initialization from options fg.setSolverFactory(null); sfg = requireNonNull(fg.setSolverFactory(new GibbsSolver())); sb1 = requireNonNull((GibbsDiscrete)sfg.getSolverVariable(b1)); sb2 = requireNonNull((GibbsDiscrete)sfg.getSolverVariable(b2)); sr1 = requireNonNull((GibbsReal)sfg.getReal(r1)); sr2 = requireNonNull((GibbsReal)sfg.getReal(r2)); sj1 = requireNonNull((GibbsRealJoint)sfg.getSolverVariable(j1)); sj2 = requireNonNull((GibbsRealJoint)sfg.getSolverVariable(j2)); fg.setOption(GibbsOptions.numSamples, 3); fg.setOption(GibbsOptions.numRandomRestarts, 2); fg.setOption(GibbsOptions.scansPerSample, 2); fg.setOption(GibbsOptions.burnInScans, 4); fg.setOption(GibbsOptions.saveAllSamples, true); b2.setOption(GibbsOptions.saveAllSamples, false); r2.setOption(GibbsOptions.saveAllSamples, false); j2.setOption(GibbsOptions.saveAllSamples, false); fg.setOption(GibbsOptions.enableAnnealing, true); fg.setOption(GibbsOptions.initialTemperature, Math.PI); fg.setOption(GibbsOptions.annealingHalfLife, 3.1); // These do not take effect until after initialization assertEquals(GibbsOptions.numSamples.defaultIntValue(), sfg.getNumSamples()); assertEquals(GibbsOptions.numRandomRestarts.defaultIntValue(), sfg.getNumRestarts()); assertFalse(sfg.isTemperingEnabled()); assertEquals(0.0, sfg.getInitialTemperature(), 0.0); assertEquals(Math.log(2), sfg.getTemperingHalfLifeInSamples(), 1e-9); sfg.initialize(); assertEquals(3, sfg.getNumSamples()); assertEquals(2, sfg.getNumRestarts()); assertEquals(nVars * 2 /*scansPerSample */, sfg.getUpdatesPerSample()); assertEquals(nVars * 4 /* burnInScans */, sfg.getBurnInUpdates()); assertEquals(true, sb1.getOptionOrDefault(GibbsOptions.saveAllSamples)); assertEquals(false, sb2.getOptionOrDefault(GibbsOptions.saveAllSamples)); assertEquals(true, sr1.getOptionOrDefault(GibbsOptions.saveAllSamples)); assertEquals(false, sr2.getOptionOrDefault(GibbsOptions.saveAllSamples)); assertEquals(true, sj1.getOptionOrDefault(GibbsOptions.saveAllSamples)); assertEquals(false, sj2.getOptionOrDefault(GibbsOptions.saveAllSamples)); assertTrue(sfg.isTemperingEnabled()); assertEquals(Math.PI, sfg.getInitialTemperature(), 0.0); assertEquals(3.1, sfg.getTemperingHalfLifeInSamples(), 1e-9); // Test set methods sfg.setNumSamples(4); assertEquals(4, sfg.getNumSamples()); assertEquals((Integer)4, sfg.getLocalOption(GibbsOptions.numSamples)); sfg.setNumRestarts(5); assertEquals(5, sfg.getNumRestarts()); assertEquals((Integer)5, sfg.getLocalOption(GibbsOptions.numRandomRestarts)); sfg.setUpdatesPerSample(6); assertEquals(new Integer(-1), sfg.getLocalOption(GibbsOptions.scansPerSample)); sfg.setScansPerSample(3); assertEquals(3 * nVars, sfg.getUpdatesPerSample()); assertEquals((Integer)3, sfg.getLocalOption(GibbsOptions.scansPerSample)); sfg.setBurnInScans(5); assertEquals((Integer)5, sfg.getLocalOption(GibbsOptions.burnInScans)); sfg.initialize(); assertEquals(5 * nVars, sfg.getBurnInUpdates()); sfg.setUpdatesPerSample(23); assertEquals(23, sfg.getUpdatesPerSample()); assertEquals(new Integer(-1), sfg.getLocalOption(GibbsOptions.scansPerSample)); sfg.setBurnInUpdates(12); assertEquals(12, sfg.getBurnInUpdates()); assertEquals(new Integer(-1), sfg.getLocalOption(GibbsOptions.burnInScans)); sfg.unsetOption(GibbsOptions.saveAllSamples); b2.unsetOption(GibbsOptions.saveAllSamples); sfg.saveAllSamples(); assertEquals(true, sfg.getLocalOption(GibbsOptions.saveAllSamples)); sfg.disableSavingAllSamples(); assertEquals(false, sfg.getLocalOption(GibbsOptions.saveAllSamples)); sb1.saveAllSamples(); assertEquals(true, sb1.getLocalOption(GibbsOptions.saveAllSamples)); sb1.disableSavingAllSamples(); assertEquals(false, sb1.getLocalOption(GibbsOptions.saveAllSamples)); sr1.saveAllSamples(); assertEquals(true, sr1.getLocalOption(GibbsOptions.saveAllSamples)); sr1.disableSavingAllSamples(); assertEquals(false, sr1.getLocalOption(GibbsOptions.saveAllSamples)); sj1.saveAllSamples(); assertEquals(true, sj1.getLocalOption(GibbsOptions.saveAllSamples)); sj1.disableSavingAllSamples(); assertEquals(false, sj1.getLocalOption(GibbsOptions.saveAllSamples)); sfg.unsetOption(GibbsOptions.saveAllScores); sfg.saveAllScores(); assertEquals(true, sfg.getLocalOption(GibbsOptions.saveAllScores)); sfg.disableSavingAllScores(); assertEquals(false, sfg.getLocalOption(GibbsOptions.saveAllScores)); sfg.disableTempering(); assertFalse(sfg.isTemperingEnabled()); assertEquals(false, sfg.getLocalOption(GibbsOptions.enableAnnealing)); sfg.enableTempering(); assertTrue(sfg.isTemperingEnabled()); assertEquals(true, sfg.getLocalOption(GibbsOptions.enableAnnealing)); sfg.disableTempering(); sfg.unsetOption(GibbsOptions.enableAnnealing); sfg.setInitialTemperature(2.345); assertEquals((Double)2.345, sfg.getLocalOption(GibbsOptions.initialTemperature)); assertTrue(sfg.isTemperingEnabled()); // tempering implicitly enabled when setting initial temperature sfg.disableTempering(); sfg.unsetOption(GibbsOptions.enableAnnealing); sfg.setTemperingHalfLifeInSamples(4); assertEquals(4, sfg.getTemperingHalfLifeInSamples(), 1e-9); assertEquals(4.0, sfg.getLocalOption(GibbsOptions.annealingHalfLife), 0.0); assertTrue(sfg.isTemperingEnabled()); // tempering implicitly enabled when setting tempering half life } }