/******************************************************************************* * Copyright 2012-2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.particleBP; import static com.analog.lyric.math.Utilities.*; import static java.util.Objects.*; import com.analog.lyric.collect.CombinatoricIterator; import com.analog.lyric.dimple.factorfunctions.core.FactorFunction; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.RealValue; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Constant; import com.analog.lyric.dimple.model.variables.IConstantOrVariable; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.solvers.core.SDiscreteWeightEdge; import com.analog.lyric.dimple.solvers.core.SFactorBase; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.dimple.solvers.sumproduct.SumProductDiscreteEdge; /** * Real solver factor under Particle BP solver. * * @since 0.07 */ public class ParticleBPRealFactor extends SFactorBase { protected double _beta = 1; ParticleBPRealFactor(Factor factor, ISolverFactorGraph parent) { super(factor, parent); } @Override public ParticleBPSolverGraph getParentGraph() { return (ParticleBPSolverGraph)_parent; } public double getMarginalPotential(double value, int outPortIndex) { final int nEdges = getSiblingCount(); FactorFunction factorFunction = _model.getFactorFunction(); double marginal = 0; CombinatoricIterator<Value> iter = getCombinatoricIterator(value, outPortIndex); final int[] variableIndices = iter.indices(); final double[][] inputWeightsPerEdge = new double[nEdges][]; for (int i = 0; i < nEdges; ++i) { inputWeightsPerEdge[i] = getSiblingEdgeState(i).varToFactorMsg.representation(); } while (iter.hasNext()) { double prob = factorFunction.eval(iter.next()); if (_beta != 1) prob = Math.pow(prob, _beta); for (int i = 0; i < outPortIndex; ++i) { prob *= inputWeightsPerEdge[i][variableIndices[i]]; } for (int i = outPortIndex + 1; i < nEdges; ++i) { prob *= inputWeightsPerEdge[i][variableIndices[i]]; } marginal += prob; } // FIXME: Should do bounds checking return weightToEnergy(marginal); } @Override public void doUpdateEdge(int outPortNum) { final int nEdges = getSiblingCount(); FactorFunction factorFunction = _model.getFactorFunction(); final DiscreteMessage outputMsg = getSiblingEdgeState(outPortNum).factorToVarMsg; outputMsg.setWeightsToZero(); final double[] outputWeights = outputMsg.representation(); final CombinatoricIterator<Value> iter = getCombinatoricIterator(); final int[] variableIndices = iter.indices(); while (iter.hasNext()) { Value[] values = iter.next(); double prob = factorFunction.eval(values); if (_beta != 1) prob = Math.pow(prob, _beta); for (int inPortNum = 0; inPortNum < nEdges; inPortNum++) { if (inPortNum != outPortNum) { prob *= getSiblingEdgeState(inPortNum).varToFactorMsg.getWeight(variableIndices[inPortNum]); } } outputWeights[variableIndices[outPortNum]] += prob; } outputMsg.normalize(); } @Override protected void doUpdate() { FactorFunction factorFunction = _model.getFactorFunction(); final CombinatoricIterator<Value> iter = getCombinatoricIterator(); final int[] variableIndices = iter.indices(); for (int outPortNum = 0, n = getSiblingCount(); outPortNum < n; outPortNum++) { final DiscreteMessage outputMsg = getSiblingEdgeState(outPortNum).factorToVarMsg; outputMsg.setWeightsToZero(); final double[] outputWeights = outputMsg.representation(); iter.reset(); while (iter.hasNext()) { Value[] variableValues = iter.next(); double prob = 1; prob = factorFunction.eval(variableValues); if (_beta != 1) prob = Math.pow(prob, _beta); for (int inPortNum = 0; inPortNum < outPortNum; inPortNum++) { prob *= getSiblingEdgeState(inPortNum).varToFactorMsg.getWeight(variableIndices[inPortNum]); } for (int inPortNum = outPortNum + 1; inPortNum < n; inPortNum++) { prob *= getSiblingEdgeState(inPortNum).varToFactorMsg.getWeight(variableIndices[inPortNum]); } outputWeights[variableIndices[outPortNum]] += prob; } outputMsg.normalize(); } } public void setBeta(double beta) // beta = 1/temperature { _beta = beta; } @Override public void initialize() { super.initialize(); } @Deprecated @Override public Object getInputMsg(int portIndex) { return getSiblingEdgeState(portIndex).varToFactorMsg.representation(); } @Deprecated @Override public Object getOutputMsg(int portIndex) { return getSiblingEdgeState(portIndex).factorToVarMsg.representation(); } @SuppressWarnings("null") @Override public SDiscreteWeightEdge getSiblingEdgeState(int siblingIndex) { return (SDiscreteWeightEdge)getSiblingEdgeState_(siblingIndex); } @SuppressWarnings("null") protected SumProductDiscreteEdge getDiscreteEdge(int siblingIndex) { return (SumProductDiscreteEdge)getSiblingEdgeState_(siblingIndex); } @SuppressWarnings("null") protected ParticleBPRealEdge getRealEdge(int siblingIndex) { return (ParticleBPRealEdge)getSiblingEdgeState_(siblingIndex); } @Override public IParticleBPVariable getSibling(int edge) { return (IParticleBPVariable) super.getSibling(edge); } /** * Returns an iterator over all combination of variable values */ private CombinatoricIterator<Value> getCombinatoricIterator() { final int nEdges = getSiblingCount(); final Value[][] particlesPerVar = new Value[nEdges][]; for (int i = 0; i < nEdges; ++i) { particlesPerVar[i] = getSibling(i).getParticleValueObjects(); } return new CombinatoricIterator<>(Value.class, particlesPerVar); } /** * Returns an iterator over all combination of variable values except for edge */ private CombinatoricIterator<Value> getCombinatoricIterator(double frozenValue, int frozenEdge) { final Factor factor = _model; final int nArgs = factor.getArgumentCount(); final Value[][] particlesPerVar = new Value[nArgs][]; for (int i = 0; i < nArgs; ++i) { IConstantOrVariable arg = factor.getArgument(i); if (arg instanceof Constant) { particlesPerVar[i] = new Value[] { ((Constant)arg).value() }; } else if (i == frozenEdge) { particlesPerVar[i] = new Value[] { RealValue.create(frozenValue) }; } else { Variable var = (Variable)arg; IParticleBPVariable svar = requireNonNull(getParentGraph().getSolverVariable(var)); particlesPerVar[i] = svar.getParticleValueObjects(); } } return new CombinatoricIterator<>(Value.class, particlesPerVar); } }