/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.solvers.gibbs; import static java.util.Objects.*; import static org.hamcrest.CoreMatchers.*; import static org.junit.Assert.*; import java.util.ArrayList; import java.util.List; import org.junit.Test; import com.analog.lyric.dimple.environment.DimpleEnvironment; import com.analog.lyric.dimple.events.DimpleEventHandler; import com.analog.lyric.dimple.events.DimpleEventListener; import com.analog.lyric.dimple.factorfunctions.Bernoulli; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.variables.Bit; import com.analog.lyric.dimple.solvers.gibbs.GibbsBurnInEvent; import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions; import com.analog.lyric.dimple.solvers.gibbs.GibbsSampleStatisticsEvent; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph; import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraphEvent; import com.analog.lyric.dimple.test.DimpleTestBase; /** * * @since 0.08 * @author Christopher Barber */ public class TestGibbsSolverGraphEvents extends DimpleTestBase { @Test public void test() { FactorGraph fg = new FactorGraph(); Bit a = new Bit(); Bit b = new Bit(); Bit c = new Bit(); fg.addFactor(new Bernoulli(.4), a ,b, c); GibbsSolverGraph sfg = requireNonNull(fg.setSolverFactory(new GibbsSolver())); fg.setOption(GibbsOptions.numSamples, 4); fg.setOption(GibbsOptions.numRandomRestarts, 2); fg.setOption(GibbsOptions.burnInScans, 3); // // Set up listener // DimpleEnvironment env = fg.getEnvironment(); DimpleEventListener listener = env.createEventListener(); GibbsEventHandler handler = new GibbsEventHandler(); fg.solve(); assertTrue(handler.events.isEmpty()); listener.register(handler, GibbsBurnInEvent.class, env); fg.solve(); int expectedSize = GibbsOptions.numRandomRestarts.getOrDefault(sfg) + 1; assertEquals(expectedSize, handler.events.size()); for (int i = 0; i < expectedSize; ++i) { GibbsSolverGraphEvent event = handler.events.get(i); assertSame(event.getSolverObject(), sfg); assertSame(event.getSource(), sfg); assertSame(event.getModelObject(), fg); assertTrue(event instanceof GibbsBurnInEvent); GibbsBurnInEvent burnInEvent = (GibbsBurnInEvent)event; assertEquals(i, burnInEvent.restartCount()); assertTrue(Double.isNaN(burnInEvent.temperature())); assertThat(burnInEvent.toString(1), containsString("burn-in restart " + i)); } handler.events.clear(); fg.setOption(GibbsOptions.enableAnnealing, true); fg.solve(); assertEquals(expectedSize, handler.events.size()); for (int i = 0; i < expectedSize; ++i) { GibbsSolverGraphEvent event = handler.events.get(i); GibbsBurnInEvent burnInEvent = (GibbsBurnInEvent)event; double temperature = burnInEvent.temperature(); assertEquals(GibbsOptions.initialTemperature.getOrDefault(sfg), temperature, 0.0); String tempString = String.format("temperature %f", temperature); assertThat(burnInEvent.toString(1), containsString(tempString)); assertThat(burnInEvent.toString(0), not(containsString(tempString))); } handler.events.clear(); listener.register(handler, GibbsSolverGraphEvent.class, env); listener.block(GibbsBurnInEvent.class, false, sfg); sfg.setOption(GibbsOptions.saveAllScores, true); sfg.setOption(GibbsOptions.numRandomRestarts, 1); fg.solve(); expectedSize = 2 * 4; assertEquals(expectedSize, handler.events.size()); double[] scores = requireNonNull(sfg.getAllScores()); assertEquals(expectedSize, scores.length); for (int i = 0; i < expectedSize; ++i) { GibbsSampleStatisticsEvent event = (GibbsSampleStatisticsEvent)handler.events.get(i); assertEquals(scores[i], event.sampleScore(), 0.0); } } static class GibbsEventHandler extends DimpleEventHandler<GibbsSolverGraphEvent> { List<GibbsSolverGraphEvent> events = new ArrayList<>(); @Override public void handleEvent(GibbsSolverGraphEvent event) { events.add(event); } } }