/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.particleBP; import static com.analog.lyric.dimple.environment.DimpleEnvironment.*; import static java.util.Objects.*; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.model.core.EdgeState; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.model.variables.VariableList; import com.analog.lyric.dimple.solvers.core.BPSolverGraph; import com.analog.lyric.dimple.solvers.core.NoSolverEdge; import com.analog.lyric.dimple.solvers.interfaces.ISolverEdgeState; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactor; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.dimple.solvers.interfaces.ISolverVariable; import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping; import com.analog.lyric.dimple.solvers.sumproduct.STableFactor; import com.analog.lyric.dimple.solvers.sumproduct.SumProductDiscreteEdge; import com.analog.lyric.dimple.solvers.sumproduct.SumProductTableFactor; import com.analog.lyric.util.misc.Matlab; /** * Solver-specific factor graph for Particle BP solver. * <p> * <em>Previously was com.analog.lyric.dimple.solvers.particleBP.SFactorGraph</em> * <p> * @since 0.07 */ @SuppressWarnings("deprecation") // TODO remove when SDiscreteVariable removed public class ParticleBPSolverGraph extends BPSolverGraph<ISolverFactor, IParticleBPVariable, ISolverEdgeState> { protected int _numIterationsBetweenResampling = 1; protected boolean _temper = false; protected double _initialTemperature; protected double _temperingDecayConstant; protected double _temperature; protected final double LOG2 = Math.log(2); protected ParticleBPSolverGraph(FactorGraph factorGraph, @Nullable ISolverFactorGraph parent) { super(factorGraph, parent); } @Override public boolean hasEdgeState() { return true; } @Override public ISolverEdgeState createEdgeState(EdgeState edge) { Variable var = edge.getVariable(_model); if (var instanceof Real) { return new ParticleBPRealEdge(requireNonNull(getRealSolverVariable((Real)var))); } else if (var instanceof Discrete) { return new SumProductDiscreteEdge((Discrete)var); } return NoSolverEdge.INSTANCE; } @SuppressWarnings("deprecation") // TODO remove when S*Factor classes removed. @Override public ISolverFactor createFactor(Factor factor) { if (factor.isDiscrete()) return new STableFactor(factor, this); else return new SRealFactor(factor, this); } @Override public ISolverFactorGraph createSubgraph(FactorGraph subgraph) { return new SFactorGraph(subgraph, this); } @SuppressWarnings("deprecation") // TODO remove when S*Variable classes removed. @Override public IParticleBPVariable createVariable(Variable var) { if (var instanceof Real) { ParticleBPReal v = new SRealVariable((Real)var, this); return v; } else if (var instanceof Discrete) { return new ParticleBPDiscrete((Discrete)var, this); } // TODO support RealJoint variables throw unsupportedVariableType(var); } @Override public IParticleBPVariable getSolverVariable(Variable variable) { return (IParticleBPVariable)super.getSolverVariable(variable); } @Override public void initialize() { _temper = getOptionOrDefault(ParticleBPOptions.enableAnnealing); _initialTemperature = getOptionOrDefault(ParticleBPOptions.initialTemperature); _numIterationsBetweenResampling = getOptionOrDefault(ParticleBPOptions.iterationsBetweenResampling); _temperingDecayConstant = 1 - LOG2/getOptionOrDefault(ParticleBPOptions.annealingHalfLife); super.initialize(); if (_temper) setTemperature(_initialTemperature); for (ISolverFactor sf : getSolverFactorsRecursive()) { if (sf instanceof SumProductTableFactor) { SumProductTableFactor tf = (SumProductTableFactor)sf; tf.setupTableFactorEngine(); } } } @Override public void iterate(int numIters) { final VariableList vars = _model.getVariables(); final SolverNodeMapping solvers = getSolverMapping(); int iterationsBeforeResampling = 1; for (int iterNum = 0; iterNum < numIters; iterNum++) { if (--iterationsBeforeResampling <= 0) { for (Variable v : vars) { ISolverVariable vs = solvers.getSolverVariable(v); if (vs instanceof ParticleBPReal) { ((ParticleBPReal)vs).resample(); } } iterationsBeforeResampling = _numIterationsBetweenResampling; } update(); if (_temper) { _temperature *= _temperingDecayConstant; setTemperature(_temperature); } // Allow interruption (if the solver is run as a thread) // Currently interruption is allowed only between iterations, not within a single iteration try {Thread.sleep(0);} catch (InterruptedException e) { Thread.currentThread().interrupt(); return; } } } // Set/get the current temperature for all variables in the graph (for tempering) @Matlab public void setTemperature(double T) { _temperature = T; double beta = 1/T; // All real factors have temperatures SolverNodeMapping solvers = getSolverMapping(); for (Factor f : _model.getNonGraphFactors()) { ISolverFactor fs = solvers.getSolverFactor(f); if (fs instanceof ParticleBPRealFactor) ((ParticleBPRealFactor)fs).setBeta(beta); } for (Variable v : _model.getVariables()) { ISolverVariable vs = solvers.getSolverVariable(v); if (vs instanceof ParticleBPReal) ((ParticleBPReal)vs).setBeta(beta); } // TODO: discrete factors could have temperatures too } @Matlab public double getTemperature() {return _temperature;} // Sets the random seed for the Particle BP solver. This allows runs of the solver to be repeatable. public void setSeed(long seed) { activeRandom().setSeed(seed); } // Set the number of particle values globally for all real variables public void setNumParticles(int numParticles) { setOption(ParticleBPOptions.numParticles, numParticles); } // Set the number of re-sampling updates per particle when re-sampling the particle values, globally for all real variables public void setResamplingUpdatesPerParticle(int updatesPerParticle) { setOption(ParticleBPOptions.resamplingUpdatesPerParticle, updatesPerParticle); } // Set/get the number of iterations between resamplings public void setNumIterationsBetweenResampling(int numIterationsBetweenResampling) { setOption(ParticleBPOptions.iterationsBetweenResampling, numIterationsBetweenResampling); _numIterationsBetweenResampling = numIterationsBetweenResampling; } public int getNumIterationsBetweenResampling() {return _numIterationsBetweenResampling;} /** * Returns real solver variable for given model variable. * @since 0.08 */ public @Nullable ParticleBPReal getRealSolverVariable(Real modelVar) { return (ParticleBPReal)super.getSolverVariable(modelVar); } /** * @deprecated Instead set {@link ParticleBPOptions#initialTemperature} option using {@link #setOption}. */ @Deprecated public void setInitialTemperature(double initialTemperature) { setOption(ParticleBPOptions.initialTemperature, initialTemperature); setTempering(true); _initialTemperature = initialTemperature; } /** * @deprecated Instead get {@link ParticleBPOptions#initialTemperature} option using {@link #getOption}. */ @Deprecated public double getInitialTemperature() {return _initialTemperature;} /** * @deprecated Instead set {@link ParticleBPOptions#annealingHalfLife} option using {@link #setOption}. */ @Deprecated public void setTemperingHalfLifeInIterations(double temperingHalfLifeInIterations) { setOption(ParticleBPOptions.annealingHalfLife, temperingHalfLifeInIterations); setTempering(true); _temperingDecayConstant = 1 - LOG2/temperingHalfLifeInIterations; } /** * @deprecated Instead get {@link ParticleBPOptions#annealingHalfLife} option using {@link #getOption}. */ @Deprecated public double getTemperingHalfLifeInIterations() {return LOG2/(1 - _temperingDecayConstant);} /** * @deprecated Instead set {@link ParticleBPOptions#enableAnnealing} option using {@link #setOption}. */ @Deprecated protected void setTempering(boolean temper) { setOption(ParticleBPOptions.enableAnnealing, temper); _temper = temper; } /** * @deprecated Instead set {@link ParticleBPOptions#enableAnnealing} option to true using {@link #setOption}. */ @Deprecated public final void enableTempering() { setTempering(true); } /** * @deprecated Instead set {@link ParticleBPOptions#enableAnnealing} option to false using {@link #setOption}. */ @Deprecated public final void disableTempering() { setTempering(false); } /** * @deprecated Instead get {@link ParticleBPOptions#enableAnnealing} option using {@link #getOption}. */ @Deprecated public boolean isTemperingEnabled() { return _temper; } /* * */ @Override protected void doUpdateEdge(int edge) { } @Override protected String getSolverName() { return "Particle BP"; } }