/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.gibbs;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.eclipse.jdt.annotation.Nullable;
import org.junit.Test;
import com.analog.lyric.dimple.events.DimpleEvent;
import com.analog.lyric.dimple.events.DimpleEventHandler;
import com.analog.lyric.dimple.events.DimpleEventListener;
import com.analog.lyric.dimple.events.IDimpleEventSource;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.domains.Domain;
import com.analog.lyric.dimple.model.domains.RealJointDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Complex;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Real;
import com.analog.lyric.dimple.schedulers.GibbsSequentialScanScheduler;
import com.analog.lyric.dimple.solvers.gibbs.GibbsScoredVariableUpdateEvent;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph;
import com.analog.lyric.dimple.solvers.gibbs.GibbsVariableUpdateEvent;
import com.analog.lyric.dimple.solvers.gibbs.ISolverVariableGibbs;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
* Test generation of {@link GibbsVariableUpdateEvent}s.
*
* @since 0.06
* @author Christopher Barber
*/
@SuppressWarnings("deprecation")
public class TestGibbsVariableUpdateEvent extends DimpleTestBase
{
static class BogoFunction extends FactorFunction
{
@Override
public final double evalEnergy(Value[] arguments)
{
double energy = 0.0;
for (Value argValue : arguments)
{
Object arg = requireNonNull(argValue.getObject());
energy += 1.0;
if (arg instanceof Number)
{
energy += Math.abs(((Number) arg).doubleValue());
}
else if (arg instanceof double[])
{
for (double d : (double[])arg)
{
energy += Math.abs(d);
}
}
else
{
throw new Error("die");
}
}
return energy;
}
}
static class VariableUpdateHandler extends DimpleEventHandler<GibbsVariableUpdateEvent>
{
List<GibbsVariableUpdateEvent> events = new ArrayList<GibbsVariableUpdateEvent>();
@Override
public void handleEvent(GibbsVariableUpdateEvent event)
{
events.add(event);
// printEvent(event);
// SFactorGraph sgraph = (SFactorGraph)event.getSource().getRootGraph();
// System.out.format("total score: %s\n", sgraph.getTotalPotential());
ISolverVariableGibbs variable = event.getSource();
assertTrue(event.getNewValue().valueEquals(variable.getCurrentSampleValue()));
if (event instanceof GibbsScoredVariableUpdateEvent)
{
GibbsScoredVariableUpdateEvent scoredEvent = (GibbsScoredVariableUpdateEvent)event;
assertEquals(variable.getCurrentSampleScore(), scoredEvent.getNewSampleScore(), 0.0);
assertEquals(scoredEvent.getNewSampleScore() - scoredEvent.getOldSampleScore(),
scoredEvent.getScoreDifference(), 1e-15);
}
}
@SuppressWarnings("null")
void printEvent(GibbsVariableUpdateEvent event)
{
System.out.format("%s: %s %s => %s",
event.getClass().getSimpleName(), event.getModelObject().getName(),
event.getOldValue(), event.getNewValue());
if (event instanceof GibbsScoredVariableUpdateEvent)
{
GibbsScoredVariableUpdateEvent scoredEvent = (GibbsScoredVariableUpdateEvent)event;
System.out.format(" score %+f", scoredEvent.getScoreDifference());
}
System.out.println("");
}
}
@SuppressWarnings("unused")
@Test
public void test()
{
//
// Set up model/solver
//
final FactorFunction function = new BogoFunction();
FactorGraph model = new FactorGraph();
Discrete d1 = new Discrete(DiscreteDomain.range(0, 9));
d1.setName("d1");
Real r1 = new Real();
r1.setName("r1");
Complex c1 = new Complex();
c1.setName("c1");
model.addVariables(d1, r1, c1);
Factor fdr = model.addFactor(function, d1, r1);
Factor frc = model.addFactor(function, r1, c1);
Factor fcd = model.addFactor(function, c1, d1);
GibbsSolverGraph sgraph = requireNonNull(model.setSolverFactory(new GibbsSolver()));
ISolverVariableGibbs sd1 = Objects.requireNonNull(sgraph.getSolverVariable(d1));
ISolverVariableGibbs sr1 = Objects.requireNonNull(sgraph.getReal(r1));
ISolverVariableGibbs sc1 = Objects.requireNonNull(sgraph.getSolverVariable(c1));
sgraph.setBurnInScans(0);
sgraph.setNumSamples(1);
sgraph.setTemperature(1.0);
model.setScheduler(new GibbsSequentialScanScheduler());
//
// Set up listener
//
DimpleEventListener listener = new DimpleEventListener();
VariableUpdateHandler handler = new VariableUpdateHandler();
listener.register(handler, GibbsVariableUpdateEvent.class, false, model);
assertTrue(listener.isListeningFor(GibbsVariableUpdateEvent.class, model));
assertTrue(listener.isListeningFor(GibbsVariableUpdateEvent.class, sc1));
assertFalse(listener.isListeningFor(GibbsScoredVariableUpdateEvent.class, sc1));
model.getEnvironment().setEventListener(listener);
assertSame(listener, model.getEventListener());
assertSame(listener, sd1.getEventListener());
model.solve();
assertEvents(handler, GibbsVariableUpdateEvent.class, sd1, sr1, sc1);
model.getEnvironment().setEventListener(null);
model.solve();
assertEvents(handler, GibbsVariableUpdateEvent.class);
model.getEnvironment().setEventListener(listener);
listener.block(GibbsVariableUpdateEvent.class, false, sr1);
model.solve();
assertEvents(handler, GibbsVariableUpdateEvent.class, sd1, sc1);
listener.unblock(GibbsVariableUpdateEvent.class, sr1);
model.solve();
assertEvents(handler, GibbsVariableUpdateEvent.class, sd1, sr1, sc1);
listener.register(handler, GibbsVariableUpdateEvent.class, true, model);
model.solve();
assertEvents(handler, GibbsScoredVariableUpdateEvent.class, sd1, sr1, sc1);
double prevScore = sgraph.getSampleScore();
sgraph.sample();
double score = sgraph.getSampleScore();
double scoreDifference = assertEvents(handler, GibbsScoredVariableUpdateEvent.class, sd1, sr1, sc1);
assertEquals(score - prevScore, scoreDifference, 1e-14);
listener.unregisterAll();
sd1.notifyListenerChanged();
sr1.notifyListenerChanged();
sc1.notifyListenerChanged();
sgraph.sample();
assertEvents(handler, null);
}
/**
* Asserts that events with given {@code expectedClass} have occurred on the {@code handler} on the
* specified {@code sources} in order. This method clears the handler's event list.
*
* @return the cumulative {@link GibbsScoredVariableUpdateEvent#getScoreDifference()} if available, otherwise zero.
*/
private double assertEvents(VariableUpdateHandler handler, @Nullable Class<? extends DimpleEvent> expectedClass,
IDimpleEventSource ... sources)
{
double scoreDifference = 0.0;
final int nSources = sources.length;
assertEquals(nSources, handler.events.size());
for (int i = 0; i < nSources; ++i)
{
GibbsVariableUpdateEvent event = handler.events.get(i);
assertSame(expectedClass, event.getClass());
assertSame(sources[i], event.getSource());
final int rejectCount = event.getRejectCount();
final Domain domain = event.getSource().getDomain();
final RealJointDomain jointDomain = domain.asRealJoint();
assertTrue(rejectCount >= 0);
boolean fullyRejected = false;
if (jointDomain != null)
{
final int n = jointDomain.getDimensions();
assertTrue(rejectCount <= n);
fullyRejected = rejectCount == n;
}
else
{
assertTrue(rejectCount <= 1);
fullyRejected = rejectCount == 1;
}
if (fullyRejected)
{
assertTrue(event.getOldValue().valueEquals(event.getNewValue()));
}
if (event instanceof GibbsScoredVariableUpdateEvent)
{
final double eventScoreDifference = ((GibbsScoredVariableUpdateEvent)event).getScoreDifference();
scoreDifference += eventScoreDifference;
if (fullyRejected)
{
assertEquals(eventScoreDifference, 0.0, 0.0);
}
}
}
handler.events.clear();
return scoreDifference;
}
}