/******************************************************************************* * Copyright 2012-2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.particleBP; import static com.analog.lyric.dimple.environment.DimpleEnvironment.*; import static com.analog.lyric.math.Utilities.*; import static java.util.Objects.*; import java.util.Arrays; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.environment.DimpleEnvironment; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.model.core.EdgeState; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.domains.Domain; import com.analog.lyric.dimple.model.domains.RealDomain; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.RealValue; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.solvers.core.PriorAndCondition; import com.analog.lyric.dimple.solvers.core.SRealVariableBase; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage; import com.analog.lyric.dimple.solvers.core.proposalKernels.IProposalKernel; import com.analog.lyric.dimple.solvers.core.proposalKernels.NormalProposalKernel; import com.analog.lyric.dimple.solvers.core.proposalKernels.Proposal; import com.analog.lyric.dimple.solvers.interfaces.SolverNodeMapping; import com.analog.lyric.math.DimpleRandom; import com.analog.lyric.options.OptionDoubleList; import com.analog.lyric.options.OptionValidationException; import com.analog.lyric.util.misc.Matlab; /** * Solver variable for Real variables under Particle BP solver. * * @since 0.07 */ public class ParticleBPReal extends SRealVariableBase implements IParticleBPVariable { /*------- * State */ protected RealValue[] _particleValues; protected int _numParticles = 1; protected int _resamplingUpdatesPerSample = 1; protected @Nullable IProposalKernel _proposalKernel; /** * True if {@link #_proposalKernel} was set explicitly via {@link #setProposalKernel(IProposalKernel)} * and should not be overridden by option settings. */ protected boolean _explicitProposalKernel; protected RealDomain _initialParticleDomain; protected RealDomain _domain; double [] _particleEnergy; protected double _beta = 1; /*-------------- * Construction */ public ParticleBPReal(Real var, ParticleBPSolverGraph parent) { super(var, parent); // Since numParticles is used to configure the message sizes, we set this at construction // time to make it less likely that the messages will have to be recreated at initialize time. _numParticles = getOptionOrDefault(ParticleBPOptions.numParticles); _initialParticleDomain = _domain = var.getDomain(); _particleValues = new RealValue[_numParticles]; for (int i = 0; i < _numParticles; ++i) { _particleValues[i] = RealValue.create(); } _particleEnergy = new double[_numParticles]; } @Override public void initialize() { _resamplingUpdatesPerSample = getOptionOrDefault(ParticleBPOptions.resamplingUpdatesPerParticle); updateNumParticles(getOptionOrDefault(ParticleBPOptions.numParticles)); if (!_explicitProposalKernel) { Class<? extends IProposalKernel> kernelClass = getOptionOrDefault(ParticleBPOptions.proposalKernel); if (_proposalKernel == null || kernelClass != requireNonNull(_proposalKernel).getClass()) { try { _proposalKernel = kernelClass.getConstructor().newInstance(); } catch (Exception ex) { // Option validation should already have made sure that the constructor // exists, so this should only happen if the constructor itself throws an exception. DimpleEnvironment.logError("Could not create proposal kernel instance for '%s': %s", kernelClass, ex.toString()); } } } requireNonNull(_proposalKernel).configureFromOptions(this); OptionDoubleList range = getOptionOrDefault(ParticleBPOptions.initialParticleRange); RealDomain initialDomain = RealDomain.create(range.get(0), range.get(1)); _initialParticleDomain = initialDomain.isSubsetOf(_domain) ? initialDomain : _domain; double particleMin = 0; double particleMax = 0; RealDomain domain = _initialParticleDomain; if (domain.isBounded()) { particleMin = domain.getLowerBound(); particleMax = domain.getUpperBound(); } int length = _particleValues.length; double particleIncrement = (length > 1) ? (particleMax - particleMin) / (length - 1) : 0; double particleValue = particleMin; for (int i = 0; i < length; i++) { _particleValues[i].setDouble(particleValue); particleValue += particleIncrement; } super.initialize(); } @Override protected void doUpdateEdge(int outPortNum) { final double maxEnergy = 100; int M = _numParticles; int D = getSiblingCount(); double minEnergy = Double.POSITIVE_INFINITY; final double[] outMsgs = getSiblingEdgeState(outPortNum).varToFactorMsg.representation(); PriorAndCondition known = getPriorAndCondition(); for (int m = 0; m < M; m++) { double prior = known.evalEnergy(_particleValues[m]); // FIXME: why does infinity get turned into minLog but values between minLog // and infinity not? double out = (prior == Double.POSITIVE_INFINITY) ? maxEnergy : prior * _beta; for (int d = 0; d < D; d++) { if (d != outPortNum) // For all ports except the output port { double tmp = getSiblingEdgeState(d).factorToVarMsg.getEnergy(m); out += (tmp == Double.POSITIVE_INFINITY) ? maxEnergy : tmp; } } // Subtract particle energy out -= _particleEnergy[m]; if (out < minEnergy) minEnergy = out; outMsgs[m] = out; } known = known.release(); //create sum double sum = 0; for (int m = 0; m < M; m++) { double out = energyToWeight(outMsgs[m] - minEnergy); outMsgs[m] = out; sum += out; } //calculate message by dividing by sum for (int m = 0; m < M; m++) outMsgs[m] /= sum; } @Override protected void doUpdate() { final double maxEnergy = 100; final int M = _numParticles; final int D = _model.getSiblingCount(); PriorAndCondition known = getPriorAndCondition(); // FIXME - handle fixed values //Compute alphas final double[] logInPortMsgs = DimpleEnvironment.doubleArrayCache.allocateAtLeast(M*D); final double[] alphas = DimpleEnvironment.doubleArrayCache.allocateAtLeast(M); for (int m = 0; m < M; m++) { double prior = known.evalEnergy(_particleValues[m]); double alpha = (prior == Double.POSITIVE_INFINITY) ? maxEnergy : prior * _beta; for (int d = 0, i = m; d < D; d++, i += M) { double tmp = getSiblingEdgeState(d).factorToVarMsg.getEnergy(m); double logtmp = (tmp == Double.POSITIVE_INFINITY) ? maxEnergy : tmp; logInPortMsgs[i] = logtmp; alpha += logtmp; } alphas[m] = alpha; } known = known.release(); //Now compute output messages for each outgoing edge for (int out_d = 0, dm = 0; out_d < D; out_d++, dm += M ) { final DiscreteMessage outMsg = getSiblingEdgeState(out_d).varToFactorMsg; final double[] outWeights = outMsg.representation(); double minEnergy = Double.POSITIVE_INFINITY; //set outMsgs to alpha - mu_d,m //find max alpha for (int m = 0; m < M; m++) { double out = alphas[m] - logInPortMsgs[dm + m]; // Subtract particle energy out -= _particleEnergy[m]; if (out < minEnergy) minEnergy = out; outWeights[m] = out; } //create sum double sum = 0; for (int m = 0; m < M; m++) { double out = energyToWeight(outWeights[m] - minEnergy); outWeights[m] = out; sum += out; } //calculate message by dividing by sum for (int m = 0; m < M; m++) { outWeights[m] /= sum; } } DimpleEnvironment.doubleArrayCache.release(logInPortMsgs); DimpleEnvironment.doubleArrayCache.release(alphas); } public void resample() { final DimpleRandom rand = activeRandom(); final FactorGraph fg = _model.requireParentGraph(); int numPorts = _model.getSiblingCount(); Domain varDomain = _model.getDomain(); double _lowerBound = _domain.getLowerBound(); double _upperBound = _domain.getUpperBound(); int M = _numParticles; PriorAndCondition known = getPriorAndCondition(); final IProposalKernel kernel = requireNonNull(_proposalKernel); // For each sample value for (int m = 0; m < M; m++) { final RealValue sampleValue = _particleValues[m]; // Start with the potential for the current particle value double potential = known.evalEnergy(sampleValue) * _beta; double potentialProposed = 0; for (int portIndex = 0; portIndex < numPorts; portIndex++) { EdgeState edge = _model.getSiblingEdgeState(portIndex); int factorPortNumber = edge.getSibling(_model).indexOfSiblingEdgeState(edge); ParticleBPRealFactor factor = (ParticleBPRealFactor)getSibling(portIndex); potential += factor.getMarginalPotential(sampleValue.getDouble(), factorPortNumber); } // Now repeat resampling this sample for (int update = 0; update < _resamplingUpdatesPerSample; update++) { Proposal proposal = kernel.next(sampleValue, varDomain); double proposalValue = proposal.value.getDouble(); // If outside the bounds, then reject if (proposalValue < _lowerBound) continue; if (proposalValue > _upperBound) continue; // Sum up the potentials from the input and all connected factors potentialProposed = known.evalEnergy(proposal.value) * _beta; for (int portIndex = 0; portIndex < numPorts; portIndex++) { EdgeState edge = _model.getSiblingEdgeState(portIndex); int factorPortNumber = edge.getSibling(_model).indexOfSiblingEdgeState(edge); ParticleBPRealFactor factor = (ParticleBPRealFactor)getSibling(portIndex); potentialProposed += factor.getMarginalPotential(proposalValue, factorPortNumber); } // Accept or reject double rejectionThreshold = Math.exp(potential - potentialProposed + proposal.forwardEnergy - proposal.reverseEnergy); if (Double.isNaN(rejectionThreshold)) // Account for invalid forward or reverse proposals { if (potentialProposed != Double.POSITIVE_INFINITY && proposal.forwardEnergy != Double.POSITIVE_INFINITY) rejectionThreshold = Double.POSITIVE_INFINITY; else rejectionThreshold = 0; } if (rand.nextDouble() < rejectionThreshold) { sampleValue.setDouble(proposalValue); potential = potentialProposed; } } _particleEnergy[m] = potential; // Sum-product code uses log(p) instead of -log(p) // Update the incoming messages for the new particle value SolverNodeMapping solvers = getSolverMapping(); for (int d = 0; d < numPorts; d++) { final EdgeState edge = _model.getSiblingEdgeState(d); Factor factorNode = edge.getFactor(fg); int factorPortNumber = edge.getFactorToVariableEdgeNumber(); ParticleBPRealFactor factor = (ParticleBPRealFactor)(solvers.getSolverFactor(factorNode)); getSiblingEdgeState(d).factorToVarMsg.setWeight(m, Math.exp(factor.getMarginalPotential(sampleValue.getDouble(), factorPortNumber))); } } known.release(); // Update the outgoing messages associated with the new particle locations doUpdate(); } @Override public double[] getBelief() { final double maxEnergy = 100; int M = _numParticles; int D = _model.getSiblingCount(); double minEnergy = Double.POSITIVE_INFINITY; PriorAndCondition known = getPriorAndCondition(); double[] outBelief = new double[M]; for (int m = 0; m < M; m++) { double prior = known.evalEnergy(_particleValues[m]); double out = (prior == Double.POSITIVE_INFINITY) ? maxEnergy : prior * _beta; for (int d = 0; d < D; d++) { double tmp = getSiblingEdgeState(d).factorToVarMsg.getEnergy(m); out += (tmp == Double.POSITIVE_INFINITY) ? maxEnergy : tmp; } // // Subtract the log weight // out -= _logWeight[m]; if (out < minEnergy) minEnergy = out; outBelief[m] = out; } known.release(); //create sum double sum = 0; for (int m = 0; m < M; m++) { double out = energyToWeight(outBelief[m] - minEnergy); outBelief[m] = out; sum += out; } //calculate belief by dividing by sum for (int m = 0; m < M; m++) outBelief[m] /= sum; return outBelief; } // Alternative belief, returned for a specified set of variable values @Matlab public double [] getBelief(double[] valueSet) { final double maxEnergy = 100; int M = valueSet.length; int D = _model.getSiblingCount(); double minEnergy = Double.POSITIVE_INFINITY; PriorAndCondition known = getPriorAndCondition(); double[] outBelief = new double[M]; Value value = Value.create(getDomain()); for (int m = 0; m < M; m++) { double real = valueSet[m]; value.setDouble(real); double prior = known.evalEnergy(value); double out = (prior == Double.POSITIVE_INFINITY) ? maxEnergy : prior * _beta; for (int d = 0; d < D; d++) { int factorPortNumber = _model.getReverseSiblingNumber(d); ParticleBPRealFactor factor = (ParticleBPRealFactor)getSibling(d); out += factor.getMarginalPotential(real, factorPortNumber); // Potential is -log(p) } if (out < minEnergy) minEnergy = out; outBelief[m] = out; } known.release(); //create sum double sum = 0; for (int m = 0; m < M; m++) { double out = energyToWeight(outBelief[m] - minEnergy); outBelief[m] = out; sum += out; } //calculate belief by dividing by sum for (int m = 0; m < M; m++) outBelief[m] /= sum; return outBelief; } @Matlab public double[] getParticleValues() { double[] particles = new double[_numParticles]; for (int i = 0; i < _numParticles; i++) { particles[i] = _particleValues[i].getDouble(); } return particles; } @Override public final Value[] getParticleValueObjects() { return _particleValues; } public void setNumParticles(int numParticles) { setOption(ParticleBPOptions.numParticles, numParticles); updateNumParticles(numParticles); } private void updateNumParticles(int numParticles) { if (numParticles != _numParticles) { _particleValues = Arrays.copyOf(_particleValues, numParticles); for (int i = _numParticles; i < numParticles; ++i) { _particleValues[i] = RealValue.create(); } _numParticles = numParticles; _particleEnergy = Arrays.copyOf(_particleEnergy, numParticles); for (int i = 0, n = getSiblingCount(); i < n; ++i) { getSiblingEdgeState(i).resize(numParticles); } } } public int getNumParticles() {return _numParticles;} public void setResamplingUpdatesPerParticle(int updatesPerParticle) { setOption(ParticleBPOptions.resamplingUpdatesPerParticle, updatesPerParticle); _resamplingUpdatesPerSample = updatesPerParticle; } public int getResamplingUpdatesPerParticle() {return _resamplingUpdatesPerSample;} /** * @deprecated instead set {@link NormalProposalKernel#standardDeviation} option using * {@link #setOption} method. */ @Matlab @Deprecated public void setProposalStandardDeviation(double stdDev) { setOption(NormalProposalKernel.standardDeviation, stdDev); } /** * @deprecated instead lookup {@link NormalProposalKernel#standardDeviation} option using * {@link #getOptionOrDefault} method. */ @Matlab @Deprecated public double getProposalStandardDeviation() { return getOptionOrDefault(NormalProposalKernel.standardDeviation); } /** * Current proposal kernel for variable. * <p> * May be null if {@link #initialize()} not yet invoked. * @since 0.07 */ public @Nullable IProposalKernel getProposalKernel() { return _proposalKernel; } /** * @deprecated instead set appropriate proposal-specific options using {@link #setOption}. */ @Deprecated public final void setProposalKernelParameters(Object... parameters) { requireNonNull(_proposalKernel).setParameters(parameters); } // Override the default proposal kernel public final void setProposalKernel(@Nullable IProposalKernel proposalKernel) { _proposalKernel = proposalKernel; _explicitProposalKernel = proposalKernel != null; } public final void setProposalKernel(String proposalKernelName) { ParticleBPOptions.proposalKernel.convertAndSet(this, proposalKernelName); } // Sets the range of initial particle values // Overrides the domain (if one is specified) in determining the initial particle values public void setInitialParticleRange(double min, double max) { RealDomain domain = RealDomain.create(min, max); if (!domain.isSubsetOf(_domain)) { throw new OptionValidationException("Bounds [%g,%g] are not within variable bounds [%g,%g]", min, max, _domain.getLowerBound(), _domain.getUpperBound()); } ParticleBPOptions.initialParticleRange.set(this, min, max); _initialParticleDomain = domain; } public void setBeta(double beta) // beta = 1/temperature { _beta = beta; } @Deprecated @Override public double getScore() { if (_guessWasSet) return super.getScore(); else throw new DimpleException("This solver doesn't provide a default value. Must set guesses for all variables."); } public void remove(Factor factor) { } /*-------------------- * Deprecated methods */ @Deprecated @Override public Object getInputMsg(int portIndex) { return getSiblingEdgeState(portIndex).factorToVarMsg.representation(); } @Deprecated @Override public Object getOutputMsg(int portIndex) { return getSiblingEdgeState(portIndex).varToFactorMsg.representation(); } @Deprecated @Override public void setInputMsgValues(int portIndex, Object obj) { final DiscreteMessage message = getSiblingEdgeState(portIndex).factorToVarMsg; if (obj instanceof DiscreteMessage) { message.setFrom((DiscreteMessage)obj); } else { double[] target = message.representation(); System.arraycopy(obj, 0, target, 0, target.length); } } @SuppressWarnings("null") @Override public ParticleBPRealEdge getSiblingEdgeState(int siblingIndex) { return (ParticleBPRealEdge)getSiblingEdgeState_(siblingIndex); } }