/*******************************************************************************
* Copyright 2012-2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.gibbs;
import static com.analog.lyric.dimple.environment.DimpleEnvironment.*;
import static java.util.Objects.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.collect.KeyedPriorityQueue;
import com.analog.lyric.dimple.data.ValueDataLayer;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.model.core.DirectedNodeSorter;
import com.analog.lyric.dimple.model.core.EdgeState;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.core.FactorGraphIterables;
import com.analog.lyric.dimple.model.core.Node;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.repeated.BlastFromThePastFactor;
import com.analog.lyric.dimple.model.repeated.FactorGraphStream;
import com.analog.lyric.dimple.model.values.IndexedValue;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Real;
import com.analog.lyric.dimple.model.variables.RealJoint;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.model.variables.VariableBlock;
import com.analog.lyric.dimple.options.DimpleOptions;
import com.analog.lyric.dimple.schedulers.SchedulerOptionKey;
import com.analog.lyric.dimple.schedulers.schedule.IGibbsSchedule;
import com.analog.lyric.dimple.schedulers.schedule.ISchedule;
import com.analog.lyric.dimple.schedulers.scheduleEntry.IScheduleEntry;
import com.analog.lyric.dimple.solvers.core.SFactorGraphBase;
import com.analog.lyric.dimple.solvers.gibbs.samplers.block.IBlockInitializer;
import com.analog.lyric.dimple.solvers.interfaces.ISolverBlastFromThePastFactor;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph;
import com.analog.lyric.dimple.solvers.interfaces.ISolverVariable;
import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping;
import com.analog.lyric.util.misc.Matlab;
import cern.colt.list.DoubleArrayList;
/**
* Solver-specific factor graph for Gibbs solver.
* <p>
* <em>Previously was com.analog.lyric.dimple.solvers.gibbs.SFactorGraph</em>
* <p>
* @since 0.07
*/
public class GibbsSolverGraph
extends SFactorGraphBase<ISolverFactorGibbs, ISolverVariableGibbs, GibbsSolverEdge<?>, GibbsVariableBlock>
{
/*
* Constants
*/
private static final double LOG2 = Math.log(2);
/**
* Bits in {@link #_flags} reserved by this class and its superclasses.
* @see GibbsSolverGraphEvent
*/
@SuppressWarnings("hiding")
protected static final int RESERVED_FLAGS = 0xFFFFF000;
/*-------
* State
*/
private @Nullable Iterator<IScheduleEntry> _scheduleIterator;
private @Nullable ArrayList<IBlockInitializer> _blockInitializers;
private int _numSamples = GibbsOptions.numSamples.defaultIntValue();
private int _updatesPerSample = GibbsOptions.scansPerSample.defaultIntValue();
private int _burnInUpdates = 0;
private int _scansPerSample = 1;
private int _burnInScans = GibbsOptions.burnInScans.defaultIntValue();
private int _numRandomRestarts = GibbsOptions.numRandomRestarts.defaultIntValue();
private boolean _temper = false;
private double _initialTemperature;
private double _temperingDecayConstant;
private double _temperature;
private double _minPotential = Double.MAX_VALUE;
private boolean _firstSample = true;
private @Nullable DoubleArrayList _scoreArray;
/**
* Priority queue of deterministic factors whose outputs should be
* reevaluated. Lazily created.
*/
private @Nullable KeyedPriorityQueue<ISolverFactorGibbs, SFactorUpdate> _deferredDeterministicFactorUpdates = null;
/**
* The number of requests to defer update of deterministic directed factor outputs.
* If greater than zero, {@link #scheduleDeterministicDirectedUpdate(ISolverFactorGibbs, int)} will
* defer execution until later. This counter may be greater than one as a result of recursive
* calls.
*/
private int _deferDeterministicFactorUpdatesCounter = 0;
/*--------------
* Construction
*/
protected GibbsSolverGraph(FactorGraph factorGraph, @Nullable ISolverFactorGraph parent)
{
super(factorGraph, parent);
}
/*--------------------------
* IVariableToValue methods
*/
@Override
public Value varToValue(Variable var)
{
return getSolverVariable(var).getCurrentSampleValue();
}
/*----------------------------
* ISolverFactorGraph methods
*/
/**
* {@inheritDoc}
* <p>
* @return {@link GibbsOptions#scheduler}.
*/
@Override
public SchedulerOptionKey getSchedulerKey()
{
return GibbsOptions.scheduler;
}
// TODO - rearrange methods
@Override
public boolean hasEdgeState()
{
return true;
}
@Override
public GibbsSolverEdge<?> createEdgeState(EdgeState edge)
{
ISolverFactorGibbs sfactor = getSolverFactor(edge.getFactor(_model));
GibbsSolverEdge<?> sedge = sfactor.createEdge(edge);
if (sedge == null)
{
final Variable var = edge.getVariable(_model);
if (var instanceof Discrete)
{
sedge = new GibbsDiscreteEdge((Discrete)var);
}
else
{
sedge = GibbsNullEdge.INSTANCE;
}
}
return sedge;
}
@SuppressWarnings("deprecation")
@Override
public ISolverVariableGibbs createVariable(Variable var)
{
if (var instanceof RealJoint)
return new SRealJointVariable((RealJoint)var, this);
if (var instanceof Real)
return new SRealVariable((Real)var, this);
else if (var instanceof Discrete)
return new SDiscreteVariable((Discrete)var, this);
throw unsupportedVariableType(var);
}
@SuppressWarnings("deprecation") // TODO: remove when SFactorGraph is removed.
@Override
public ISolverFactorGraph createSubgraph(FactorGraph subgraph)
{
return new SFactorGraph(subgraph, this, null);
}
// Note, customFactorExists is intentionally not overridden and therefore returns false
// This is because all of the custom factors for this solver also exist as FactorFunctions,
// and therefore we still want the MATLAB code to create a factor with the specified FactorFunctions.
@Override
public ISolverFactorGibbs createFactor(Factor factor)
{
return GibbsOptions.customFactors.createFactor(factor, this);
}
@Override
public ISolverBlastFromThePastFactor createBlastFromThePast(BlastFromThePastFactor factor)
{
//TODO: catch case where the factor is directed
if (factor.isDirected() || factor.getFactorFunction().isDeterministicDirected())
throw new DimpleException("not yet supported");
if (factor.isDiscrete())
return new GibbsTableFactorBlastFromThePast(factor, this);
else
return new GibbsRealFactorBlastFromThePast(factor, this);
}
@Override
public GibbsVariableBlock createVariableBlock(VariableBlock block)
{
return new GibbsVariableBlock(block, this);
}
@Override
public IGibbsSchedule getSchedule()
{
return (IGibbsSchedule)super.getSchedule();
}
@Override
public ISolverFactorGibbs getSolverFactor(Factor factor)
{
return (ISolverFactorGibbs)super.getSolverFactor(factor);
}
@Override
public ISolverVariableGibbs getSolverVariable(Variable variable)
{
return (ISolverVariableGibbs)super.getSolverVariable(variable);
}
/**
* Get {@link GibbsDiscrete} solver variable for given model variable.
* @param variable is a variable contained in corresponding model graph.
* @throws NullPointerException if there is no such solver variable.
* @since 0.08
* @see #getSolverVariable(Variable)
* @see #getReal(Real)
*/
public GibbsDiscrete getDiscrete(Discrete variable)
{
return requireNonNull((GibbsDiscrete)super.getSolverVariable(variable));
}
/**
* Get {@link GibbsReal} solver variable for given model variable.
* @param variable is a variable contained in corresponding model graph.
* @throws NullPointerException if there is no such solver variable.
* @since 0.08
* @see #getSolverVariable(Variable)
* @see #getDiscrete(Discrete)
*/
public GibbsReal getReal(Real variable)
{
return requireNonNull((GibbsReal)super.getSolverVariable(variable));
}
@Override
public void initialize()
{
_numSamples = getOptionOrDefault(GibbsOptions.numSamples);
_numRandomRestarts = getOptionOrDefault(GibbsOptions.numRandomRestarts);
_scansPerSample = getOptionOrDefault(GibbsOptions.scansPerSample);
_burnInScans = getOptionOrDefault(GibbsOptions.burnInScans);
final boolean saveAllScores = getOptionOrDefault(GibbsOptions.saveAllScores);
_temper = getOptionOrDefault(GibbsOptions.enableAnnealing);
_initialTemperature = getOptionOrDefault(GibbsOptions.initialTemperature);
_temperingDecayConstant = 1 - LOG2/getOptionOrDefault(GibbsOptions.annealingHalfLife);
Long seed = getOption(DimpleOptions.randomSeed);
if (seed != null)
{
setSeed(seed);
}
// Make sure the schedule is created before factor initialization to allow custom factors to modify the schedule if needed
final ISchedule schedule = getSchedule();
validateSchedule(schedule);
FactorGraph fg = _model;
Map<Node,Integer> nodeOrder = DirectedNodeSorter.orderDirectedNodes(fg);
for (Factor factor : fg.getFactors())
{
ISolverFactorGibbs sfactor = getSolverFactor(factor);
Integer order = nodeOrder.get(factor);
sfactor.setTopologicalOrder(order != null ? order : 0);
}
// Same as SFactorGraphBase.initialize() but with deferral of deterministic updates
_blockInitializers = null;
deferDeterministicUpdates();
for (Variable variable : fg.getOwnedVariables())
{
getSolverVariable(variable).initialize();
}
if (!fg.hasParentGraph())
{
for (int i = 0, end = fg.getBoundaryVariableCount(); i <end; ++i)
{
getSolverVariable(fg.getBoundaryVariable(i)).initialize();
}
}
for (Factor f : fg.getNonGraphFactorsTop())
{
getSolverFactor(f).initialize();
}
processDeferredDeterministicUpdates();
for (FactorGraph g : fg.getOwnedGraphs())
{
getSolverSubgraph(g).initialize();
}
deferDeterministicUpdates();
final ArrayList<IBlockInitializer> blockInitializers = _blockInitializers;
if (blockInitializers != null)
{
for (IBlockInitializer b : blockInitializers) // After initializing all variables and factors, invoke any block initializers
{
b.initialize();
}
}
processDeferredDeterministicUpdates();
_scheduleIterator = schedule.iterator();
_minPotential = Double.POSITIVE_INFINITY;
_firstSample = true;
setUpdatesPerSampleFromScans();
setBurnInUpdatesFromScans();
if (_temper) setTemperature(_initialTemperature);
DoubleArrayList scoreArray = null;
if (saveAllScores)
{
scoreArray = _scoreArray;
if (scoreArray == null)
{
scoreArray = new DoubleArrayList();
}
else
{
scoreArray.clear();
}
}
_scoreArray = scoreArray;
}
/**
* Does one round of Gibbs sampling.
* <p>
* Performs the equivalent of:
* <blockquote>
* <pre>
* for (int restart = 0; restart <= {@link #getNumRestarts()}; ++restart)
* {
* {@linkplain #burnIn(int) burnIn(restart)};
* for (int i = 0; i < {@link #getNumSamples()}; ++i)
* {
* {@link #sample()};
* }
* }
* </pre>
* </blockquote>
* </ol>
*/
@Override
public void solveOneStep()
{
_minPotential = Double.POSITIVE_INFINITY;
_firstSample = true;
for (int restartCount = 0; restartCount <= _numRandomRestarts; restartCount++)
{
burnIn(restartCount);
for (int iter = 0; iter < _numSamples; iter++)
oneSample();
}
}
/**
* Perform initial burn in.
* <p>
* This invokes {@link #burnIn(int)} with value zero.
*/
@Matlab
public final void burnIn()
{
burnIn(0);
}
/**
* Perform burn-in phase.
* <p>
* This consists of randomly reinitializing values of variables in the graph
* that do not have fixed values and then performing {@link #getBurnInUpdates()}
* variable updates.
* <p>
* Burn-in is required for most graphs to ensure that the samples will be closer to the
* real distribution.
* <p>
* @param restartCount is a non-negative number indicating which random restart is
* executing. This will be zero for the initial burn-in phase.
*/
public final void burnIn(int restartCount)
{
randomRestart(restartCount);
iterate(_burnInUpdates);
if (GibbsSolverGraphEvent.raiseBurnInEvent(this))
{
raiseEvent(new GibbsBurnInEvent(this, restartCount, _temper ? _temperature : Double.NaN));
}
}
/**
* Generate one sample.
* <p>
* Simply invokes {@link #sample(int)} with value one.
*/
public void sample()
{
sample(1);
}
/**
* Run more samples without initializing, burn-in, or random-restarts
* <p>
* This is like {@link #iterate}, except that while iterate just updates runs a specified number
* of single-variable updates, this runs a specified number of entire samples, where the size of
* a sample has already been defined in terms of number of either updates or scans.
* <p>
* @param numSamples is a positive number indicating the number of samples to generate.
*/
@Matlab
public void sample(int numSamples)
{
for (int sample = 0; sample < numSamples; sample++)
oneSample();
}
/**
* Performs specified number of single variable updates
*<p>
* Note that the iterate() method for the Gibbs solver means do the
* specified number of single-variable updates, regardless of other parameter settings.
* The iterate() method behaves differently than for other solvers due to the fact that the
* {@link #update()} method for Gibbs-specific schedules will update only a single variable.
* Also, multithreaded operation for Gibbs is not supported
*/
@Override
public void iterate(int numUpdates)
{
Iterator<IScheduleEntry> scheduleIterator = Objects.requireNonNull(_scheduleIterator);
final ISchedule schedule = getSchedule();
for (int iterNum = 0; iterNum < numUpdates; iterNum++)
{
if (!scheduleIterator.hasNext())
{
// Wrap-around the schedule if reached the end
scheduleIterator = _scheduleIterator = schedule.iterator();
}
runScheduleEntry(scheduleIterator.next());
}
// Allow interruption (if the solver is run as a thread); currently interruption is allowed only between
// iterations, not within a single iteration.
// FIXME - is this really doing anything? Seems like this is just going to check the interrupt bit and
// then ignore it.
try {interruptCheck();}
catch (InterruptedException e) {return;}
}
@SuppressWarnings("null")
protected void oneSample()
{
iterate(_updatesPerSample);
for (Variable v : _model.getVariables())
{
ISolverVariableGibbs vs = getSolverVariable(v);
vs.updateBelief();
vs.saveCurrentSample(); // Note that the first sample saved is one full sample after burn in, not immediately after burn in (in case the burn in is zero)
}
// Save the best sample value seen so far
final double totalPotential = getSampleScore();
final boolean wasMininum = totalPotential < _minPotential || _firstSample;
if (wasMininum)
{
for (Variable v : _model.getVariables())
getSolverVariable(v).saveBestSample();
_minPotential = totalPotential;
_firstSample = false;
}
// If requested save score value for each sample
final DoubleArrayList scoreArray = _scoreArray;
if (scoreArray != null)
{
scoreArray.add(totalPotential);
}
// If tempering, reduce the temperature
double oldTemperature = Double.NaN, newTemperature = Double.NaN;
if (_temper)
{
oldTemperature = _temperature;
newTemperature= oldTemperature * _temperingDecayConstant;
setTemperature(_temperature = newTemperature);
}
if (GibbsSolverGraphEvent.raiseSampleStatsEvent(this))
{
raiseEvent(new GibbsSampleStatisticsEvent(this, totalPotential, wasMininum, oldTemperature, newTemperature));
}
}
@SuppressWarnings("null")
@Override
public void postAdvance()
{
//In the case of rolled up graphs, we make sure we randomly restart
//the variables that are added to the end of the chain.
for (FactorGraphStream fgs : getModel().getFactorGraphStreams())
{
FactorGraph ng = fgs.getNestedGraphs().get(fgs.getNestedGraphs().size()-1);
for (Variable vb : FactorGraphIterables.boundary(ng))
{
getSolverVariable(vb).randomRestart(0);
}
}
}
@SuppressWarnings("null")
public void randomRestart(int restartCount)
{
deferDeterministicUpdates();
for (Variable v : _model.getVariables())
getSolverVariable(v).randomRestart(restartCount);
final ArrayList<IBlockInitializer> blockInitializers = _blockInitializers;
if (blockInitializers != null)
for (IBlockInitializer b : blockInitializers) // Also invoke any block initializers
b.initialize();
processDeferredDeterministicUpdates();
if (_temper) setTemperature(_initialTemperature); // Reset the temperature, if tempering
}
/**
* @deprecated use {@link #getSampleScore()} instead.
*/
@Matlab
@Deprecated
@SuppressWarnings("null")
public double getTotalPotential()
{
return getSampleScore();
}
/**
* Returns data layer view of sample values for graph tree.
* @since 0.08
*/
public ValueDataLayer getSampleLayer()
{
return new GibbsSampleLayer(this);
}
/**
* Computes the total energy or "score" over the entire graph given current sample values.
* <p>
* Includes input priors over variables.
* <p>
* @since 0.08
* @see #getBestSampleScore()
*/
public double getSampleScore()
{
double totalPotential = 0;
for (Factor f : _model.getNonGraphFactors())
totalPotential += requireNonNull(getSolverFactor(f)).getPotential();
for (Variable v : _model.getVariables()) // Variables contribute too because they have inputs, which are factors
totalPotential += requireNonNull(getSolverVariable(v)).getPotential();
return totalPotential;
}
/**
* Returns the lowest value of {@link #getSampleScore()} discovered since initialization.
* <p>
* @since 0.08
* @see #getAllScores()
*/
public double getBestSampleScore()
{
return _minPotential;
}
/**
* @deprecated Instead set {@link GibbsOptions#saveAllSamples} to true using {@link #setOption}.
*/
@Deprecated
@SuppressWarnings("null")
public void saveAllSamples()
{
setOption(GibbsOptions.saveAllSamples, true);
}
/**
* @deprecated Instead set {@link GibbsOptions#saveAllSamples} to false using {@link #setOption}.
*/
@Deprecated
@SuppressWarnings("null")
public void disableSavingAllSamples()
{
setOption(GibbsOptions.saveAllSamples, false);
}
/**
* @deprecated Instead set {@link GibbsOptions#saveAllScores} to true using {@link #setOption}.
*/
@Deprecated
public void saveAllScores()
{
_scoreArray = new DoubleArrayList();
setOption(GibbsOptions.saveAllScores, true);
}
/**
* @deprecated Instead set {@link GibbsOptions#saveAllScores} to false using {@link #setOption}.
*/
@Deprecated
public void disableSavingAllScores()
{
_scoreArray = null;
setOption(GibbsOptions.saveAllScores, false);
}
/**
* If the score had been saved, return the array of score values, otherwise null.
*/
@Matlab
public final @Nullable double[] getAllScores()
{
final DoubleArrayList scoreArray = _scoreArray;
if (scoreArray != null)
{
return Arrays.copyOf(scoreArray.elements(), scoreArray.size());
}
else
return null;
}
/**
* Get the rejection rate of the sampler for variables and block entries for which it applies
* @return rejection rate
* @since 0.07
*/
@Matlab
public final double getRejectionRate()
{
long updateCount = 0;
long rejectCount = 0;
// Accumulate the rejection statistics for all variables
for (Variable v : _model.getVariables())
{
ISolverVariableGibbs variable = requireNonNull(getSolverVariable(v));
updateCount += variable.getUpdateCount();
rejectCount += variable.getRejectionCount();
}
// Accumulate the rejection statistics for any variable blocks in the graph
for (GibbsVariableBlock sblock : getSolverVariableBlocks())
{
updateCount += sblock.getUpdateCount();
rejectCount += sblock.getRejectionCount();
}
return (updateCount > 0) ? (double)rejectCount / (double)updateCount : 0;
}
/**
* Clear the rejection rate statistics
* @since 0.07
*/
@Matlab
public final void resetRejectionRateStats()
{
// Reset the rejection statistics for all variables
for (Variable v : _model.getVariables())
requireNonNull(getSolverVariable(v)).resetRejectionRateStats();
// Reset the rejection statistics for any variable blocks in the graph
for (GibbsVariableBlock sblock : getSolverVariableBlocks())
{
sblock.resetCounts();
}
}
// Set/get the current temperature for all variables in the graph (for tempering)
@Matlab
@SuppressWarnings("null")
public void setTemperature(double T)
{
_temperature = T;
double beta = 1/T;
for (Variable v : _model.getVariables())
getSolverVariable(v).setBeta(beta);
}
@Matlab
public double getTemperature() {return _temperature;}
// Sets the random seed for the Gibbs solver. This allows runs of the solver to be repeatable.
public void setSeed(long seed)
{
activeRandom().setSeed(seed);
}
/**
* Sets the number of samples to generate per restart.
* <p>
* Sets the value of {@link #getNumSamples()} and the corresponding {@link GibbsOptions#numSamples}
* option to the specified value.
* <p>
* @param numSamples must be a positive integer.
* @deprecated Instead set {@link GibbsOptions#numSamples} option on this object or its corresponding
* model object using {@link #setOption}.
*/
@Deprecated
public void setNumSamples(int numSamples)
{
setOption(GibbsOptions.numSamples, numSamples);
_numSamples = numSamples;
}
/**
* Number of samples to generate per restart.
* <p>
* Set automatically from the {@link GibbsOptions#numSamples} option during {@link #initialize}.
*/
public int getNumSamples()
{
return _numSamples;
}
/**
* @deprecated This method will be removed in a future release.
*/
@Deprecated
public int getUpdatesPerSample()
{
return _updatesPerSample;
}
/**
* @deprecated This method will be removed in a future release.
*/
@Deprecated
public void setUpdatesPerSample(int updatesPerSample)
{
// TODO: when this method is removed, change the range of scansPerSample to [0,max]
_updatesPerSample = updatesPerSample;
setOption(GibbsOptions.scansPerSample, -1);
_scansPerSample = -1; // Samples specified in updates rather than scans
}
/**
* Sets the number of full updates of all of the variables to perform for each sample.
* <p>
* This sets the value of the corresponding {@link GibbsOptions#scansPerSample} option.
* <p>
* @param scansPerSample must be a positive integer.
* @deprecated Instead set {@link GibbsOptions#scansPerSample} option on this object or its
* corresponding model graph using {@link #setOption}.
*/
@Deprecated
public void setScansPerSample(int scansPerSample)
{
if (scansPerSample < 1)
throw new DimpleException("Scans per sample must be greater than 0.");
setOption(GibbsOptions.scansPerSample, scansPerSample);
_scansPerSample = scansPerSample;
setUpdatesPerSampleFromScans();
}
/**
* Updates the value of {@link _updatesPerSample} based on {@link _scansPerSample} and
* the current number of variables in the graph.
*/
private void setUpdatesPerSampleFromScans()
{
if (_scansPerSample > 0)
{
final IGibbsSchedule schedule = (IGibbsSchedule)_schedule;
_updatesPerSample = _scansPerSample * (schedule != null ? schedule.size() : _model.getVariableCount());
}
}
/**
* @deprecated This method will be removed in a future release.
*/
@Deprecated
public int getBurnInUpdates()
{
return _burnInUpdates;
}
/**
* @deprecated This method will be removed in a future release.
*/
@Deprecated
public void setBurnInUpdates(int burnInUpdates)
{
// TODO: when this method is removed, change the range of burnInScans to [0,max]
_burnInUpdates = burnInUpdates;
_burnInScans = -1; // Burn-in specified in updates rather than scans
setOption(GibbsOptions.burnInScans, -1);
}
// Set the number of scans for burn-in as an alternative means of specifying the burn-in period
/**
* Sets the number of updates of all of the variables to perform during the burn-in period.
* <p>
* This simply sets the value of the {@link GibbsOptions#burnInScans} option on this object.
* <p>
* @param burnInScans is a non-negative number.
* @deprecated Instead set {@link GibbsOptions#burnInScans} option on this object or its corresponding
* model graph using {@link #setOption}.
*/
@Deprecated
public void setBurnInScans(int burnInScans)
{
setOption(GibbsOptions.burnInScans, burnInScans);
_burnInScans = burnInScans;
setBurnInUpdatesFromScans();
}
/**
* Updates the value of {@link _burnInUpdates} based on {@link _burnInScans} and
* the current number of variables in the graph.
*/
private void setBurnInUpdatesFromScans()
{
if (_burnInScans > 0)
{
final IGibbsSchedule schedule = (IGibbsSchedule)_schedule;
_burnInUpdates = _burnInScans * (schedule != null ? schedule.size() : _model.getVariableCount());
}
}
/**
* Sets number of random restarts.
* <p>
* Sets the value of {@link #getNumRestarts()} and the corresponding {@link GibbsOptions#numRandomRestarts}
* option.
* @param numRestarts must be a positive integer.
* @deprecated Instead set {@link GibbsOptions#numRandomRestarts} option on this object or its corresponding
* model graph using {@link #setOption}.
*/
@Deprecated
public void setNumRestarts(int numRestarts)
{
setOption(GibbsOptions.numRandomRestarts, numRestarts);
_numRandomRestarts = numRestarts;
}
/**
* Number of random restarts to perform during solve.
* <p>
* This is automatically set from {@link GibbsOptions#numRandomRestarts} option during
* {@link #initialize}.
*/
public int getNumRestarts()
{
return _numRandomRestarts;
}
// Set the default sampler for Real (and RealJoint) variables
public void setDefaultRealSampler(String samplerName)
{
GibbsOptions.realSampler.convertAndSet(this, samplerName);
}
public String getDefaultRealSampler()
{
return getOptionOrDefault(GibbsOptions.realSampler).getSimpleName();
}
// Set the default sampler for Discrete variables
public void setDefaultDiscreteSampler(String samplerName)
{
GibbsOptions.discreteSampler.convertAndSet(this, samplerName);
}
public String getDefaultDiscreteSampler()
{
return getOptionOrDefault(GibbsOptions.discreteSampler).getSimpleName();
}
/**
* @deprecated Instead set {@link GibbsOptions#initialTemperature} option using {@link #setOption}.
*/
@Deprecated
public void setInitialTemperature(double initialTemperature)
{
setOption(GibbsOptions.initialTemperature, initialTemperature);
setTempering(true);
_initialTemperature = initialTemperature;
}
/**
* @deprecated Instead get {@link GibbsOptions#initialTemperature} option using {@link #getOption}.
*/
@Deprecated
public double getInitialTemperature() {return _initialTemperature;}
/**
* @deprecated Instead set {@link GibbsOptions#annealingHalfLife} option using {@link #setOption}.
*/
@Deprecated
public void setTemperingHalfLifeInSamples(double temperingHalfLifeInSamples)
{
setOption(GibbsOptions.annealingHalfLife, temperingHalfLifeInSamples);
setTempering(true);
_temperingDecayConstant = 1 - LOG2/temperingHalfLifeInSamples;
}
/**
* @deprecated Instead get {@link GibbsOptions#annealingHalfLife} option using {@link #getOption}.
*/
@Deprecated
public double getTemperingHalfLifeInSamples() {return LOG2/(1 - _temperingDecayConstant);}
/**
* @deprecated Instead set {@link GibbsOptions#enableAnnealing} option using {@link #setOption}.
*/
@Deprecated
protected void setTempering(boolean temper)
{
setOption(GibbsOptions.enableAnnealing, temper);
_temper = temper;
}
/**
* @deprecated Instead set {@link GibbsOptions#enableAnnealing} option to true using {@link #setOption}.
*/
@Deprecated
public final void enableTempering()
{
setTempering(true);
}
/**
* @deprecated Instead set {@link GibbsOptions#enableAnnealing} option to false using {@link #setOption}.
*/
@Deprecated
public final void disableTempering()
{
setTempering(false);
}
/**
* @deprecated Instead get {@link GibbsOptions#enableAnnealing} option using {@link #getOption}.
*/
@Deprecated
public boolean isTemperingEnabled()
{
return _temper;
}
// Helpers for operating on pre-specified groups of variables in the graph
public double[] getVariableSampleValues(int variableBlockLocalId)
{
List<Variable> variableList = _model.getVariableBlockByLocalId(variableBlockLocalId);
if (variableList == null)
{
return ArrayUtil.EMPTY_DOUBLE_ARRAY;
}
final SolverNodeMapping solvers = getSolverMapping();
final int numVariables = variableList.size();
final double[] result = new double[numVariables];
for (int i = 0; i < numVariables; i++)
{
ISolverVariable var = solvers.getSolverVariable(variableList.get(i));
if (var instanceof GibbsDiscrete)
result[i] = (Double)((GibbsDiscrete)var).getCurrentSample();
else if (var instanceof GibbsReal)
result[i] = ((GibbsReal)var).getCurrentSample();
else
throw new DimpleException("Invalid variable class");
}
return result;
}
public void setAndHoldVariableSampleValues(int variableBlockLocalId, Object[] values) {setAndHoldVariableSampleValues(variableBlockLocalId, (double[])values[0]);} // Due to the way MATLAB passes objects
public void setAndHoldVariableSampleValues(int variableBlockLocalId, double[] values)
{
List<Variable> variableList = _model.getVariableBlockByLocalId(variableBlockLocalId);
if (variableList != null)
{
int numVariables = variableList.size();
if (numVariables != values.length)
{
throw new DimpleException("Number of values must match the number of variables");
}
final SolverNodeMapping solvers = getSolverMapping();
for (int i = 0; i < numVariables; i++)
{
ISolverVariable var = solvers.getSolverVariable(variableList.get(i));
if (var instanceof GibbsDiscrete)
((GibbsDiscrete)var).setAndHoldSampleValue(values[i]);
else if (var instanceof GibbsReal)
((GibbsReal)var).setAndHoldSampleValue(values[i]);
else
throw new DimpleException("Invalid variable class");
}
}
}
public void holdVariableSampleValues(int variableBlockLocalId)
{
List<Variable> variableList = _model.getVariableBlockByLocalId(variableBlockLocalId);
if (variableList != null)
{
final SolverNodeMapping solvers = getSolverMapping();
int numVariables = variableList.size();
for (int i = 0; i < numVariables; i++)
{
ISolverVariable var = solvers.getSolverVariable(variableList.get(i));
if (var instanceof GibbsDiscrete)
((GibbsDiscrete)var).holdSampleValue();
else if (var instanceof GibbsReal)
((GibbsReal)var).holdSampleValue();
else
throw new DimpleException("Invalid variable class");
}
}
}
public void releaseVariableSampleValues(int variableBlockLocalId)
{
List<Variable> variableList = _model.getVariableBlockByLocalId(variableBlockLocalId);
if (variableList != null)
{
final SolverNodeMapping solvers = getSolverMapping();
int numVariables = variableList.size();
for (int i = 0; i < numVariables; i++)
{
ISolverVariable svar = solvers.getSolverVariable(variableList.get(i));
if (svar instanceof GibbsDiscrete)
((GibbsDiscrete)svar).releaseSampleValue();
else if (svar instanceof GibbsReal)
((GibbsReal)svar).releaseSampleValue();
else
throw new DimpleException("Invalid variable class");
}
}
}
// 'Iterations' are not defined for Gibbs since that term is ambiguous. Instead, set the number of samples using setNumSamples().
@Override
public void setNumIterations(int numIter)
{
throw new DimpleException("The length of a run in the Gibbs solver is not specified by a number of 'iterations', but by the number of 'samples'");
}
@SuppressWarnings("null")
@Override
public void postAddFactor(Factor f)
{
deferDeterministicUpdates();
for (int i = 0, nvars = f.getSiblingCount(); i < nvars; ++i)
{
getSolverVariable(f.getSibling(i)).postAddFactor(f);
}
processDeferredDeterministicUpdates();
}
@SuppressWarnings("null")
@Override
public void postSetSolverFactory()
{
deferDeterministicUpdates();
for(Variable vb : getModel().getVariablesFlat())
{
getSolverVariable(vb).postAddFactor(null);
}
processDeferredDeterministicUpdates();
}
@Override
public @Nullable String getMatlabSolveWrapper()
{
return null;
}
/**
*
* @param sfactor
* @param changedArgIndex is the index of the changed factor argument. This may be different
* than the sibling index if the factor has constants.
* @param oldValue
* @since 0.08
*/
void scheduleDeterministicDirectedUpdate(ISolverFactorGibbs sfactor, int changedArgIndex, Value oldValue)
{
if (_deferDeterministicFactorUpdatesCounter > 0)
{
if (_deferredDeterministicFactorUpdates == null)
{
_deferredDeterministicFactorUpdates =
new KeyedPriorityQueue<ISolverFactorGibbs, SFactorUpdate>(11,
SFactorUpdate.DeterministicOrder.INSTANCE);
}
SFactorUpdate update = requireNonNull(_deferredDeterministicFactorUpdates).get(sfactor);
if (update == null)
{
update = new SFactorUpdate(sfactor);
requireNonNull(_deferredDeterministicFactorUpdates).offer(update);
}
update.addVariableUpdate(changedArgIndex, oldValue);
}
else
{
final Factor factor = sfactor.getModelObject();
final int nEdges = factor.getSiblingCount();
IndexedValue.SingleList oldValues = null;
if (factor.getFactorFunction().updateDeterministicLimit(nEdges) > 0)
{
oldValues = IndexedValue.SingleList.create(changedArgIndex, oldValue);
}
deferDeterministicUpdates();
sfactor.updateNeighborVariableValuesNow(oldValues);
if (oldValues != null)
{
oldValues.release();
}
processDeferredDeterministicUpdates();
}
}
public void processDeferredDeterministicUpdates()
{
if (--_deferDeterministicFactorUpdatesCounter <= 0)
{
_deferDeterministicFactorUpdatesCounter = 1;
final KeyedPriorityQueue<ISolverFactorGibbs, SFactorUpdate> deferredUpdates =
_deferredDeterministicFactorUpdates;
if (deferredUpdates != null)
{
SFactorUpdate update = null;
while ((update = deferredUpdates.poll()) != null)
{
update.performUpdate();
}
}
_deferDeterministicFactorUpdatesCounter = 0;
}
}
public void deferDeterministicUpdates()
{
++_deferDeterministicFactorUpdatesCounter;
}
@Override
public boolean checkAllEdgesAreIncludedInSchedule()
{
return false;
}
public void addBlockInitializer(IBlockInitializer blockInitializer)
{
ArrayList<IBlockInitializer> blockInitializers = _blockInitializers;
if (blockInitializers == null)
blockInitializers = _blockInitializers = new ArrayList<IBlockInitializer>();
blockInitializers.add(blockInitializer);
}
public void removeBlockInitializer(IBlockInitializer blockInitializer)
{
final ArrayList<IBlockInitializer> blockInitializers = _blockInitializers;
if (blockInitializers != null)
blockInitializers.remove(blockInitializer);
}
/*
*
*/
@Override
protected void doUpdateEdge(int edge)
{
}
@Override
protected String getSolverName()
{
return "Gibbs";
}
}