/******************************************************************************* * Copyright 2012-2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.sumproduct; import static com.analog.lyric.math.Utilities.*; import static java.util.Objects.*; import java.util.Arrays; import java.util.Objects; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.collect.ArrayUtil; import com.analog.lyric.dimple.environment.DimpleEnvironment; import com.analog.lyric.dimple.model.core.EdgeState; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Discrete; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.options.BPOptions; import com.analog.lyric.dimple.solvers.core.PriorAndCondition; import com.analog.lyric.dimple.solvers.core.SDiscreteVariableDoubleArray; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteWeightMessage; import com.analog.lyric.dimple.solvers.interfaces.ISolverFactorGraph; import com.analog.lyric.util.misc.Internal; /** * Solver variable for Discrete variables under Sum-Product solver. * * @since 0.07 */ public class SumProductDiscrete extends SDiscreteVariableDoubleArray { /*------- * State */ private boolean _calculateDerivative = false; @Nullable private double [][][] _outMessageDerivative; protected @Nullable double[] _dampingParams = null; protected double[][] _inMsgs = ArrayUtil.EMPTY_DOUBLE_ARRAY_ARRAY; protected double[][] _outMsgs = ArrayUtil.EMPTY_DOUBLE_ARRAY_ARRAY; /*-------------- * Construction */ public SumProductDiscrete(Discrete var, ISolverFactorGraph parent) { super(var, parent); } @Override public void initialize() { super.initialize(); final int nEdges = _model.getSiblingCount(); if (nEdges != _inMsgs.length) { _inMsgs = new double[nEdges][]; _outMsgs = new double[nEdges][]; } for (int i = 0; i < nEdges; ++i) { SumProductDiscreteEdge edge = getSiblingEdgeState(i); _inMsgs[i] = edge.factorToVarMsg.representation(); _outMsgs[i] = edge.varToFactorMsg.representation(); } configureDampingFromOptions(); } public Variable getVariable() { return _model; } public void setCalculateDerivative(boolean val) { _calculateDerivative = val; } @Deprecated public void setDamping(int siblingNumber,double dampingVal) { double[] params = BPOptions.nodeSpecificDamping.getOrDefault(this).toPrimitiveArray(); if (params.length == 0 && dampingVal != 0.0) { params = new double[getSiblingCount()]; } if (params.length != 0) { params[siblingNumber] = dampingVal; } BPOptions.nodeSpecificDamping.set(this, params); configureDampingFromOptions(); } public double getDamping(int siblingNumber) { final double[] dampingParams = _dampingParams; return dampingParams != null ? dampingParams[siblingNumber] : 0.0; } @Override protected void doUpdateEdge(int outPortNum) { final double[] outMsgs = _outMsgs[outPortNum]; PriorAndCondition known = getPriorAndCondition(); final Value fixedValue = known.value(); if (fixedValue != null) { Arrays.fill(outMsgs, 0); outMsgs[fixedValue.getIndex()] = 1.0; known.release(); return; } final double minLog = -100; // FIXME DiscreteMessage priors = toEnergyMessage(known); known = known.release(); final int M = getDomain().size(); final int D = _model.getSiblingCount(); double maxLog = Double.NEGATIVE_INFINITY; final double[][] inMsgs = _inMsgs; final double[] dampingParams = _dampingParams; final double damping = dampingParams != null ? dampingParams[outPortNum] : 0.0; final double priorNormalizer = priors != null ? priors.getNormalizationEnergy() : 0; if (damping != 0.0) { // Save previous output for damping final double[] savedOutMsgArray = DimpleEnvironment.doubleArrayCache.allocateAtLeast(M); System.arraycopy(outMsgs, 0, savedOutMsgArray, 0, M); // We do not assume that the prior is normalized for (int m = M; --m>=0;) { double prior = priors != null ? priors.getEnergy(m) : 0; double out = (prior == Double.POSITIVE_INFINITY) ? minLog : priorNormalizer - prior; int d = D; while (--d > outPortNum) { double tmp = inMsgs[d][m]; out += (tmp == 0) ? minLog : Math.log(tmp); } while (--d >= 0) { double tmp = inMsgs[d][m]; out += (tmp == 0) ? minLog : Math.log(tmp); } maxLog = Math.max(maxLog, out); outMsgs[m] = out; } // convert from log domain double sum = 0.0; for (int m = M; --m>=0;) { double out = Math.exp(outMsgs[m] - maxLog); outMsgs[m] = out; sum += out; } // normalize for (int m = M; --m>=0;) { outMsgs[m] /= sum; } // Apply damping final double inverseDamping = 1.0 - damping; for (int m = M; --m>=0;) outMsgs[m] = outMsgs[m]*inverseDamping + savedOutMsgArray[m]*damping; // Release temp array DimpleEnvironment.doubleArrayCache.release(savedOutMsgArray); } else { // Only update normalization energy when damping is disabled because it probably // won't be useful in that case. final DiscreteMessage outMsg = getSiblingEdgeState(outPortNum).varToFactorMsg; final boolean setNormalizationEnergy = true; // make this optional? double normalizationEnergy = 0.0; if (setNormalizationEnergy) { for (int d = D; -- d> outPortNum;) normalizationEnergy += getSiblingEdgeState(d).factorToVarMsg.getNormalizationEnergy(); for (int d = outPortNum; --d >=0;) normalizationEnergy += getSiblingEdgeState(d).factorToVarMsg.getNormalizationEnergy(); } for (int m = M; --m>=0;) { double prior = priors != null ? priors.getEnergy(m) : 0; double out = (prior == Double.POSITIVE_INFINITY) ? minLog : priorNormalizer - prior; int d = D; while (--d > outPortNum) { double tmp = inMsgs[d][m]; out += (tmp == 0) ? minLog : Math.log(tmp); } while (--d >= 0) { double tmp = inMsgs[d][m]; out += (tmp == 0) ? minLog : Math.log(tmp); } maxLog = Math.max(maxLog, out); outMsgs[m] = out; } // Convert from log domain for (int m = M; --m>=0;) { double out = Math.exp(outMsgs[m] - maxLog); outMsgs[m] = out; } if (setNormalizationEnergy) { outMsg.setNormalizationEnergy(normalizationEnergy - maxLog); } outMsg.normalize(); } if (_calculateDerivative) { updateDerivative(outPortNum); } } @Override protected void doUpdate() { PriorAndCondition known = getPriorAndCondition(); final Value fixedValue = known.value(); if (fixedValue != null) { final int index = fixedValue.getIndex(); for (double[] outMsg : _outMsgs) { Arrays.fill(outMsg, 0); outMsg[index] = 1.0; } known.release(); return; } final double minLog = -100; // FIXME final DiscreteMessage priors = toEnergyMessage(known); known = known.release(); final int M = getDomain().size(); final int D = _model.getSiblingCount(); //Compute alphas final double[][] inMsgs = _inMsgs; final double[] logInPortMsgs = DimpleEnvironment.doubleArrayCache.allocateAtLeast(M*D); final double[] alphas = DimpleEnvironment.doubleArrayCache.allocateAtLeast(M); final double priorNormalizer = priors != null ? priors.getNormalizationEnergy() : 0; for (int m = M; --m>=0;) { double prior = priors != null ? priors.getEnergy(m) : 0; double alpha = (prior == Double.POSITIVE_INFINITY) ? minLog : priorNormalizer - prior; for (int d = 0, i = m; d < D; d++, i += M) { double tmp = inMsgs[d][m]; double logtmp = (tmp == 0) ? minLog : Math.log(tmp); logInPortMsgs[i] = logtmp; alpha += logtmp; } alphas[m] = alpha; } final double[] dampingParams = _dampingParams; if (dampingParams != null) { final double[] savedOutMsgArray = DimpleEnvironment.doubleArrayCache.allocateAtLeast(M); for (int out_d = 0, dm = 0; out_d < D; out_d++, dm += M ) { final double[] outMsgs = _outMsgs[out_d]; final double damping = dampingParams[out_d]; if (damping != 0) { System.arraycopy(outMsgs, 0, savedOutMsgArray, 0, M); } double maxLog = Double.NEGATIVE_INFINITY; //set outMsgs to alpha - mu_d,m //find max alpha for (int m = M; --m>=0;) { final double out = alphas[m] - logInPortMsgs[dm + m]; maxLog = Math.max(maxLog, out); outMsgs[m] = out; } // convert from log domain double sum = 0.0; for (int m = M; --m>=0;) { double out = Math.exp(outMsgs[m] - maxLog); outMsgs[m] = out; sum += out; } // normalize for (int m = M; --m>=0;) { outMsgs[m] /= sum; } if (damping != 0) { final double inverseDamping = 1.0 - damping; for (int m = M; --m>=0;) { outMsgs[m] = outMsgs[m]*inverseDamping + savedOutMsgArray[m]*damping; } } } DimpleEnvironment.doubleArrayCache.release(savedOutMsgArray); } else // no damping { double incomingNormalizationEnergy = 0.0; for (int d = 0; d < D; ++d) { incomingNormalizationEnergy += getSiblingEdgeState(d).factorToVarMsg.getNormalizationEnergy(); } for (int out_d = 0, dm = 0; out_d < D; out_d++, dm += M ) { final double[] outMsgs = _outMsgs[out_d]; double maxLog = Double.NEGATIVE_INFINITY; //set outMsgs to alpha - mu_d,m //find max alpha for (int m = M; --m>=0;) { final double out = alphas[m] - logInPortMsgs[dm + m]; maxLog = Math.max(maxLog, out); outMsgs[m] = out; } // convert from log domain double sum = 0.0; for (int m = M; --m>=0;) { final double out = Math.exp(outMsgs[m] - maxLog); sum += out; outMsgs[m] = out; } // normalize for (int m = M; --m>=0;) { outMsgs[m] /= sum; } // Update normalization energy on outgoing message: // includes energy from edges used to compute the outgoing message // plus that used to do the final normalization. final SumProductDiscreteEdge outEdge = getSiblingEdgeState(out_d); final double normalizationEnergy = weightToEnergy(sum) - maxLog + incomingNormalizationEnergy - outEdge.factorToVarMsg.getNormalizationEnergy(); outEdge.varToFactorMsg.setNormalizationEnergy(normalizationEnergy); } } DimpleEnvironment.doubleArrayCache.release(logInPortMsgs); DimpleEnvironment.doubleArrayCache.release(alphas); if (_calculateDerivative) { for (int i = 0; i < D; i++) updateDerivative(i); } } @Override public double[] getBelief() { final int M = getDomain().size(); final double[] outBelief = new double[M]; PriorAndCondition known = getPriorAndCondition(); final Value fixedValue = known.value(); if (fixedValue != null) { outBelief[fixedValue.getIndex()] = 1.0; known.release(); return outBelief; } final DiscreteMessage priors = toEnergyMessage(known); known = known.release(); final int D = _model.getSiblingCount(); final double minLog = -100; double maxLog = Double.NEGATIVE_INFINITY; final double priorNormalizer = priors != null ? priors.getNormalizationEnergy() : 0; for (int m = 0; m < M; m++) { double prior = priors != null ? priors.getEnergy(m) : 0; double out = (prior == Double.POSITIVE_INFINITY) ? minLog : priorNormalizer - prior; for (int d = 0; d < D; d++) { double tmp = getSiblingEdgeState(d).factorToVarMsg.getWeight(m); out += (tmp == 0) ? minLog : Math.log(tmp); } if (out > maxLog) maxLog = out; outBelief[m] = out; } //create sum double sum = 0; for (int m = 0; m < M; m++) { double out = Math.exp(outBelief[m] - maxLog); outBelief[m] = out; sum += out; } // Normalize if (sum > 0) { for (int m = 0; m < M; m++) { outBelief[m] /= sum; } } else { // If all zero, then return uniform Arrays.fill(outBelief, 1.0/M); } return outBelief; } public double [] getNormalizedInputs() { PriorAndCondition known = getPriorAndCondition(); final Value fixedValue = known.value(); if (fixedValue != null) { final double[] belief = new double[getDomain().size()]; belief[fixedValue.getIndex()] = 1.0; known.release(); return belief; } DiscreteMessage prior = toEnergyMessage(known); known.release(); if (prior != null) { prior = new DiscreteWeightMessage(prior); prior.normalize(); return prior.representation(); } final int n = getDomain().size(); final double[] result = new double[n]; Arrays.fill(result, 1.0 / n); return result; } public double [] getUnormalizedBelief() { Value fixedValue = getKnownValue(); if (fixedValue != null) { final double[] belief = new double[getDomain().size()]; belief[fixedValue.getIndex()] = 1.0; return belief; } //TODO: log regime double [] input = getNormalizedInputs(); double [] retval = input.clone(); for (int i = 0, n = getSiblingCount(); i < n; i++) { final double[] inMsg = getSiblingEdgeState(i).factorToVarMsg.representation(); for (int j = 0; j < retval.length; j++) { retval[j] *= inMsg[j]; } } return retval; } /** * Computes the log partition function of the graph (under appropriate conditions) * <p> * This returns the log of sum of the weights of the unnormalized variable belief, which * is the same as the log partition function of the graph as long as it is a tree (or forest) * and solve has been run. * <p> * @category internal * @since 0.08 * @see SumProductSolverGraph#computeLogPartitionFunction() */ @Internal public double computeLogPartitionFunction() { double [] energies = new double[getDomain().size()]; DiscreteMessage prior = knownEnergyMessage(); if (prior != null) { prior.getEnergies(energies); } double normalizationEnergy = 0.0; for (int i = 0, n = getSiblingCount(); i < n; i++) { DiscreteMessage inMsg = getSiblingEdgeState(i).factorToVarMsg; for (int j = 0; j < energies.length; j++) { energies[j] += inMsg.getEnergy(j); } normalizationEnergy += inMsg.getNormalizationEnergy(); } double sum = 0.0; for (int j = 0; j < energies.length; ++j) { sum += energyToWeight(energies[j]); } return weightToEnergy(sum) + normalizationEnergy; } /****************************************************** * Energy, Entropy, and derivatives of all that. ******************************************************/ @Override public double getInternalEnergy() { int domainLength = _model.getDomain().size(); double sum = 0; double [] belief = getBelief(); DiscreteMessage prior = knownEnergyMessage(); final double priorNormalizer = prior != null ? prior.getNormalizationEnergy() : 0; //make sure input is normalized for (int i = 0; i < domainLength; i++) { double tmp = prior != null ? prior.getEnergy(i) : 0; if (tmp != Double.POSITIVE_INFINITY) sum += belief[i] * (tmp - priorNormalizer); } return sum; } @Override public double getBetheEntropy() { double sum = 0; double [] belief = getBelief(); for (int i = 0; i < belief.length; i++) { if (belief[i] != 0) sum -= belief[i] * Math.log(belief[i]); } return sum; } public double calculatedf(double f, int weightIndex, int domain) { double sum = 0; for (int i = 0, n = getSiblingCount(); i < n; i++) { final EdgeState edge = _model.getSiblingEdgeState(i); SumProductTableFactor sft = (SumProductTableFactor)getSibling(i); double inputMsg = getSiblingEdgeState(i).factorToVarMsg.getWeight(domain); double tmp = f / inputMsg; @SuppressWarnings("null") double der = sft.getMessageDerivative(weightIndex, edge.getFactorToVariableEdgeNumber())[domain]; tmp = tmp * der; sum += tmp; } return sum; } public double calculateDerivativeOfBelief(int weightIndex, int domain) { final int n = getDomain().size(); double [] un = getUnormalizedBelief(); //Calculate unormalized belief double f = un[domain]; double g = 0; for (int i = 0; i < n; i++) g += un[i]; double df = calculatedf(f,weightIndex,domain); double dg = 0; for (int i = 0; i < n; i++) { double tmp = un[i]; dg += calculatedf(tmp,weightIndex,i); } //return df; return (df*g - f*dg)/(g*g); } public double calculateDerivativeOfInternalEnergyWithRespectToWeight(int weightIndex) { final int n = getDomain().size(); double sum = 0; //double [] belief = (double[])getBelief(); double [] input = getNormalizedInputs(); //for each domain for (int d = 0; d < n; d++) { //calculate belief(d) //double beliefd = belief[d]; //calculate input(d) double inputd = input[d]; //get derviativebelief(d,weightindex) double dbelief = calculateDerivativeOfBelief(weightIndex,d); sum += dbelief * (-Math.log(inputd)); } return sum; } public double calculateDerivativeOfBetheEntropyWithRespectToWeight(int weightIndex) { double sum = 0; double [] belief = getBelief(); for (int d = belief.length; --d >=0;) { double beliefd = belief[d]; double dbelief = calculateDerivativeOfBelief(weightIndex, d); sum += dbelief * (Math.log(beliefd)) + dbelief; } return -sum * (getSiblingCount()-1); } private double calculateProdFactorMessagesForDomain(int outPortNum, int d) { double f = getNormalizedInputs()[d]; for (int i = 0, n = getSiblingCount(); i < n; i++) { if (i != outPortNum) { f *= getSiblingEdgeState(i).factorToVarMsg.getWeight(d); } } return f; } public void updateDerivativeForWeightNumAndDomainItem(int outPortNum, int weight, int d) { //calculate f double f = calculateProdFactorMessagesForDomain(outPortNum,d); //calculate g double g = 0; for (int i = getDomain().size(); --i>=0;) g += calculateProdFactorMessagesForDomain(outPortNum, i); double derivative = 0; if (g != 0) { //calculate df double df = calculatedf(outPortNum,f,d,weight); //calculate dg double dg = calculatedg(outPortNum,weight); derivative = (df*g - f*dg)/(g*g); } Objects.requireNonNull(_outMessageDerivative)[weight][outPortNum][d] = derivative; } public double calculatedg(int outPortNum,int wn) { double sum = 0; for (int d = getDomain().size(); --d>=0;) { double prod = calculateProdFactorMessagesForDomain(outPortNum, d); sum += calculatedf(outPortNum,prod,d,wn); } return sum; } public void initializeDerivativeMessages(int weights) { _outMessageDerivative = new double[weights][getSiblingCount()][getDomain().size()]; } /** * @deprecated instead use {@link #getMessageDerivative(int, int)}. */ @Deprecated public double [] getMessageDerivative(int wn, Factor f) { return getMessageDerivative(wn, _model.findSibling(f)); } @Internal public double[] getMessageDerivative(int wn, int edgeNumber) { return requireNonNull(_outMessageDerivative)[wn][edgeNumber]; } public double calculatedf(int outPortNum, double f, int d, int wn) { double df = 0; for (int i = 0, n = getSiblingCount(); i < n; i++) { if (i != outPortNum) { final EdgeState edge = _model.getSiblingEdgeState(i); double thisMsg = getSiblingEdgeState(i).factorToVarMsg.getWeight(d); SumProductTableFactor stf = (SumProductTableFactor)getSibling(i); @SuppressWarnings("null") double [] dfactor = stf.getMessageDerivative(wn, edge.getFactorToVariableEdgeNumber()); df += f/thisMsg * dfactor[d]; } } return df; } public void updateDerivativeForWeightNum(int outPortNum, int weight) { for (int d = getDomain().size(); --d >=0; ) { updateDerivativeForWeightNumAndDomainItem(outPortNum,weight,d); } } public void updateDerivative(int outPortNum) { SumProductSolverGraph sfg = (SumProductSolverGraph)getRootSolverGraph(); @SuppressWarnings("null") int numWeights = sfg.getCurrentFactorTable().sparseSize(); for (int wn = 0; wn < numWeights; wn++) { updateDerivativeForWeightNum(outPortNum, wn); } } @Override protected double[] createDefaultMessage() { final double [] retval = super.createDefaultMessage(); Arrays.fill(retval, 1.0); return retval; } /*--------------- * SNode methods */ @Override protected boolean supportsMessageEvents() { return true; } /*----------------- * Private methods */ private void configureDampingFromOptions() { final int size = getSiblingCount(); double[] dampingParams = _dampingParams = getReplicatedNonZeroListFromOptions(BPOptions.nodeSpecificDamping, BPOptions.damping, size, _dampingParams); if (dampingParams.length > 0 && dampingParams.length != size) { DimpleEnvironment.logWarning("%s has wrong number of parameters for %s\n", BPOptions.nodeSpecificDamping, this); _dampingParams = null; } if (dampingParams.length == 0) { _dampingParams = null; } } @Override @SuppressWarnings("null") public SumProductDiscreteEdge getSiblingEdgeState(int siblingIndex) { return (SumProductDiscreteEdge)getSiblingEdgeState_(siblingIndex); } }