/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.gibbs; import static com.analog.lyric.dimple.environment.DimpleEnvironment.*; import static com.analog.lyric.dimple.solvers.gibbs.GibbsSolverVariableEvent.*; import static java.util.Objects.*; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Set; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.ArrayUtil; import com.analog.lyric.collect.ReleasableIterator; import com.analog.lyric.dimple.data.IDatum; import com.analog.lyric.dimple.environment.DimpleEnvironment; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.factorfunctions.core.FactorFunctionUtilities; import com.analog.lyric.dimple.factorfunctions.core.IUnaryFactorFunction; import com.analog.lyric.dimple.model.core.EdgeState; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.domains.RealDomain; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.RealValue; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.solvers.core.PriorAndCondition; import com.analog.lyric.dimple.solvers.core.SRealVariableBase; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.IParameterizedMessage; import com.analog.lyric.dimple.solvers.core.proposalKernels.IProposalKernel; import com.analog.lyric.dimple.solvers.core.proposalKernels.NormalProposalKernel; import com.analog.lyric.dimple.solvers.gibbs.customFactors.IRealConjugateFactor; import com.analog.lyric.dimple.solvers.gibbs.samplers.ISampler; import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.IConjugateSampler; import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.IRealConjugateSampler; import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.IRealConjugateSamplerFactory; import com.analog.lyric.dimple.solvers.gibbs.samplers.conjugate.RealConjugateSamplerRegistry; import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.IGenericSampler; import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.IMCMCSampler; import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.IRealSamplerClient; import com.analog.lyric.dimple.solvers.gibbs.samplers.generic.MHSampler; import com.analog.lyric.dimple.solvers.interfaces.ISolverEdgeState; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.dimple.solvers.interfaces.ISolverNode; import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping; import com.analog.lyric.options.IOptionHolder; import com.analog.lyric.util.misc.Internal; import com.analog.lyric.util.misc.Matlab; import com.google.common.primitives.Doubles; import cern.colt.list.DoubleArrayList; /* * WARNING: Whenever editing this class, also make the corresponding edit to SRealJointVariable. * The two are nearly identical, but unfortunately couldn't easily be shared due to the class hierarchy * */ /** * Real-valued solver variable for Gibbs solver. * @since 0.07 */ public class GibbsReal extends SRealVariableBase implements ISolverVariableGibbs, ISolverRealVariableGibbs, IRealSamplerClient { /*----------- * Constants */ /** * Bits in {@link #_flags} reserved by this class and its superclasses. */ @SuppressWarnings("hiding") protected final static int RESERVED_FLAGS = 0xFFFF0003; @SuppressWarnings("hiding") protected static final int EVENT_MASK = 0x03; /*------- * State */ private class CurrentSample extends RealValue { private static final long serialVersionUID = 1L; CurrentSample(RealDomain domain) { super(0.0); reset(); } @Override public void setDouble(double value) { // If the sample value is being held, don't modify the value if (_holdSampleValue) { return; } // Also return if the variable is set to a fixed value if (_model.hasFixedValue()) { return; } if (value != _value) { setDoubleForce(value); } } void setDoubleForce(double value) { final GibbsNeighbors neighbors = _neighbors; final boolean hasDeterministicDependents = neighbors != null && neighbors.hasDeterministicDependents(); RealValue oldValue = null; if (hasDeterministicDependents) { oldValue = RealValue.create(_value); } _value = value; // If this variable has deterministic dependents, then set their values if (hasDeterministicDependents) { requireNonNull(neighbors).update(requireNonNull(oldValue)); } } void reset() { double knownValue = getKnownReal(); _value = knownValue == knownValue ? knownValue : _initialSampleValue; } } public static final String DEFAULT_REAL_SAMPLER_NAME = "SliceSampler"; private final CurrentSample _currentSample; private RealValue _prevSample; // Used only by BlastFromThePast factors private boolean _repeatedVariable; private double _initialSampleValue = 0; private boolean _initialSampleValueSet = false; private final RealDomain _domain; private @Nullable IMCMCSampler _sampler = null; private @Nullable IRealConjugateSampler _conjugateSampler = null; private boolean _samplerSpecificallySpecified = false; private @Nullable DoubleArrayList _sampleArray; private double _sampleSum; private double _sampleSumSquare; private long _sampleCount; private double _bestSampleValue; private double _beta = 1; private boolean _holdSampleValue = false; private boolean _visited = false; private long _updateCount; private long _rejectCount; private long _scoreCount; /** * List of neighbors for sample scoring. Instantiated during initialization. */ private @Nullable GibbsNeighbors _neighbors = null; /*-------------- * Construction */ // Primary constructor public GibbsReal(Real var, GibbsSolverGraph parent) { super(var, parent); _domain = var.getDomain(); _prevSample = _currentSample = new CurrentSample(_domain); } /*--------------------- * ISolverNode methods */ @Override public GibbsSolverGraph getParentGraph() { return (GibbsSolverGraph)_parent; } @Override public ISolverFactorGibbs getSibling(int edge) { return (ISolverFactorGibbs)super.getSibling(edge); } @Override protected void doUpdateEdge(int outPortNum) { throw new DimpleException("Method not supported in Gibbs sampling solver."); } @Override public final void update() { // If the sample value is being held, don't modify the value if (_holdSampleValue) return; final Real model = _model; // Don't bother to re-sample deterministic dependent variables (those that are the output of a directional deterministic factor) if (model.isDeterministicOutput()) return; // Also return if the variable is set to a fixed value if (model.hasFixedValue()) return; final int updateEventFlags = GibbsSolverVariableEvent.getVariableUpdateEventFlags(this); Value oldValue = null; double oldSampleScore = 0.0; switch (updateEventFlags) { case UPDATE_EVENT_SCORED: // TODO: non-conjugate samplers already compute sample scores, so we shouldn't have to do here. oldSampleScore = getCurrentSampleScore(); //$FALL-THROUGH$ case UPDATE_EVENT_SIMPLE: oldValue = _currentSample.clone(); break; } // Get the next sample value from the sampler boolean rejected = false; _updateCount++; final IRealConjugateSampler conjugateSampler = _conjugateSampler; if (conjugateSampler == null) { // Use MCMC sampler RealValue nextSample = RealValue.create(_currentSample.getDouble()); rejected = !Objects.requireNonNull(_sampler).nextSample(nextSample, this); if (rejected) _rejectCount++; } else { // Use conjugate sampler, first update the messages from all factors // Factor messages represent the current distribution parameters from each factor final int numEdges = model.getSiblingCount(); ISolverEdgeState[] sedges = new ISolverEdgeState[numEdges]; final FactorGraph fg = model.requireParentGraph(); final SolverNodeMapping solvers = getSolverMapping(); final ISolverFactorGraph sfg = solvers.getSolverGraph(fg); for (int portIndex = 0; portIndex < numEdges; portIndex++) { final int edgeIndex = model.getSiblingEdgeIndex(portIndex); final EdgeState edge = requireNonNull(fg.getGraphEdgeState(edgeIndex)); final GibbsSolverEdge<?> sedge = requireNonNull((GibbsSolverEdge<?>)sfg.getSolverEdge(edgeIndex)); final ISolverFactorGibbs factor = (ISolverFactorGibbs)sfg.getSolverFactorForEdge(edge); sedges[portIndex] = sedge; factor.updateEdgeMessage(edge, sedge); // Run updateEdgeMessage for each neighboring factor } PriorAndCondition inputs = getPriorAndCondition(); double nextSampleValue = conjugateSampler.nextSample(sedges, inputs); inputs.release(); if (nextSampleValue != _currentSample.getDouble()) // Would be exactly equal if not changed since last value tested setCurrentSample(nextSampleValue); } switch (updateEventFlags) { case UPDATE_EVENT_SCORED: // TODO: non-conjugate samplers already compute sample scores, so we shouldn't have to do here. raiseEvent(new GibbsScoredVariableUpdateEvent(this, Objects.requireNonNull(oldValue), oldSampleScore, _currentSample, getCurrentSampleScore(), rejected ? 1 : 0)); break; case UPDATE_EVENT_SIMPLE: raiseEvent(new GibbsVariableUpdateEvent(this, Objects.requireNonNull(oldValue), _currentSample, rejected ? 1 : 0)); break; } } /*--------------------------- * SolverEventSource methods */ @Override protected int getEventMask() { return EVENT_MASK | super.getEventMask(); } /*------------------------- * ISolverVariable methods */ // IRealSampleScorer methods... // The following methods are for the IRealSampleScorer interface, meant to be called by a sampler // These are not intended for other purposes @Override public final double getSampleScore(Value sampleValue) { return getSampleScore(sampleValue.getDouble()); } @Override public final double getSampleScore(double sampleValue) { // WARNING: Side effect is that the current sample value changes to this sample value // Could change back but less efficient to do this, since we'll be updating the sample value anyway setCurrentSample(sampleValue); return getCurrentSampleScore(); } @Override public final double getCurrentSampleScore() { double sampleScore = Double.POSITIVE_INFINITY; _scoreCount++; computeScore: { if (!_domain.inDomain(_currentSample.getDouble())) break computeScore; // outside the domain // Sum up the potentials from the prior, condition and all connected factors PriorAndCondition known = getPriorAndCondition(); double potential = known.evalEnergy(_currentSample); known = known.release(); if (!Doubles.isFinite(potential)) { break computeScore; } ReleasableIterator<ISolverNodeGibbs> scoreNodes = getSampleScoreNodes(); while (scoreNodes.hasNext()) { final ISolverNodeGibbs node = scoreNodes.next(); potential += node.getPotential(); if (!Doubles.isFinite(potential)) { break computeScore; } } scoreNodes.release(); sampleScore = potential * _beta; // Incorporate current temperature } return sampleScore; } @Override public final void setNextSampleValue(Value sampleValue) { setNextSampleValue(sampleValue.getDouble()); } @Override public final void setNextSampleValue(double sampleValue) { if (sampleValue != _currentSample.getDouble()) setCurrentSample(sampleValue); } // TODO move to local methods? // For conjugate samplers public final @Nullable IRealConjugateSampler getConjugateSampler() { return _conjugateSampler; } @Override public void updatePriorAndCondition() { Value value = getKnownValue(); if (value != null) { setCurrentSampleForce(value.getDouble()); } } /*---------------------------------- * ISolverRealVariableGibbs methods * TODO: move below ISolverVariableGibbs methods */ @Override public final void getAggregateMessages(IParameterizedMessage outputMessage, int outPortNum, ISampler conjugateSampler) { final Real model = _model; final FactorGraph fg = model.requireParentGraph(); final SolverNodeMapping solvers = getSolverMapping(); final ISolverFactorGraph sfg = solvers.getSolverGraph(fg); final int numEdges = model.getSiblingCount(); final ISolverEdgeState[] sedges = new ISolverEdgeState[numEdges - 1]; for (int port = 0, i = 0; port < numEdges; port++) { if (port != outPortNum) { final int edgeIndex = model.getSiblingEdgeIndex(port); final EdgeState edgeState = requireNonNull(fg.getGraphEdgeState(edgeIndex)); final GibbsSolverEdge<?> sedge = requireNonNull((GibbsSolverEdge<?>)sfg.getSolverEdge(edgeIndex)); final ISolverFactorGibbs factor = (ISolverFactorGibbs)sfg.getSolverFactorForEdge(edgeState); sedges[i++] = sfg.getSolverEdge(edgeIndex); factor.updateEdgeMessage(edgeState, sedge); // Run updateEdgeMessage for each neighboring factor } } PriorAndCondition known = getPriorAndCondition(); ((IConjugateSampler)conjugateSampler).aggregateParameters(outputMessage, sedges, known); known.release(); } /*-------------------------- * ISolverNodeGibbs methods */ @Override public boolean setVisited(boolean visited) { boolean changed = _visited ^ visited; _visited = visited; return changed; } /*------------------------------ * ISolverVariableGibbs methods */ @Override public RealValue getCurrentSampleValue() { return _currentSample; } @Internal @Override public final RealValue getPrevSampleValue() { return _prevSample; } @Override public ReleasableIterator<ISolverNodeGibbs> getSampleScoreNodes() { return GibbsNeighbors.iteratorFor(_neighbors, this); } @Override public void randomRestart(int restartCount) { // If the sample value is being held, don't modify the value if (_holdSampleValue) return; // If the variable is the output of a directed deterministic factor, then don't modify the value--it should already be set correctly if (getModelObject().isDeterministicOutput()) return; // If the variable has a fixed value, then set the current sample to that value and return double knownValue = getKnownReal(); if (knownValue == knownValue) { setCurrentSample(knownValue); return; } if (_initialSampleValueSet && restartCount == 0) { setCurrentSample(_initialSampleValue); return; } // If there are inputs, see if there's an available conjugate sampler IRealConjugateSampler inputConjugateSampler = null; // Don't use the global conjugate sampler since other factors might not be conjugate final IUnaryFactorFunction prior = _model.getPriorFunction(); if (prior != null) { inputConjugateSampler = RealConjugateSamplerRegistry.findCompatibleSampler(prior); } // Determine if there are bounds double hi = _domain.getUpperBound(); double lo = _domain.getLowerBound(); if (inputConjugateSampler != null) { // Sample from the input if there's an available sampler List<? extends IDatum> priorList; if (prior == null) priorList = Collections.emptyList(); else priorList = Collections.singletonList(prior); double sampleValue = inputConjugateSampler.nextSample(new ISolverEdgeState[0], priorList); // If there are also bounds, clip at the bounds if (sampleValue > hi) sampleValue = hi; if (sampleValue < lo) sampleValue = lo; setCurrentSample(sampleValue); } else { // No input or no available sampler, so if bounded, sample uniformly from the bounds if (hi < Double.POSITIVE_INFINITY && lo > Double.NEGATIVE_INFINITY) { setCurrentSample(activeRandom().nextDouble() * (hi - lo) + lo); } else { double sampleValue = _currentSample.getDouble(); if (hi < sampleValue) { setCurrentSample(hi); } else if (lo > sampleValue) { setCurrentSample(lo); } } } } @Override public final void updateBelief() { // Update the sums for computing moments final double currentSampleValue = _currentSample.getDouble(); _sampleSum += currentSampleValue; _sampleSumSquare += currentSampleValue * currentSampleValue; _sampleCount++; } @Override public Object getBelief() { return 0d; } @Matlab public final double getSampleMean() { return _sampleSum / _sampleCount; } public final double getSampleVariance() { return (_sampleSumSquare - (_sampleSum * (_sampleSum / _sampleCount)) ) / (_sampleCount - 1); } @SuppressWarnings("null") @Override public void postAddFactor(@Nullable Factor f) { } @Override public Object getGuess() { if (_guessWasSet) return Double.valueOf(_guessValue); double knownValue = getKnownReal(); if (knownValue == knownValue) return Double.valueOf(knownValue); else return _currentSample.getObject(); } /** * @deprecated Instead set {@link GibbsOptions#saveAllSamples} to true using {@link #setOption}. */ @Deprecated @Override public final void saveAllSamples() { _sampleArray = new DoubleArrayList(); setOption(GibbsOptions.saveAllSamples, true); } /** * @deprecated Instead set {@link GibbsOptions#saveAllSamples} to false using {@link #setOption}. */ @Deprecated @Override public void disableSavingAllSamples() { _sampleArray = null; setOption(GibbsOptions.saveAllSamples, false); } @Override public final void saveCurrentSample() { final DoubleArrayList sampleArray = _sampleArray; if (sampleArray != null) sampleArray.add(_currentSample.getDouble()); } @Override public final void saveBestSample() { _bestSampleValue = _currentSample.getDouble(); } @Override public final double getPotential() { return evalPriorAndConditionEnergy(_currentSample); } @Override public final boolean hasPotential() { return canHavePriorAndConditionEnergy(); } @Override public final void setCurrentSample(Object value) { _currentSample.setDouble(FactorFunctionUtilities.toDouble(value)); } @Override public final void setCurrentSample(Value value) { _currentSample.setFrom(value); } /*--------------- * Local methods */ public final void setCurrentSample(double value) { _currentSample.setDouble(value); } // Sets the sample regardless of whether the value is fixed or held private final void setCurrentSampleForce(double value) { _currentSample.setDoubleForce(value); } @Matlab public final double getCurrentSample() { return _currentSample.getDouble(); } @Matlab public final double getBestSample() { return _bestSampleValue; } @Matlab @Override public final double[] getAllSamples() { final DoubleArrayList sampleArray = _sampleArray; if (sampleArray == null) { return ArrayUtil.EMPTY_DOUBLE_ARRAY; } return Arrays.copyOf(sampleArray.elements(), sampleArray.size()); } @Override public final double getRejectionRate() { return (_updateCount > 0) ? (double)_rejectCount / (double)_updateCount : 0; } @Override public final double getNumScoresPerUpdate() { return (_updateCount > 0) ? (double)_scoreCount / (double)_updateCount : 0; } @Override public final void resetRejectionRateStats() { _updateCount = 0; _rejectCount = 0; _scoreCount = 0; } @Override public final long getUpdateCount() { return _updateCount; } @Override public final long getRejectionCount() { return _rejectCount; } // This is meant for internal use, not as a user accessible method @Internal public final @Nullable DoubleArrayList _getSampleArrayUnsafe() { return _sampleArray; } public final void setAndHoldSampleValue(double value) { releaseSampleValue(); setCurrentSample(value); holdSampleValue(); } public final void holdSampleValue() { _holdSampleValue = true; } public final void releaseSampleValue() { _holdSampleValue = false; } /** * @deprecated Will be removed in future release. Instead set corresponding option * for desired proposal kernel (e.g. {@link NormalProposalKernel#standardDeviation}. */ @SuppressWarnings("null") @Deprecated public final void setProposalStandardDeviation(double stdDev) { if (_sampler instanceof MHSampler) ((MHSampler)_sampler).getProposalKernel().setParameters(stdDev); } /** * @deprecated Will be removed in future release. Instead lookup corresponding option * for desired proposal kernel (e.g. {@link NormalProposalKernel#standardDeviation}. */ @SuppressWarnings("null") @Deprecated public final double getProposalStandardDeviation() { if (_sampler instanceof MHSampler) return (Double)((MHSampler)_sampler).getProposalKernel().getParameters()[0]; else return 0; } /** * @deprecated Will be removed in future release. Instead set appropriate options * for proposal kernel using {@link #setOption}. */ @SuppressWarnings("null") @Deprecated public final void setProposalKernelParameters(Object... parameters) { if (_sampler instanceof MHSampler) ((MHSampler)_sampler).getProposalKernel().setParameters(parameters); } /** * @deprecated Will be removed in future release. Instead set corresponding option * for sampler (e.g. {@link MHSampler#realProposalKernel}). */ @SuppressWarnings("null") @Deprecated public final void setProposalKernel(IProposalKernel proposalKernel) // IProposalKernel object { if (_sampler instanceof MHSampler) ((MHSampler)_sampler).setProposalKernel(proposalKernel); } /** * @deprecated Will be removed in future release. Instead lookup corresponding option * for sampler (e.g. {@link MHSampler#realProposalKernel}). */ @SuppressWarnings("null") @Deprecated public final void setProposalKernel(String proposalKernelName) // Name of proposal kernel { if (_sampler instanceof MHSampler) ((MHSampler)_sampler).setProposalKernel(proposalKernelName); } /** * @deprecated Will be removed in future release. Instead get kernel directly * from {@linkplain #getSampler() sampler}. */ @SuppressWarnings("null") @Deprecated public final @Nullable IProposalKernel getProposalKernel() { if (_sampler instanceof MHSampler) return ((MHSampler)_sampler).getProposalKernel(); else return null; } /** * @deprecated Will be removed in future release. Instead use {@link GibbsOptions#realSampler} * option. */ @Deprecated public final void setDefaultSampler(String samplerName) { GibbsOptions.realSampler.convertAndSet(this, samplerName); } /** * @deprecated Will be removed in future release. Instead use {@link GibbsOptions#realSampler} * option. */ @Deprecated public final String getDefaultSamplerName() { return getOptionOrDefault(GibbsOptions.realSampler).getSimpleName(); } /** * Sets sampler to be used for this variable. * <p> * In general, it is usually easier to configure the sampler using the * {@link GibbsOptions#realSampler} option. This method should only be * required when the sampler class is not registered with the * {@linkplain DimpleEnvironment#genericSamplers() generic sampler registry} * for the current environment. * <p> * @param sampler is a non-null sampler. */ public final void setSampler(ISampler sampler) { _sampler = (IMCMCSampler)sampler; _samplerSpecificallySpecified = true; } /** * @deprecated Will be removed in future release. Instead set sampler by setting * {@link GibbsOptions#realSampler} option using {@link #setOption}. */ @Matlab @Deprecated public final void setSampler(String samplerName) { GibbsOptions.realSampler.convertAndSet(this, samplerName); IMCMCSampler sampler = (IMCMCSampler) GibbsOptions.realSampler.instantiateIfDifferent(this, _sampler); _sampler = sampler; _samplerSpecificallySpecified = true; sampler.initializeFromVariable(this); } @Matlab @Override public final @Nullable ISampler getSampler() { if (_samplerSpecificallySpecified) { Objects.requireNonNull(_sampler).initializeFromVariable(this); return _sampler; } else { initialize(); // To determine the appropriate sampler if (_conjugateSampler == null) { Objects.requireNonNull(_sampler).initializeFromVariable(this); return _sampler; } else { return _conjugateSampler; } } } @Matlab public final String getSamplerName() { ISampler sampler = getSampler(); if (sampler != null) return sampler.getClass().getSimpleName(); else return ""; } @Matlab public final void setInitialSampleValue(double initialSampleValue) { _initialSampleValue = initialSampleValue; _initialSampleValueSet = true; } @Matlab public final double getInitialSampleValue() { return _initialSampleValue; } // TODO move to ISolverVariableGibbs @Override public final void setBeta(double beta) // beta = 1/temperature { _beta = beta; } @Deprecated @Override public @Nullable Object getOutputMsg(int portIndex) { return _currentSample; } // TODO move to ISolverNode @Override public void initialize() { final boolean saveAllSamples = getOptionOrDefault(GibbsOptions.saveAllSamples); super.initialize(); // We actually only need to change this if the model has changed in the vicinity of this variable, // but that may not be worth the trouble to figure out. _neighbors = GibbsNeighbors.create(this); // Unless this is a dependent of a deterministic factor, then set the starting sample value if (!getModelObject().isDeterministicOutput()) { final double knownValue = getKnownReal(); final double initialSampleValue = knownValue == knownValue ? knownValue : _initialSampleValue; if (!_holdSampleValue) setCurrentSampleForce(initialSampleValue); } // Clear out sample state _bestSampleValue = _currentSample.getDouble(); DoubleArrayList sampleArray = null; if (saveAllSamples) { sampleArray = _sampleArray; if (sampleArray == null) { sampleArray = new DoubleArrayList(); } else { sampleArray.clear(); } } _sampleArray = sampleArray; // Clear out the Belief statistics _sampleSum = 0; _sampleSumSquare = 0; _sampleCount = 0; updatePriorAndCondition(); // // Determine which sampler to use // _conjugateSampler = null; if (!_samplerSpecificallySpecified) { IOptionHolder[] source = new IOptionHolder[1]; Class<? extends IGenericSampler> samplerClass = getOptionAndSource(GibbsOptions.realSampler, source); if (samplerClass == null || source[0] != this && source[0] != _model) { if (getOptionOrDefault(GibbsOptions.enableAutomaticConjugateSampling)) { // See if there's an available conjugate sampler, and if so, use it _conjugateSampler = findConjugateSampler(); } } if (_conjugateSampler == null) { if (samplerClass == null) { samplerClass = GibbsOptions.realSampler.defaultValue(); } IMCMCSampler sampler = _sampler; if (sampler == null || sampler.getClass() != samplerClass) { try { _sampler = sampler = (IMCMCSampler)samplerClass.newInstance(); } catch (InstantiationException | IllegalAccessException ex) { throw new RuntimeException(ex); } } sampler.initializeFromVariable(this); resetRejectionRateStats(); } } } // TODO move to ISolverVariable @Override public void createNonEdgeSpecificState() { _currentSample.reset(); _bestSampleValue = _currentSample.getDouble(); if (_sampleArray != null) saveAllSamples(); } // TODO move to ISolverVariable @Override public void moveNonEdgeSpecificState(ISolverNode other) { GibbsReal ovar = ((GibbsReal)other); ovar._prevSample = _currentSample; ovar._repeatedVariable = true; if (!_repeatedVariable) { if (_prevSample == _currentSample) { // If not already pointing at a different value object, then this must be the // the first variable in the stream, so make a copy of its state. _prevSample = _currentSample.clone(); } else { _prevSample.setFrom(_currentSample); } } Value fixedValue = _model.getPriorValue(); if (fixedValue != null) _currentSample.setDoubleForce(fixedValue.getDouble()); else _currentSample.setFrom(ovar._currentSample); _initialSampleValue = ovar._initialSampleValue; _initialSampleValueSet = ovar._initialSampleValueSet; _sampleArray = ovar._sampleArray; _bestSampleValue = ovar._bestSampleValue; _beta = ovar._beta; _holdSampleValue = ovar._holdSampleValue; _sampleSum = ovar._sampleSum; _sampleSumSquare = ovar._sampleSumSquare; _sampleCount = ovar._sampleCount; // Field values intentionally NOT moved: // _sampler // _conjugateSampler // _samplerSpecificallySpecified // _updateCount // _rejectCount } // Find a single conjugate sampler consistent with all neighboring factors and the Input public @Nullable IRealConjugateSampler findConjugateSampler() { Set<IRealConjugateSamplerFactory> availableSamplerFactories = findConjugateSamplerFactories(); if (availableSamplerFactories.isEmpty()) return null; // No available conjugate sampler else return availableSamplerFactories.iterator().next().create(); // Get the first one and create the sampler } // Find the set of all available conjugate samplers, but don't create it yet public Set<IRealConjugateSamplerFactory> findConjugateSamplerFactories() { return findConjugateSamplerFactories(_model.getSiblingEdgeState()); } // Find the set of available conjugate samplers consistent with a specific set of neighboring factors (as well as the Input) public Set<IRealConjugateSamplerFactory> findConjugateSamplerFactories(Collection<EdgeState> edges) { final Set<IRealConjugateSamplerFactory> commonSamplers = new HashSet<IRealConjugateSamplerFactory>(); final FactorGraph fg = _model.requireParentGraph(); final SolverNodeMapping solvers = getSolverMapping(); // Check all the adjacent factors to see if they all support a common conjugate factor for (EdgeState edgeState : edges) { ISolverNode factor = solvers.getSolverFactor(edgeState.getFactor(fg)); if (!(factor instanceof IRealConjugateFactor)) { commonSamplers.clear(); // At least one connected factor does not support conjugate sampling return commonSamplers; } int factorPortNumber = edgeState.getFactorToVariableEdgeNumber(); Set<IRealConjugateSamplerFactory> availableSamplers = ((IRealConjugateFactor)factor).getAvailableRealConjugateSamplers(factorPortNumber); if (commonSamplers.isEmpty()) { commonSamplers.addAll(availableSamplers); } else { commonSamplers.retainAll(availableSamplers); } } // Next, check conjugate samplers are also compatible with the input and the domain of this variable IUnaryFactorFunction input = _model.getPriorFunction(); IUnaryFactorFunction condition = getConditionFunction(); Iterator<IRealConjugateSamplerFactory> iter = commonSamplers.iterator(); while (iter.hasNext()) { IRealConjugateSamplerFactory sampler = iter.next(); if (!sampler.isCompatible(input) || !sampler.isCompatible(condition) || !sampler.isCompatible(_domain)) { iter.remove(); } } return commonSamplers; } @SuppressWarnings("null") @Override public GibbsSolverEdge<?> getSiblingEdgeState(int siblingIndex) { return (GibbsSolverEdge<?>)getSiblingEdgeState_(siblingIndex); } }