/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.solvers.core; import static java.util.Objects.*; import static org.junit.Assert.*; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Objects; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.environment.DimpleEnvironment; import com.analog.lyric.dimple.events.DimpleEventHandler; import com.analog.lyric.dimple.events.DimpleEventListener; import com.analog.lyric.dimple.events.SolverEvent; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.core.INode; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.schedulers.schedule.FixedSchedule; import com.analog.lyric.dimple.schedulers.schedule.ISchedule; import com.analog.lyric.dimple.schedulers.scheduleEntry.EdgeScheduleEntry; import com.analog.lyric.dimple.schedulers.scheduleEntry.IScheduleEntry; import com.analog.lyric.dimple.schedulers.scheduleEntry.NodeScheduleEntry; import com.analog.lyric.dimple.solvers.core.FactorToVariableMessageEvent; import com.analog.lyric.dimple.solvers.core.IMessageUpdateEvent; import com.analog.lyric.dimple.solvers.core.VariableToFactorMessageEvent; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.IParameterizedMessage; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactor; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.dimple.solvers.interfaces.ISolverNode; import com.analog.lyric.dimple.solvers.interfaces.ISolverVariable; /** * {@link DimpleEventHandler} for testing {@link IMessageUpdateEvent}. * * @since 0.06 * @author Christopher Barber */ public class TestMessageUpdateEventHandler extends DimpleEventHandler<SolverEvent> { public boolean printEvents = false; public final List<IMessageUpdateEvent> observedEvents = new ArrayList<IMessageUpdateEvent>(); public static TestMessageUpdateEventHandler setUpListener(ISolverFactorGraph solver) { FactorGraph model = solver.getModelObject(); DimpleEventListener listener = DimpleEnvironment.active().createEventListener(); TestMessageUpdateEventHandler handler = new TestMessageUpdateEventHandler(); listener.register(handler, VariableToFactorMessageEvent.class, false, model); listener.register(handler, FactorToVariableMessageEvent.class, false, model); for (Variable var : model.getVariables()) { ISolverVariable svar = requireNonNull(solver.getSolverVariable(var)); assertTrue(listener.isListeningFor(VariableToFactorMessageEvent.class, svar)); assertSame(listener, svar.getEventListener()); } for (Factor factor : model.getFactors()) { ISolverFactor sfactor = requireNonNull(solver.getSolverFactor(factor)); assertTrue(listener.isListeningFor(FactorToVariableMessageEvent.class, sfactor)); assertSame(listener, sfactor.getEventListener()); } solver.initialize(); return handler; } @Override public void handleEvent(SolverEvent event) { assertEquals(event.getModelObject(), event.getSolverObject().getModelEventSource()); assertTrue(event instanceof IMessageUpdateEvent); IMessageUpdateEvent messageEvent = (IMessageUpdateEvent)event; IParameterizedMessage oldMsg = messageEvent.getOldMessage(); IParameterizedMessage newMsg = messageEvent.getNewMessage(); observedEvents.add(messageEvent); assertNotNull(newMsg); assertNotSame(oldMsg, newMsg); if (oldMsg == null) { assertEquals(Double.POSITIVE_INFINITY, messageEvent.computeKLDivergence(), 0.0); } ISolverFactor factor = messageEvent.getFactor(); assertNotNull(factor); ISolverVariable variable = messageEvent.getVariable(); assertNotNull(variable); if (messageEvent.isToFactor()) { assertTrue(event instanceof VariableToFactorMessageEvent); assertSame(variable, event.getSolverObject()); assertSame(factor, variable.getSibling(messageEvent.getEdge())); } else { assertTrue(event instanceof FactorToVariableMessageEvent); assertSame(factor, event.getSolverObject()); assertSame(variable, factor.getSibling(messageEvent.getEdge())); } if (printEvents) { printEvent(messageEvent); } } /** * Asserts that the contents of {@link #observedEvents} corresponds to the given * {@code schedule}. * <p> * This only looks at entries of type {@link EdgeScheduleEntry} and {@link NodeScheduleEntry}. * <p> * @param schedule * @param solver is the root solver graph for use in mapping model to solver nodes. * @since 0.06 */ public void assertEventsFromSchedule(ISchedule schedule, ISolverFactorGraph solver) { Iterator<IMessageUpdateEvent> eventIter = observedEvents.iterator(); Iterator<IScheduleEntry> scheduleIter = schedule.iterator(); while (scheduleIter.hasNext()) { IScheduleEntry scheduleEntry = scheduleIter.next(); if (scheduleEntry instanceof EdgeScheduleEntry) { EdgeScheduleEntry edgeEntry = (EdgeScheduleEntry)scheduleEntry; INode edgeNode = edgeEntry.getNode(); int edge = edgeEntry.getPortNum(); if (edgeNode instanceof Variable) { Variable edgeVar = (Variable)edgeNode; ISolverVariable svar = solver.getSolverVariable(edgeVar); assertEdgeMessageEvent(svar, edge, eventIter); } else if (edgeNode instanceof Factor) { Factor edgeFactor = (Factor)edgeNode; ISolverFactor sfactor = solver.getSolverFactor(edgeFactor); assertEdgeMessageEvent(sfactor, edge, eventIter); } } else if (scheduleEntry instanceof NodeScheduleEntry) { NodeScheduleEntry nodeEntry = (NodeScheduleEntry)scheduleEntry; INode node = nodeEntry.getNode(); if (node instanceof Variable) { Variable edgeVar = (Variable)node; ISolverVariable svar = solver.getSolverVariable(edgeVar); for (int edge = 0, n = edgeVar.getSiblingCount(); edge < n; ++edge) { assertEdgeMessageEvent(svar, edge, eventIter); } } else if (node instanceof Factor) { Factor edgeFactor = (Factor)node; ISolverFactor sfactor = solver.getSolverFactor(edgeFactor); for (int edge = 0, n = edgeFactor.getSiblingCount(); edge < n; ++edge) { assertEdgeMessageEvent(sfactor, edge, eventIter); } } } } assertFalse(eventIter.hasNext()); } public void testNodeSchedule(ISolverFactorGraph solver) { FactorGraph model = solver.getModelObject(); // Create a fixed schedule to exerise full update messages FixedSchedule schedule = new FixedSchedule(model); for (INode node : model.getNodes()) { schedule.add(node); } solver.setSchedule(schedule); solver.iterate(); assertEventsFromSchedule(solver.getSchedule(), solver); observedEvents.clear(); } public void testEdgeSchedule(ISolverFactorGraph solver) { FactorGraph model = solver.getModelObject(); // Create a fixed edge schedule to exercise edge messages FixedSchedule schedule = new FixedSchedule(model); for (INode node : model.getNodes()) { for (int i = node.getSiblingCount(); --i>=0;) { schedule.add(node, i); } } solver.setSchedule(schedule); solver.iterate(); assertEventsFromSchedule(solver.getSchedule(), solver); observedEvents.clear(); } private void assertEdgeMessageEvent(@Nullable ISolverVariable svar, int edge, Iterator<IMessageUpdateEvent> eventIter) { if (DimpleEventListener.sourceHasListenerFor(Objects.requireNonNull(svar), VariableToFactorMessageEvent.class)) { assertTrue(eventIter.hasNext()); IMessageUpdateEvent event = eventIter.next(); assertTrue(event.isToFactor()); assertTrue(event instanceof VariableToFactorMessageEvent); assertSame(svar, event.getVariable()); assertEquals(edge, event.getEdge()); } } private void assertEdgeMessageEvent(@Nullable ISolverFactor sfactor, int edge, Iterator<IMessageUpdateEvent> eventIter) { if (DimpleEventListener.sourceHasListenerFor(Objects.requireNonNull(sfactor), FactorToVariableMessageEvent.class)) { assertTrue(eventIter.hasNext()); IMessageUpdateEvent event = eventIter.next(); assertFalse(event.isToFactor()); assertTrue(event instanceof FactorToVariableMessageEvent); assertSame(sfactor, event.getFactor()); assertEquals(edge, event.getEdge()); } } @SuppressWarnings("null") void printEvent(IMessageUpdateEvent event) { ISolverNode source, target; if (event.isToFactor()) { source = event.getVariable(); target = event.getFactor(); } else { source = event.getFactor(); target = event.getFactor(); } System.out.format("%s: %s to %s: KL = %g", event.getClass().getSimpleName(), source.getModelObject().getName(), target.getModelObject().getName(), event.computeKLDivergence() ); System.out.println(""); } }