/******************************************************************************* * Copyright 2012-2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.minsum; import java.util.Arrays; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.ArrayUtil; import com.analog.lyric.collect.DoubleArrayCache; import com.analog.lyric.dimple.environment.DimpleEnvironment; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.options.BPOptions; import com.analog.lyric.dimple.solvers.core.PriorAndCondition; import com.analog.lyric.dimple.solvers.core.SDiscreteVariableDoubleArray; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage; /** * Solver variable for Discrete variables under Min-Sum solver. * * @since 0.07 */ public class MinSumDiscrete extends SDiscreteVariableDoubleArray { /*------- * State */ protected @Nullable double[] _dampingParams = null; protected double[][] _inMsgs = ArrayUtil.EMPTY_DOUBLE_ARRAY_ARRAY; protected double[][] _outMsgs = ArrayUtil.EMPTY_DOUBLE_ARRAY_ARRAY; /*-------------- * Construction */ MinSumDiscrete(Discrete var, MinSumSolverGraph parent) { super(var, parent); } /*--------------------- * ISolverNode methods */ @Override public void initialize() { super.initialize(); final int nEdges = _model.getSiblingCount(); if (nEdges != _inMsgs.length) { _inMsgs = new double[nEdges][]; _outMsgs = new double[nEdges][]; } for (int i = 0; i < nEdges; ++i) { MinSumDiscreteEdge edge = getSiblingEdgeState(i); _inMsgs[i] = edge.factorToVarMsg.representation(); _outMsgs[i] = edge.varToFactorMsg.representation(); } configureDampingFromOptions(); } /*--------------- * SNode methods */ @Override protected void doUpdateEdge(int outPortNum) { final double[] outMsgs = _outMsgs[outPortNum]; PriorAndCondition known = getPriorAndCondition(); final Value fixedValue = known.value(); if (fixedValue != null) { Arrays.fill(outMsgs, MessageConverter.maxPotential); outMsgs[fixedValue.getIndex()] = 0.0; known.release(); return; } final int numPorts = _model.getSiblingCount(); final int numValue = getDomain().size(); final double[] dampingParams = _dampingParams; final double damping = dampingParams != null ? dampingParams[outPortNum] : 0.0; if (damping != 0.0) { // Save previous output for damping final double[] savedOutMsgArray = DimpleEnvironment.doubleArrayCache.allocateAtLeast(numValue); System.arraycopy(outMsgs, 0, savedOutMsgArray, 0, numValue); copyPrior(known, outMsgs); int port = numPorts; while (--port > outPortNum) { final double[] energies = _inMsgs[port]; for (int i = numValue; --i>=0;) { outMsgs[i] += energies[i]; } } while(--port >= 0) { final double[] energies = _inMsgs[port]; for (int i = numValue; --i>=0;) { outMsgs[i] += energies[i]; } } // Apply damping final double inverseDamping = 1.0 - damping; for (int m = numValue; --m>=0;) { outMsgs[m] = outMsgs[m]*inverseDamping + savedOutMsgArray[m]*damping; } // Release temp array DimpleEnvironment.doubleArrayCache.release(savedOutMsgArray); } else { copyPrior(known, outMsgs); int port = numPorts; while (--port > outPortNum) { final double[] energies = _inMsgs[port]; for (int i = numValue; --i>=0;) { outMsgs[i] += energies[i]; } } while(--port >= 0) { final double[] energies = _inMsgs[port]; for (int i = numValue; --i>=0;) { outMsgs[i] += energies[i]; } } } known = known.release(); // Normalize the min double minPotential = outMsgs[0]; for (int i = 1; i < numValue; ++i) { minPotential = Math.min(minPotential, outMsgs[i]); } if (minPotential != 0.0) { for (int i = numValue; --i>=0;) { outMsgs[i] -= minPotential; } } } @Override protected void doUpdate() { PriorAndCondition known = getPriorAndCondition(); final Value fixedValue = known.value(); if (fixedValue != null) { final int index = fixedValue.getIndex(); for (double[] outMsg : _outMsgs) { Arrays.fill(outMsg, MessageConverter.maxPotential); outMsg[index] = 0.0; } known.release(); return; } int numPorts = _model.getSiblingCount(); int numValue = getDomain().size(); // Compute the sum of all messages final DoubleArrayCache cache = DimpleEnvironment.doubleArrayCache; final double[] beliefs = cache.allocateAtLeast(numValue); copyPrior(known, beliefs); known = known.release(); for (int port = numPorts; --port>=0;) { final double[] inMsgs = _inMsgs[port]; for (int i = numValue; --i>=0;) { beliefs[i] += inMsgs[i]; } } final double[] dampingParams = _dampingParams; if (dampingParams != null) { final double[] savedOutMsgArray = cache.allocateAtLeast(numValue); for (int port = numPorts; --port>=0; ) { final double[] outMsgs = _outMsgs[port]; double minPotential = Double.POSITIVE_INFINITY; final double damping = dampingParams[port]; if (damping != 0.0) { System.arraycopy(outMsgs, 0, savedOutMsgArray, 0, numValue); } final double[] inPortMsgsThisPort = _inMsgs[port]; for (int i = numValue; --i>=0;) { double out = beliefs[i] - inPortMsgsThisPort[i]; minPotential = Math.min(minPotential, out); outMsgs[i] = out; } // Damping if (damping != 0) { final double inverseDamping = 1.0 - damping; for (int m = numValue; --m>=0;) { outMsgs[m] = outMsgs[m]*inverseDamping + savedOutMsgArray[m]*damping; } } // Normalize the min if (minPotential != 0.0) { for (int i = numValue; --i>=0;) outMsgs[i] -= minPotential; } } // Release temp array cache.release(savedOutMsgArray); } else { for (int port = numPorts; --port>=0; ) { final double[] outMsgs = _outMsgs[port]; double minPotential = Double.POSITIVE_INFINITY; final double[] inPortMsgsThisPort = _inMsgs[port]; for (int i = numValue; --i>=0;) { double out = beliefs[i] - inPortMsgsThisPort[i]; minPotential = Math.min(minPotential, out); outMsgs[i] = out; } // Normalize the min if (minPotential != 0.0) { for (int i = numValue; --i>=0;) outMsgs[i] -= minPotential; } } } cache.release(beliefs); } /*------------------------- * ISolverVariable methods */ @Override public double[] getBelief() { final int numValue = getDomain().size(); final double[] outBelief = new double[numValue]; PriorAndCondition known = getPriorAndCondition(); final Value fixedValue = known.value(); if (fixedValue != null) { outBelief[fixedValue.getIndex()] = 1.0; known.release(); return outBelief; } copyPrior(known, outBelief); known = known.release(); int numPorts = _model.getSiblingCount(); for (int i = 0; i < numValue; i++) { double sum = outBelief[i]; for (int port = 0; port < numPorts; port++) { sum += getSiblingEdgeState(port).factorToVarMsg.getEnergy(i); } outBelief[i] = sum; } // Convert to probabilities since that's what the interface expects return MessageConverter.toProb(outBelief); } /** * @deprecated Use {@link BPOptions#damping} or {@link BPOptions#nodeSpecificDamping} options instead. */ @Deprecated public void setDamping(int siblingNumber, double dampingVal) { double[] params = BPOptions.nodeSpecificDamping.getOrDefault(this).toPrimitiveArray(); if (params.length == 0 && dampingVal != 0.0) { params = new double[getSiblingCount()]; } if (params.length != 0) { params[siblingNumber] = dampingVal; } BPOptions.nodeSpecificDamping.set(this, params); configureDampingFromOptions(); } public double getDamping(int siblingNumber) { final double[] dampingParams = _dampingParams; return dampingParams != null ? dampingParams[siblingNumber] : 0.0; } /*--------------- * SNode methods */ @Override protected boolean supportsMessageEvents() { return true; } /*----------------- * Private methods */ private void configureDampingFromOptions() { final int size = getSiblingCount(); double[] dampingParams = _dampingParams = getReplicatedNonZeroListFromOptions(BPOptions.nodeSpecificDamping, BPOptions.damping, size, _dampingParams); if (dampingParams.length > 0 && dampingParams.length != size) { DimpleEnvironment.logWarning("%s has wrong number of parameters for %s\n", BPOptions.nodeSpecificDamping, this); _dampingParams = null; } if (dampingParams.length == 0) { _dampingParams = null; } } private void copyPrior(PriorAndCondition known, double[] out) { final DiscreteMessage prior = toEnergyMessage(known); if (prior != null) { prior.getEnergies(out); MessageConverter.clipEnergies(out); } else { Arrays.fill(out, 0.0); } } @Override @SuppressWarnings("null") public MinSumDiscreteEdge getSiblingEdgeState(int siblingIndex) { return (MinSumDiscreteEdge)getSiblingEdgeState_(siblingIndex); } }