/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.sumproduct; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.data.IDatum; import com.analog.lyric.dimple.environment.DimpleEnvironment; import com.analog.lyric.dimple.factorfunctions.MultivariateNormal; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.RealJoint; import com.analog.lyric.dimple.solvers.core.PriorAndCondition; import com.analog.lyric.dimple.solvers.core.SMultivariateNormalEdge; import com.analog.lyric.dimple.solvers.core.SRealJointVariableBase; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.MultivariateNormalParameters; /** * Solver variable for RealJoint variables under Sum-Product solver. * * @since 0.07 */ public class SumProductRealJoint extends SRealJointVariableBase { private int _numVars; public SumProductRealJoint(RealJoint var, SumProductSolverGraph parent) { super(var, parent); _numVars = _model.getDomain().getNumVars(); } @Override public Object getBelief() { MultivariateNormalParameters m = new MultivariateNormalParameters(getDomain().getDimensions()); doUpdate(m,-1); return m; } @Override public Object getValue() { MultivariateNormalParameters m = (MultivariateNormalParameters)getBelief(); return m.getMean(); } @Override protected void doUpdateEdge(int outPortNum) { doUpdate(getSiblingEdgeState(outPortNum).varToFactorMsg, outPortNum); } private void doUpdate(MultivariateNormalParameters outMsg, int outPortNum) { PriorAndCondition known = getPriorAndCondition(); Value fixedValue = known.value(); if (fixedValue != null) { // If fixed value, just return the input, which has been set to a zero-variance message outMsg.setDeterministic(fixedValue); } else { outMsg.setNull(); for (IDatum datum : known) { final MultivariateNormalParameters input = datumToNormal(datum); if (input != null) { outMsg.addFrom(input); } } for (int i = 0, n = getSiblingCount(); i < n; i++ ) { if (i != outPortNum) { final MultivariateNormalParameters inMsg = getSiblingEdgeState(i).factorToVarMsg; outMsg.addFrom(inMsg); } } } known.release(); } public MultivariateNormalParameters createDefaultMessage() { return new MultivariateNormalParameters(_numVars); } @Deprecated @Override public void setInputMsgValues(int portIndex, Object obj) { getSiblingEdgeState(portIndex).factorToVarMsg.set((MultivariateNormalParameters)obj); } public MultivariateNormalParameters createFixedValueMessage(double[] fixedValue) { double[] variance = new double[_numVars]; MultivariateNormalParameters message = new MultivariateNormalParameters(fixedValue, variance); return message; } /*----------------------- * SVariableBase methods */ @Override protected MultivariateNormalParameters cloneMessage(int edge) { return getSiblingEdgeState(edge).varToFactorMsg.clone(); } @Override protected boolean supportsMessageEvents() { return true; } @SuppressWarnings("null") @Override public SMultivariateNormalEdge getSiblingEdgeState(int siblingIndex) { return (SMultivariateNormalEdge)getSiblingEdgeState_(siblingIndex); } /*----------------- * Private methods */ private @Nullable MultivariateNormalParameters datumToNormal(@Nullable IDatum prior) { if (prior instanceof MultivariateNormalParameters) { return (MultivariateNormalParameters)prior; } else if (prior instanceof MultivariateNormal) { return ((MultivariateNormal)prior).getParameters(); } else if (prior != null) { DimpleEnvironment.logError( "Ignoring prior on %s: sum-product reals only supports MultivariateNormalParameters for priors but got %s", _model, prior); } return null; } }