/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.core;
import static com.analog.lyric.util.test.ExceptionTester.*;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.eclipse.jdt.annotation.Nullable;
import org.junit.Test;
import com.analog.lyric.dimple.environment.DimpleEnvironment;
import com.analog.lyric.dimple.events.DimpleEvent;
import com.analog.lyric.dimple.events.DimpleEventHandler;
import com.analog.lyric.dimple.events.DimpleEventListener;
import com.analog.lyric.dimple.events.IModelEventSource;
import com.analog.lyric.dimple.events.ISolverEventSource;
import com.analog.lyric.dimple.events.SolverEvent;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.Normal;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.core.Node;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.solvers.core.SNode;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.IParameterizedMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.ParameterizedMessageBase;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph;
import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolver;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
* Unit test for {@link SNode} base class.
*
* @since 0.06
* @author Christopher Barber
*/
public class TestSNode extends DimpleTestBase
{
private static class TestNode extends SNode<Node>
{
private boolean _supportsMessageEvents = false;
final private Set<Integer> _updatedEdges = new HashSet<Integer>();
final private Map<Integer,TestMessage> _messages = new HashMap<Integer, TestMessage>();
final private List<DimpleEvent> _createdEvents = new ArrayList<DimpleEvent>();
final private ISolverFactorGraph _parent;
public TestNode(Node n, ISolverFactorGraph parent)
{
super(n);
_parent = parent;
}
@Override
public void initialize()
{
super.initialize();
_updatedEdges.clear();
}
@Override
public ISolverFactorGraph getContainingSolverGraph()
{
return _parent;
}
@Deprecated
@Override
public double getScore()
{
return 0;
}
@Override
public double getInternalEnergy()
{
return 0;
}
@Override
public double getBetheEntropy()
{
return 0;
}
@Override
public @Nullable Object getInputMsg(int portIndex)
{
return null;
}
@Override
public IParameterizedMessage getOutputMsg(int edge)
{
return _messages.get(edge);
}
@Override
public ISolverFactorGraph getParentGraph()
{
return _parent;
}
@Override
public SolverNodeMapping getSolverMapping()
{
return _parent.getSolverMapping();
}
/*---------------
* SNode methods
*/
@Override
protected void doUpdateEdge(int edge)
{
_updatedEdges.add(edge);
TestMessage msg = _messages.get(edge);
if (msg == null)
{
msg = new TestMessage(0);
_messages.put(edge, msg);
}
else
{
++msg._counter;
}
}
@Override
protected @Nullable IParameterizedMessage cloneMessage(int edge)
{
if (_supportsMessageEvents)
{
IParameterizedMessage msg = _messages.get(edge);
return msg != null ? msg.clone() : msg;
}
else
{
return super.cloneMessage(edge);
}
}
@Override
public @Nullable SolverEvent createMessageEvent(int edge,
@Nullable IParameterizedMessage oldMsg, IParameterizedMessage newMsg)
{
SolverEvent event = null;
if (_supportsMessageEvents)
{
event = new TestEvent(this, oldMsg, newMsg);
}
else
{
event = super.createMessageEvent(edge, oldMsg, newMsg);
}
if (event != null)
{
_createdEvents.add(event);
}
return event;
}
@Override
protected @Nullable Class<? extends SolverEvent> messageEventType()
{
return _supportsMessageEvents ? TestEvent.class : super.messageEventType();
}
@Override
protected boolean supportsMessageEvents()
{
return _supportsMessageEvents || super.supportsMessageEvents();
}
}
private static class TestMessage extends ParameterizedMessageBase
{
private static final long serialVersionUID = 1L;
private int _counter;
private TestMessage(int counter)
{
_counter = counter;
}
@Override
public TestMessage clone()
{
return new TestMessage(_counter);
}
@Override
public boolean objectEquals(@Nullable Object other)
{
return other instanceof TestMessage && ((TestMessage)other)._counter == _counter;
}
@Override
public void print(PrintStream out, int verbosity)
{
out.format("TestMessage(counter=%d)", _counter);
}
@Override
public double computeKLDivergence(IParameterizedMessage that)
{
return Math.abs(_counter - ((TestMessage)that)._counter);
}
@Override
public double evalEnergy(Value value)
{
return _counter * value.getDouble();
}
@Override
public boolean isNull()
{
return _counter == 0;
}
@Override
public void setFrom(IParameterizedMessage other)
{
_counter = ((TestMessage)other)._counter;
}
@Override
public void setUniform()
{
_counter = 0;
}
@Override
protected double computeNormalizationEnergy()
{
return 0;
}
}
private static class TestEvent extends SolverEvent
{
private static final long serialVersionUID = 1L;
private final @Nullable IParameterizedMessage _oldMsg;
private final @Nullable IParameterizedMessage _newMsg;
protected TestEvent(ISolverEventSource source, @Nullable IParameterizedMessage oldMsg,
@Nullable IParameterizedMessage newMsg)
{
super(source);
_oldMsg = oldMsg;
_newMsg = newMsg;
}
@Override
public @Nullable IModelEventSource getModelObject()
{
return getSource().getModelEventSource();
}
@Override
protected void printDetails(PrintStream out, int verbosity)
{
}
}
private static class TestHandler extends DimpleEventHandler<TestEvent>
{
@Override
public void handleEvent(TestEvent event)
{
TestMessage oldMsg = (TestMessage)event._oldMsg;
TestMessage newMsg = (TestMessage)event._newMsg;
requireNonNull(newMsg);
if (oldMsg != null)
{
assertEquals(oldMsg._counter + 1, newMsg._counter);
}
}
}
@Test
public void test()
{
FactorGraph fg = new FactorGraph();
ISolverFactorGraph sfg = requireNonNull(fg.setSolverFactory(new SumProductSolver()));
Discrete d1 = new Discrete(DiscreteDomain.bit());
Discrete d2 = new Discrete(DiscreteDomain.bit());
Discrete d3 = new Discrete(DiscreteDomain.bit());
fg.addVariables(d1, d2, d3);
TestNode n1 = new TestNode(d1, sfg);
assertSame(d1, n1.getModelObject());
assertEquals(0, n1.getSiblingCount());
assertEquals(0, n1.getFlagValue(-1));
assertSame(sfg, n1.getParentGraph());
assertSame(sfg, n1.getOptionParent());
assertFalse(n1.supportsMessageEvents());
expectThrow(IndexOutOfBoundsException.class, n1, "getSibling", 0);
expectThrow(DimpleException.class, "Not supported.*", n1, "setInputMsg", 42, null);
expectThrow(DimpleException.class, "Not supported.*", n1, "setOutputMsg", 42, null);
expectThrow(DimpleException.class, "Not supported.*", n1, "setInputMsgValues", 42, null);
expectThrow(DimpleException.class, "Not supported.*", n1, "setOutputMsgValues", 42, null);
Factor f13 = fg.addFactor(new Normal(0.0, 1.0), d1, d3);
Factor f12 = fg.addFactor(new Normal(0.0, 1.0), d1, d2);
assertEquals(2, n1.getSiblingCount());
assertSame(f13.getSolver(), n1.getSibling(0));
assertSame(f12.getSolver(), n1.getSibling(1));
assertTrue(n1._updatedEdges.isEmpty());
n1.update();
assertEquals(2, n1._updatedEdges.size());
assertTrue(n1._updatedEdges.contains(0));
assertTrue(n1._updatedEdges.contains(1));
n1.setFlagValue(-1, -1);
assertEquals(-1, n1.getFlagValue(-1));
n1.initialize();
assertEquals(0, n1.getFlagValue(-1));
assertTrue(n1._updatedEdges.isEmpty());
for (int i = 0; i <2; ++i)
{
n1.updateEdge(i);
assertEquals(1, n1._updatedEdges.size());
assertTrue(n1._updatedEdges.contains(i));
n1._updatedEdges.clear();
}
assertTrue(n1._createdEvents.isEmpty());
//
// Test message events
//
DimpleEventListener listener = DimpleEnvironment.active().createEventListener();
assertSame(listener, fg.getEventListener());
assertSame(listener, n1.getEventListener());
assertFalse(listener.isListeningFor(TestEvent.class, n1));
// merely listening will not trigger event creation
n1.initialize();
n1.update();
assertTrue(n1._createdEvents.isEmpty());
TestHandler handler = new TestHandler();
listener.register(handler, TestEvent.class, false, fg);
assertTrue(listener.isListeningFor(TestEvent.class, n1));
// still no events, because not enabled
n1.initialize();
n1.update();
assertTrue(n1._createdEvents.isEmpty());
// still no events because initialize not called
n1._supportsMessageEvents = true;
n1.update();
assertTrue(n1._createdEvents.isEmpty());
n1.initialize();
n1.update();
assertEquals(2, n1._createdEvents.size());
assertSame(n1, n1._createdEvents.get(0).getSource());
assertSame(n1, n1._createdEvents.get(1).getSource());
n1._createdEvents.clear();
n1.updateEdge(0);
assertEquals(1, n1._createdEvents.size());
assertSame(n1, n1._createdEvents.get(0).getSource());
n1._createdEvents.clear();
DimpleEnvironment.active().setEventListener(null);
n1.notifyListenerChanged();
n1.updateEdge(0);
n1.updateEdge(1);
assertTrue(n1._createdEvents.isEmpty());
}
}