/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.sumproduct.customFactors; import java.util.List; import com.analog.lyric.dimple.exceptions.DimpleException; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.NormalParameters; import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph; /* * This class is for backward compatibility only. * This is to work with the deprecated "linear" custom factor. * This has been replaced by the LinearEquation factor function, which has a different * interface to the "linear" custom factor, and is associated with * the CustomGaussianLinearEquation custom factor. */ public class CustomGaussianLinear extends GaussianFactorBase { private double [] _constants; private double _total; public CustomGaussianLinear(Factor factor, SumProductSolverGraph parent) { super(factor, parent); //Make sure this is of the form a = b*c where either b or c is a constant. final List<Value> constants = factor.getConstantValues(); final int n = constants.size(); if (n < 1 || n > 2) throw new DimpleException("Need to specify vector of constants"); _constants = constants.get(0).getDoubleArray(); _total = 0; if (n == 2) { _total = constants.get(1).getDouble(); } if (factor.getSiblingCount() != _constants.length) throw new DimpleException("Length of constants must equal the size of the number of variables"); } @Override public void doUpdateEdge(int outPortNum) { double mu; double sigma2; if (_constants[outPortNum] == 0) { mu = 0; sigma2 = Double.POSITIVE_INFINITY; } else { // mui = 1/constanti (sum j!=i muj * cj) + total/constanti // sigma2 = 1/constanti^2 * (sumj!=i constantj^2 sigmaj^2) mu = _total; sigma2 = 0; for (int i = 0, n = getSiblingCount(); i < n; i++) { if (i != outPortNum) { NormalParameters msg = getSiblingEdgeState(i).varToFactorMsg; mu -= msg.getMean() * _constants[i]; sigma2 += _constants[i]*_constants[i]*msg.getVariance(); } } mu /= _constants[outPortNum]; sigma2 /= (_constants[outPortNum]*_constants[outPortNum]); } NormalParameters msg = getSiblingEdgeState(outPortNum).factorToVarMsg; msg.setMean(mu); msg.setVariance(sigma2); } }