/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.gibbs;
import static com.analog.lyric.dimple.environment.DimpleEnvironment.*;
import static com.analog.lyric.dimple.solvers.gibbs.GibbsSolverVariableEvent.*;
import static java.util.Objects.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.collect.ReleasableIterator;
import com.analog.lyric.dimple.environment.DimpleEnvironment;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunctionUtilities;
import com.analog.lyric.dimple.factorfunctions.core.IUnaryFactorFunction;
import com.analog.lyric.dimple.factorfunctions.core.UnaryJointRealFactorFunction;
import com.analog.lyric.dimple.model.core.EdgeState;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.RealDomain;
import com.analog.lyric.dimple.model.domains.RealJointDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.RealJointValue;
import com.analog.lyric.dimple.model.values.RealValue;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.RealJoint;
import com.analog.lyric.dimple.solvers.core.PriorAndCondition;
import com.analog.lyric.dimple.solvers.core.SRealJointVariableBase;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.IParameterizedMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.MultivariateNormalParameters;
import com.analog.lyric.dimple.solvers.core.proposalKernels.IProposalKernel;
import com.analog.lyric.dimple.solvers.core.proposalKernels.NormalProposalKernel;
import com.analog.lyric.dimple.solvers.gibbs.customFactors.IRealJointConjugateFactor;
import com.analog.lyric.dimple.solvers.gibbs.samplers.ISampler;
import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.IRealConjugateSampler;
import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.IRealJointConjugateSampler;
import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.IRealJointConjugateSamplerFactory;
import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.RealConjugateSamplerRegistry;
import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.RealJointConjugateSamplerRegistry;
import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.IGenericSampler;
import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.IMCMCSampler;
import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.IRealSamplerClient;
import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.MHSampler;
import com.analog.lyric.dimple.solvers.interfaces.ISolverEdgeState;
import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph;
import com.analog.lyric.dimple.solvers.interfaces.ISolverNode;
import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping;
import com.analog.lyric.math.DimpleRandom;
import com.analog.lyric.options.IOptionHolder;
import com.analog.lyric.util.misc.Internal;
import com.analog.lyric.util.misc.Matlab;
import com.google.common.primitives.Doubles;
/*
* WARNING: Whenever editing this class, also make the corresponding edit to SRealVariable.
* The two are nearly identical, but unfortunately couldn't easily be shared due to the class hierarchy
*/
/**
* RealJoint-valued solver variable for Gibbs solver.
* @since 0.07
*/
public class GibbsRealJoint extends SRealJointVariableBase
implements ISolverVariableGibbs, ISolverRealVariableGibbs, IRealSamplerClient
{
/*-----------
* Constants
*/
@SuppressWarnings("hiding")
protected static final int EVENT_MASK = 0x03;
/*-------
* State
*/
private final class CurrentSample extends RealJointValue
{
private static final long serialVersionUID = 1L;
CurrentSample(RealJointDomain domain)
{
super(domain);
reset();
}
@Override
public void setValue(double[] value)
{
// If the sample value is being held, don't modify the value
if (_holdSampleValue) return;
// Also return if the variable is set to a fixed value
if (_model.hasFixedValue()) return;
setCurrentSampleForce(value);
}
@Override
public final void setValue(int index, double value)
{
if (value == _value[index])
{
return;
}
boolean hasDeterministicDependents = getModelObject().isDeterministicInput();
RealJointValue oldValue = null;
if (hasDeterministicDependents)
{
oldValue = _currentSample.clone();
oldValue.setValue(oldValue.getValue().clone());
}
_value[index] = value;
_currentSample.setValue(index, value);
if (hasDeterministicDependents)
{
// If this variable has deterministic dependents, then set their values
setDeterministicDependentValues(Objects.requireNonNull(oldValue));
}
}
private final void setValueForce(double[] value)
{
// FIXME - check for changed value.
boolean hasDeterministicDependents = getModelObject().isDeterministicInput();
RealJointValue oldValue = null;
if (hasDeterministicDependents)
{
oldValue = Value.create(getDomain(), _value);
}
_value = value.clone();
if (hasDeterministicDependents)
{
// If this variable has deterministic dependents, then set their values
setDeterministicDependentValues(Objects.requireNonNull(oldValue));
}
}
private void reset()
{
double[] knownValue = getKnownRealJoint();
_value = knownValue != null ? knownValue.clone() : _initialSampleValue.clone();
}
}
private final CurrentSample _currentSample;
private RealJointValue _prevSample; // Used only by BlastFromThePast factors
private boolean _repeatedVariable;
private @Nullable Object[] _inputMsg = null;
private double[] _initialSampleValue;
private boolean _initialSampleValueSet = false;
private RealJointDomain _domain;
private @Nullable IMCMCSampler _sampler = null;
private @Nullable IRealJointConjugateSampler _conjugateSampler = null;
private boolean _samplerSpecificallySpecified = false;
private @Nullable ArrayList<double[]> _sampleArray;
private @Nullable double[] _sampleSum;
private @Nullable double[][] _sampleSumSquare;
private long _sampleCount;
private double[] _bestSampleValue;
private double _beta = 1;
private boolean _holdSampleValue = false;
private int _numRealVars;
private int _tempIndex = 0;
private boolean _visited = false;
private long _updateCount;
private long _rejectCount;
private long _scoreCount;
/**
* List of neighbors for sample scoring. Instantiated during initialization.
*/
private @Nullable GibbsNeighbors _neighbors = null;
/*--------------
* Construction
*/
public GibbsRealJoint(RealJoint var, GibbsSolverGraph parent)
{
super(var, parent);
_domain = var.getDomain();
_numRealVars = _domain.getNumVars();
_initialSampleValue = new double[_numRealVars];
_bestSampleValue = new double[_numRealVars];
_prevSample = _currentSample = new CurrentSample(_domain);
resetCurrentSample();
}
/*---------------------
* ISolverNode methods
*/
@Override
protected void doUpdateEdge(int outPortNum)
{
throw new DimpleException("Method not supported in Gibbs sampling solver.");
}
@Override
public GibbsSolverGraph getParentGraph()
{
return (GibbsSolverGraph)_parent;
}
@Override
public ISolverFactorGibbs getSibling(int edge)
{
return (ISolverFactorGibbs)super.getSibling(edge);
}
@Override
public final void update()
{
// If the sample value is being held, don't modify the value
if (_holdSampleValue) return;
final RealJoint model = _model;
// Don't bother to re-sample deterministic dependent variables (those that are the output of a directional deterministic factor)
if (model.isDeterministicOutput()) return;
// Also return if the variable is set to a fixed value
if (model.hasFixedValue()) return;
final int updateEventFlags = GibbsSolverVariableEvent.getVariableUpdateEventFlags(this);
Value oldValue = null;
double oldSampleScore = 0.0;
switch (updateEventFlags)
{
case UPDATE_EVENT_SCORED:
// TODO: non-conjugate samplers already compute sample scores, so we shouldn't have to do here.
oldSampleScore = getCurrentSampleScore();
//$FALL-THROUGH$
case UPDATE_EVENT_SIMPLE:
oldValue = _currentSample.clone();
break;
}
// Get the next sample value from the sampler
int rejectCount = 0;
IRealJointConjugateSampler conjugateSampler = _conjugateSampler;
if (conjugateSampler == null)
{
// Use MCMC sampler
RealValue nextSample = RealValue.create();
IMCMCSampler sampler = Objects.requireNonNull(_sampler);
for (int i = 0; i < _numRealVars; i++)
{
_tempIndex = i; // Save this to be used by the call-back from sampler
nextSample.setDouble(_currentSample.getValue(i));
if (!sampler.nextSample(nextSample, this))
{
++rejectCount;
}
_updateCount++; // Updates count each real variable when using an MCMC sampler
_rejectCount += rejectCount;
}
}
else
{
// Use conjugate sampler, first update the messages from all factors
// Factor messages represent the current distribution parameters from each factor
final int numEdges = model.getSiblingCount();
final ISolverEdgeState[] sedges = new ISolverEdgeState[numEdges];
final FactorGraph fg = model.requireParentGraph();
final SolverNodeMapping solvers = getSolverMapping();
final ISolverFactorGraph sfg = solvers.getSolverGraph(fg);
for (int portIndex = 0; portIndex < numEdges; portIndex++)
{
final int edgeIndex = model.getSiblingEdgeIndex(portIndex);
final EdgeState edge = requireNonNull(fg.getGraphEdgeState(edgeIndex));
final GibbsSolverEdge<?> sedge = requireNonNull((GibbsSolverEdge<?>)sfg.getSolverEdge(edgeIndex));
final ISolverFactorGibbs factor = (ISolverFactorGibbs)sfg.getSolverFactorForEdge(edge);
sedges[portIndex] = sedge;
factor.updateEdgeMessage(edge, sedge); // Run updateEdgeMessage for each neighboring factor
}
PriorAndCondition inputs = getPriorAndCondition();
setCurrentSample(conjugateSampler.nextSample(sedges, inputs));
inputs.release();
_updateCount++;
}
switch (updateEventFlags)
{
case UPDATE_EVENT_SCORED:
// TODO: non-conjugate samplers already compute sample scores, so we shouldn't have to do here.
raiseEvent(new GibbsScoredVariableUpdateEvent(this, Objects.requireNonNull(oldValue), oldSampleScore,
_currentSample.clone(), getCurrentSampleScore(), rejectCount));
break;
case UPDATE_EVENT_SIMPLE:
raiseEvent(new GibbsVariableUpdateEvent(this, Objects.requireNonNull(oldValue),
_currentSample.clone(),rejectCount));
break;
}
}
/*---------------------------
* SolverEventSource methods
*/
@Override
protected int getEventMask()
{
return EVENT_MASK | super.getEventMask();
}
/*-------------------------
* ISolverVariable methods
*/
// IRealSampleScorer methods...
// The following methods are for the IRealSampleScorer interface, meant to be called by a sampler
// These are not intended for other purposes
@Override
public final double getSampleScore(Value sampleValue)
{
return getSampleScore(sampleValue.getDouble());
}
@Override
public final double getSampleScore(double sampleValue)
{
// WARNING: Side effect is that the current sample value changes to this sample value
// Could change back but less efficient to do this, since we'll be updating the sample value anyway
setCurrentSample(_tempIndex, sampleValue);
return getCurrentSampleScore();
}
@Override
public final double getCurrentSampleScore()
{
double sampleScore = Double.POSITIVE_INFINITY;
_scoreCount++;
computeScore:
{
if (!_domain.inDomain(_currentSample.getValue()))
break computeScore; // outside the domain
// Sum up the potentials from the prior, condition and all connected factors
PriorAndCondition known = getPriorAndCondition();
double potential = known.evalEnergy(_currentSample);
known = known.release();
if (!Doubles.isFinite(potential))
{
break computeScore;
}
ReleasableIterator<ISolverNodeGibbs> scoreNodes = getSampleScoreNodes();
while (scoreNodes.hasNext())
{
potential += scoreNodes.next().getPotential();
if (!Doubles.isFinite(potential))
{
break computeScore;
}
}
scoreNodes.release();
sampleScore = potential * _beta; // Incorporate current temperature
}
return sampleScore;
}
@Override
public final void setNextSampleValue(Value sampleValue)
{
setNextSampleValue(sampleValue.getDouble());
}
@Override
public final void setNextSampleValue(double sampleValue)
{
setCurrentSample(_tempIndex, sampleValue);
}
// TODO move to local methods?
// For conjugate samplers
public final @Nullable IRealJointConjugateSampler getConjugateSampler()
{
return _conjugateSampler;
}
/*----------------------------------
* ISolverRealVariableGibbs methods
* TODO: move below ISolverVariableGibbs methods
*/
@Override
public final void getAggregateMessages(IParameterizedMessage outputMessage, int outPortNum, ISampler conjugateSampler)
{
final RealJoint model = _model;
final FactorGraph fg = model.requireParentGraph();
final SolverNodeMapping solvers = getSolverMapping();
final ISolverFactorGraph sfg = solvers.getSolverGraph(fg);
final int numEdges = model.getSiblingCount();
final ISolverEdgeState[] sedges = new ISolverEdgeState[numEdges - 1];
for (int port = 0, i = 0; port < numEdges; port++)
{
if (port != outPortNum)
{
final int edgeIndex = model.getSiblingEdgeIndex(port);
final EdgeState edgeState = requireNonNull(fg.getGraphEdgeState(edgeIndex));
final GibbsSolverEdge<?> sedge = requireNonNull((GibbsSolverEdge<?>)sfg.getSolverEdge(edgeIndex));
final ISolverFactorGibbs factor = (ISolverFactorGibbs)sfg.getSolverFactorForEdge(edgeState);
sedges[i++] = sfg.getSolverEdge(edgeIndex);
factor.updateEdgeMessage(edgeState, sedge); // Run updateEdgeMessage for each neighboring factor
}
}
PriorAndCondition inputs = getPriorAndCondition();
((IRealJointConjugateSampler)conjugateSampler).aggregateParameters(outputMessage, sedges, inputs);
inputs.release();
}
/*--------------------------
* ISolverNodeGibbs methods
*/
@Override
public boolean setVisited(boolean visited)
{
boolean changed = _visited ^ visited;
_visited = visited;
return changed;
}
/*------------------------------
* ISolverVariableGibbs methods
*/
@Override
public RealJointValue getCurrentSampleValue()
{
return _currentSample;
}
@Internal
@Override
public RealJointValue getPrevSampleValue()
{
return _prevSample;
}
@Override
public ReleasableIterator<ISolverNodeGibbs> getSampleScoreNodes()
{
return GibbsNeighbors.iteratorFor(_neighbors, this);
}
@Override
public void randomRestart(int restartCount)
{
// If the sample value is being held, don't modify the value
if (_holdSampleValue) return;
// If the variable is the output of a directed deterministic factor, then don't modify the value--it should already be set correctly
if (getModelObject().isDeterministicOutput()) return;
// If the variable has a fixed value, then set the current sample to that value and return
double[] knownValue = getKnownRealJoint();
if (knownValue != null)
{
setCurrentSample(knownValue);
return;
}
if (_initialSampleValueSet && restartCount == 0)
{
setCurrentSample(_initialSampleValue);
return;
}
final DimpleRandom rand = activeRandom();
// If the variable has a prior, sample from that (bounded by the domain)
final IUnaryFactorFunction priorJoint = _model.getPriorFunction();
List<? extends IUnaryFactorFunction> priorArray = null;
if (priorJoint instanceof UnaryJointRealFactorFunction)
{
priorArray = ((UnaryJointRealFactorFunction)priorJoint).realFunctions();
}
else if (priorJoint instanceof MultivariateNormalParameters)
{
priorArray = ((MultivariateNormalParameters)priorJoint).getDiagonalNormals();
}
if (priorArray != null)
{
for (int i = 0; i < _numRealVars; i++)
{
RealDomain realDomain = _domain.getRealDomain(i);
IUnaryFactorFunction prior = priorArray.get(i);
// If there are inputs, see if there's an available conjugate sampler
IRealConjugateSampler inputConjugateSampler = null; // Don't use the global conjugate sampler since other factors might not be conjugate
if (prior != null)
inputConjugateSampler = RealConjugateSamplerRegistry.findCompatibleSampler(prior);
// Determine if there are bounds
double hi = realDomain.getUpperBound();
double lo = realDomain.getLowerBound();
if (inputConjugateSampler != null)
{
// Sample from the input if there's an available sampler
double sampleValue = inputConjugateSampler.nextSample(new ISolverEdgeState[0],
Collections.singletonList(prior));
// If there are also bounds, clip at the bounds
if (sampleValue > hi) sampleValue = hi;
if (sampleValue < lo) sampleValue = lo;
setCurrentSample(i, sampleValue);
}
else
{
// No available sampler, so if bounded, sample uniformly from the bounds
if (hi < Double.POSITIVE_INFINITY && lo > Double.NEGATIVE_INFINITY)
setCurrentSample(i, activeRandom().nextDouble() * (hi - lo) + lo);
}
}
}
else if (priorJoint != null) // Input is a joint input
{
// Don't use the global conjugate sampler since other factors might not be conjugate
IRealJointConjugateSampler inputConjugateSampler =
RealJointConjugateSamplerRegistry.findCompatibleSampler(priorJoint);
if (inputConjugateSampler != null)
{
double[] sampleValue =
inputConjugateSampler.nextSample(new ISolverEdgeState[0], Collections.singletonList(priorJoint));
// Clip if necessary
for (int i = 0; i < _numRealVars; i++)
{
// Determine if there are bounds
RealDomain realDomain = _domain.getRealDomain(i);
double hi = realDomain.getUpperBound();
double lo = realDomain.getLowerBound();
// If there are also bounds, clip at the bounds
if (sampleValue[i] > hi) sampleValue[i] = hi;
if (sampleValue[i] < lo) sampleValue[i] = lo;
}
setCurrentSample(sampleValue);
}
else // No available conjugate sampler
{
for (int i = 0; i < _numRealVars; i++)
{
// Determine if there are bounds
RealDomain realDomain = _domain.getRealDomain(i);
double hi = realDomain.getUpperBound();
double lo = realDomain.getLowerBound();
// No available sampler, so if bounded, sample uniformly from the bounds
if (hi < Double.POSITIVE_INFINITY && lo > Double.NEGATIVE_INFINITY)
setCurrentSample(i, rand.nextDouble() * (hi - lo) + lo);
else if (hi < _currentSample.getValue(i))
setCurrentSample(i, hi);
else if (lo > _currentSample.getValue(i))
setCurrentSample(i, lo);
}
}
}
else // There are no inputs
{
for (int i = 0; i < _numRealVars; i++)
{
// Determine if there are bounds
RealDomain realDomain = _domain.getRealDomain(i);
double hi = realDomain.getUpperBound();
double lo = realDomain.getLowerBound();
// If bounded, sample uniformly from the bounds, otherwise leave current sample value
if (hi < Double.POSITIVE_INFINITY && lo > Double.NEGATIVE_INFINITY)
setCurrentSample(i, rand.nextDouble() * (hi - lo) + lo);
}
}
}
@Override
public final void updateBelief()
{
if (_sampleSum != null)
{
// Update the sums for computing moments
for (int i = 0; i < _numRealVars; i++)
{
final double vi = _currentSample.getValue(i);
requireNonNull(_sampleSum)[i] += vi;
for (int j = i; j < _numRealVars; j++)
{
final double vj = _currentSample.getValue(j);
requireNonNull(_sampleSumSquare)[i][j] += vi * vj;
}
}
}
_sampleCount++;
}
@Override
public void updatePriorAndCondition()
{
Value value = getKnownValue();
if (value != null)
{
setCurrentSampleForce(value.getDoubleArray());
}
}
@Override
public Object getBelief()
{
return 0d;
}
@Matlab
public final double[] getSampleMean()
{
if (_sampleSum != null)
{
final double[] mean = new double[_numRealVars];
for (int i = 0; i < _numRealVars; i++)
mean[i] = requireNonNull(_sampleSum)[i] / requireNonNull(_sampleCount);
return mean;
}
else
{
throw new DimpleException("The sample mean is only computed if the option GibbsOptions.computeRealJointBeliefMoments has been set to true");
}
}
@Matlab
public final double[][] getSampleCovariance()
{
if (_sampleSum != null)
{
// For all sample values, compute the covariance matrix
// For now, use the naive algorithm; could be improved
final double[][] covariance = new double[_numRealVars][_numRealVars];
final double sampleCount = _sampleCount;
final double sampleCountMinusOne = (sampleCount - 1);
for (int i = 0; i < _numRealVars; i++)
{
for (int j = i; j < _numRealVars; j++)
{
final double sumi = requireNonNull(_sampleSum)[i];
final double sumj = requireNonNull(_sampleSum)[j];
final double sumij = requireNonNull(_sampleSumSquare)[i][j];
final double value = (sumij - sumi * (sumj / sampleCount) ) / sampleCountMinusOne;
covariance[i][j] = value;
covariance[j][i] = value; // Fill in lower triangular half
}
}
return covariance;
}
else
{
throw new DimpleException("The sample covariance is only computed if the option GibbsOptions.computeRealJointBeliefMoments has been set to true");
}
}
@SuppressWarnings("null")
@Override
public void postAddFactor(@Nullable Factor f)
{
// Set the default sampler
}
@Override
public double[] getGuess()
{
if (_guessWasSet)
return _guessValue;
final double[] knownValue = getKnownRealJoint();
if (knownValue != null)
return knownValue;
else
return _currentSample.getValue();
}
/**
* @deprecated Instead set {@link GibbsOptions#saveAllSamples} to true using {@link #setOption}.
*/
@Deprecated
@Override
public final void saveAllSamples()
{
_sampleArray = new ArrayList<double[]>();
setOption(GibbsOptions.saveAllSamples, true);
}
/**
* @deprecated Instead set {@link GibbsOptions#saveAllSamples} to false using {@link #setOption}.
*/
@Deprecated
@Override
public void disableSavingAllSamples()
{
_sampleArray = null;
setOption(GibbsOptions.saveAllSamples, false);
}
@Override
public final void saveCurrentSample()
{
final ArrayList<double[]> sampleArray = _sampleArray;
if (sampleArray != null)
sampleArray.add(_currentSample.getValue().clone());
}
@Override
public final void saveBestSample()
{
_bestSampleValue = _currentSample.getValue().clone();
}
@Override
public final double getPotential()
{
return evalPriorAndConditionEnergy(_currentSample);
}
@Override
public final boolean hasPotential()
{
return canHavePriorAndConditionEnergy();
}
@Override
public final void setCurrentSample(Object value)
{
setCurrentSample((double[])value);
}
@Override
public final void setCurrentSample(Value value)
{
setCurrentSample(value.getDoubleArray());
}
/*---------------
* Local methods
*/
public final void setCurrentSample(double[] value)
{
_currentSample.setValue(value);
}
// Sets the sample regardless of whether the value is fixed or held
private final void setCurrentSampleForce(double[] value)
{
_currentSample.setValueForce(value);
}
// Set a specific element of the sample value
public final void setCurrentSample(int index, Object value)
{
setCurrentSample(index, FactorFunctionUtilities.toDouble(value));
}
public final void setCurrentSample(int index, double value)
{
_currentSample.setValue(index, value);
}
private final void setDeterministicDependentValues(RealJointValue oldValue)
{
final GibbsNeighbors neighbors = _neighbors;
if (neighbors != null)
neighbors.update(oldValue);
}
@Matlab
public final double[] getCurrentSample()
{
return _currentSample.getValue();
}
@Matlab
public final double[] getBestSample()
{
return _bestSampleValue;
}
@Matlab
@Override
public final double[][] getAllSamples()
{
final ArrayList<double[]> sampleArray = _sampleArray;
if (sampleArray == null)
{
return ArrayUtil.EMPTY_DOUBLE_ARRAY_ARRAY;
}
int length = sampleArray.size();
double[][] retval = new double[length][];
for (int i = 0; i < length; i++)
retval[i] = sampleArray.get(i);
return retval;
}
@Override
public final double getRejectionRate()
{
return (_updateCount > 0) ? (double)_rejectCount / (double)_updateCount : 0;
}
@Override
public final double getNumScoresPerUpdate()
{
return (_updateCount > 0) ? (double)_scoreCount / (double)_updateCount : 0;
}
@Override
public final void resetRejectionRateStats()
{
_updateCount = 0;
_rejectCount = 0;
_scoreCount = 0;
}
@Override
public final long getUpdateCount()
{
return _updateCount;
}
@Override
public final long getRejectionCount()
{
return _rejectCount;
}
// This is meant for internal use, not as a user accessible method
public final @Nullable List<double[]> _getSampleArrayUnsafe()
{
return _sampleArray;
}
public final void setAndHoldSampleValue(double[] value)
{
releaseSampleValue();
setCurrentSample(value);
holdSampleValue();
}
public final void holdSampleValue()
{
_holdSampleValue = true;
}
public final void releaseSampleValue()
{
_holdSampleValue = false;
}
/**
* @deprecated Will be removed in future release. Instead set corresponding option
* for desired proposal kernel (e.g. {@link NormalProposalKernel#standardDeviation}.
*/
@SuppressWarnings("null")
@Deprecated
public final void setProposalStandardDeviation(double stdDev)
{
if (_sampler instanceof MHSampler)
((MHSampler)_sampler).getProposalKernel().setParameters(stdDev);
}
/**
* @deprecated Will be removed in future release. Instead lookup corresponding option
* for desired proposal kernel (e.g. {@link NormalProposalKernel#standardDeviation}.
*/
@SuppressWarnings("null")
@Deprecated
public final double getProposalStandardDeviation()
{
if (_sampler instanceof MHSampler)
return (Double)((MHSampler)_sampler).getProposalKernel().getParameters()[0];
else
return 0;
}
/**
* @deprecated Will be removed in future release. Instead set appropriate options
* for proposal kernel using {@link #setOption}.
*/
@SuppressWarnings("null")
@Deprecated
public final void setProposalKernelParameters(Object... parameters)
{
if (_sampler instanceof MHSampler)
((MHSampler)_sampler).getProposalKernel().setParameters(parameters);
}
// There should be a way to call these directly via the samplers
// If so, they should be removed from here since this makes this sampler-specific
@SuppressWarnings("null")
@Deprecated
public final void setProposalKernel(IProposalKernel proposalKernel) // IProposalKernel object
{
if (_sampler instanceof MHSampler)
((MHSampler)_sampler).setProposalKernel(proposalKernel);
}
@SuppressWarnings("null")
@Deprecated
public final void setProposalKernel(String proposalKernelName) // Name of proposal kernel
{
if (_sampler instanceof MHSampler)
((MHSampler)_sampler).setProposalKernel(proposalKernelName);
}
@SuppressWarnings("null")
@Deprecated
public final @Nullable IProposalKernel getProposalKernel()
{
if (_sampler instanceof MHSampler)
return ((MHSampler)_sampler).getProposalKernel();
else
return null;
}
/**
* @deprecated Will be removed in future release. Instead use {@link GibbsOptions#realSampler}
* option.
*/
@Deprecated
public final void setDefaultSampler(String samplerName)
{
GibbsOptions.realSampler.convertAndSet(this, samplerName);
}
/**
* @deprecated Will be removed in future release. Instead use {@link GibbsOptions#realSampler}
* option.
*/
@Deprecated
public final String getDefaultSamplerName()
{
return getOptionOrDefault(GibbsOptions.realSampler).getSimpleName();
}
/**
* Sets sampler to be used for this variable.
* <p>
* In general, it is usually easier to configure the sampler using the
* {@link GibbsOptions#realSampler} option. This method should only be
* required when the sampler class is not registered with the
* {@linkplain DimpleEnvironment#genericSamplers() generic sampler registry}
* for the current environment.
* <p>
* @param sampler is a non-null sampler.
*/
public final void setSampler(ISampler sampler)
{
_sampler = (IMCMCSampler)sampler;
_samplerSpecificallySpecified = true;
}
/**
* @deprecated Will be removed in future release. Instead set sampler by setting
* {@link GibbsOptions#realSampler} option using {@link #setOption}.
*/
@Matlab
@Deprecated
public final void setSampler(String samplerName)
{
GibbsOptions.realSampler.convertAndSet(this, samplerName);
IMCMCSampler sampler = (IMCMCSampler) GibbsOptions.realSampler.instantiateIfDifferent(this, _sampler);
_sampler = sampler;
_samplerSpecificallySpecified = true;
sampler.initializeFromVariable(this);
}
@Matlab
@Override
public final @Nullable ISampler getSampler()
{
if (_samplerSpecificallySpecified)
{
Objects.requireNonNull(_sampler).initializeFromVariable(this);
return _sampler;
}
else if (_model.hasParentGraph())
{
initialize(); // To determine the appropriate sampler
if (_conjugateSampler == null)
{
Objects.requireNonNull(_sampler).initializeFromVariable(this);
return _sampler;
}
else
{
return _conjugateSampler;
}
}
else
return null;
}
@Matlab
public final String getSamplerName()
{
ISampler sampler = getSampler();
if (sampler != null)
return sampler.getClass().getSimpleName();
else
return "";
}
@Matlab
public final void setInitialSampleValue(double[] initialSampleValue)
{
_initialSampleValue = initialSampleValue;
_initialSampleValueSet = true;
}
@Matlab
public final double[] getInitialSampleValue()
{
return _initialSampleValue;
}
// TODO move to ISolverVariableGibbs
@Override
public final void setBeta(double beta) // beta = 1/temperature
{
_beta = beta;
}
public void resetCurrentSample()
{
final double[] knownValue = getKnownRealJoint();
_currentSample.setValue(knownValue != null ? knownValue.clone() : _initialSampleValue.clone());
}
@Deprecated
@Override
public @Nullable Object getInputMsg(int portIndex)
{
return _inputMsg;
}
@Deprecated
@Override
public @Nullable Object getOutputMsg(int portIndex)
{
return _currentSample;
}
@Deprecated
@Override
public void setInputMsgValues(int portIndex, Object obj)
{
Object[] inputMsg = _inputMsg;
if (inputMsg == null)
inputMsg = new Object[_model.getSiblingCount()];
inputMsg[portIndex] = obj;
}
// TODO move to ISolverNode
@Override
public void initialize()
{
final boolean saveAllSamples = getOptionOrDefault(GibbsOptions.saveAllSamples);
super.initialize();
// We actually only need to change this if the model has changed in the vicinity of this variable,
// but that may not be worth the trouble to figure out.
_neighbors = GibbsNeighbors.create(this);
if (!getModelObject().isDeterministicOutput())
{
final double[] knownValue = getKnownRealJoint();
double[] initialSampleValue = knownValue != null ? knownValue : _initialSampleValue;
if (!_holdSampleValue)
setCurrentSampleForce(initialSampleValue);
}
// Clear out sample state
_bestSampleValue = _currentSample.getValue();
ArrayList<double[]> sampleArray = null;
if (saveAllSamples)
{
sampleArray = _sampleArray;
if (sampleArray == null)
{
sampleArray = new ArrayList<double[]>();
}
else
{
sampleArray.clear();
}
}
_sampleArray = sampleArray;
// Clear out the Belief statistics
if (getOptionOrDefault(GibbsOptions.computeRealJointBeliefMoments))
{
if (_sampleSum == null)
_sampleSum = new double[_numRealVars];
if (_sampleSumSquare == null)
_sampleSumSquare = new double[_numRealVars][_numRealVars];
Arrays.fill(_sampleSum, 0);
for (int i = 0; i < _numRealVars; i++)
Arrays.fill(requireNonNull(_sampleSumSquare)[i], 0);
}
else
{
_sampleSum = null;
_sampleSumSquare = null;
}
_sampleCount = 0;
updatePriorAndCondition();
//
// Determine which sampler to use
//
_conjugateSampler = null;
if (!_samplerSpecificallySpecified)
{
IOptionHolder[] source = new IOptionHolder[1];
Class<? extends IGenericSampler> samplerClass = getOptionAndSource(GibbsOptions.realSampler, source);
if (samplerClass == null || source[0] != this && source[0] != _model)
{
if (getOptionOrDefault(GibbsOptions.enableAutomaticConjugateSampling))
{
// See if there's an available conjugate sampler, and if so, use it
_conjugateSampler = findConjugateSampler();
}
}
if (_conjugateSampler == null)
{
if (samplerClass == null)
{
samplerClass = GibbsOptions.realSampler.defaultValue();
}
IMCMCSampler sampler = _sampler;
if (sampler == null || sampler.getClass() != samplerClass)
{
try
{
_sampler = sampler = (IMCMCSampler)samplerClass.newInstance();
}
catch (InstantiationException | IllegalAccessException ex)
{
throw new RuntimeException(ex);
}
}
sampler.initializeFromVariable(this);
resetRejectionRateStats();
}
}
}
// TODO move to ISolverVariable
@Override
public void createNonEdgeSpecificState()
{
resetCurrentSample();
_bestSampleValue = _currentSample.getValue();
if (_sampleArray != null)
saveAllSamples();
}
// TODO move to ISolverVariable
@Override
public void moveNonEdgeSpecificState(ISolverNode other)
{
GibbsRealJoint ovar = ((GibbsRealJoint)other);
ovar._prevSample = _currentSample;
ovar._repeatedVariable = true;
if (!_repeatedVariable)
{
if (_prevSample == _currentSample)
{
// If not already pointing at a different value object, then this must be the
// the first variable in the stream, so make a copy of its state.
_prevSample = _currentSample.clone();
}
else
{
_prevSample.setFrom(_currentSample);
}
}
Value fixedValue = _model.getPriorValue();
if (fixedValue != null)
_currentSample.setValueForce(fixedValue.getDoubleArray());
else
_currentSample.setFrom(ovar._currentSample);
_initialSampleValue = ovar._initialSampleValue;
_initialSampleValueSet = ovar._initialSampleValueSet;
_sampleArray = ovar._sampleArray;
_bestSampleValue = ovar._bestSampleValue;
_beta = ovar._beta;
_holdSampleValue = ovar._holdSampleValue;
_numRealVars = ovar._numRealVars;
_sampleSum = ovar._sampleSum;
_sampleSumSquare = ovar._sampleSumSquare;
_sampleCount = ovar._sampleCount;
// Field values intentionally NOT moved:
// _sampler
// _conjugateSampler
// _samplerSpecificallySpecified
// _updateCount
// _rejectCount
}
// Get the dimension of the joint variable
public int getDimension()
{
return _numRealVars;
}
// Find a single conjugate sampler consistent with all neighboring factors and the Input
public @Nullable IRealJointConjugateSampler findConjugateSampler()
{
Set<IRealJointConjugateSamplerFactory> availableSamplerFactories = findConjugateSamplerFactories();
if (availableSamplerFactories.isEmpty())
return null; // No available conjugate sampler
else
return availableSamplerFactories.iterator().next().create(); // Get the first one and create the sampler
}
// Find the set of all available conjugate samplers, but don't create it yet
public Set<IRealJointConjugateSamplerFactory> findConjugateSamplerFactories()
{
return findConjugateSamplerFactories(_model.getSiblingEdgeState());
}
// Find the set of available conjugate samplers consistent with a specific set of neighboring factors (as well as the Input)
public Set<IRealJointConjugateSamplerFactory> findConjugateSamplerFactories(Collection<EdgeState> edges)
{
final Set<IRealJointConjugateSamplerFactory> commonSamplers = new HashSet<>();
final FactorGraph fg = _model.requireParentGraph();
final SolverNodeMapping solvers = getSolverMapping();
// Check all the adjacent factors to see if they all support a common conjugate factor
for (EdgeState edgeState : edges)
{
ISolverNode factor = solvers.getSolverFactor(edgeState.getFactor(fg));
if (!(factor instanceof IRealJointConjugateFactor))
{
commonSamplers.clear(); // At least one connected factor does not support conjugate sampling
return commonSamplers;
}
int factorPortNumber = edgeState.getFactorToVariableEdgeNumber();
Set<IRealJointConjugateSamplerFactory> availableSamplers =
((IRealJointConjugateFactor)factor).getAvailableRealJointConjugateSamplers(factorPortNumber);
if (commonSamplers.isEmpty())
{
commonSamplers.addAll(availableSamplers);
}
else
{
commonSamplers.retainAll(availableSamplers);
}
}
// Next, check conjugate samplers are also compatible with the input and the domain of this variable
final IUnaryFactorFunction inputJoint = _model.getPriorFunction();
IUnaryFactorFunction condition = getConditionFunction();
Iterator<IRealJointConjugateSamplerFactory> iter = commonSamplers.iterator();
while (iter.hasNext())
{
IRealJointConjugateSamplerFactory sampler = iter.next();
if (!sampler.isCompatible(inputJoint) || !sampler.isCompatible(condition) || !sampler.isCompatible(_domain))
{
iter.remove();
}
}
return commonSamplers;
}
@SuppressWarnings("null")
@Override
public GibbsSolverEdge<?> getSiblingEdgeState(int siblingIndex)
{
return (GibbsSolverEdge<?>)getSiblingEdgeState_(siblingIndex);
}
}