/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.sumproduct.customFactors;
import static java.util.Objects.*;
import com.analog.lyric.dimple.model.domains.RealJointDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.solvers.core.SolverFactorCreationException;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.MultivariateNormalParameters;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph;
public class CustomMultivariateGaussianProduct extends MultivariateGaussianFactorBase
{
private double [][] _constant;
public CustomMultivariateGaussianProduct(Factor factor, SumProductSolverGraph parent)
{
super(factor, parent);
// Make sure this is of the form a = b*c where either b or c is a constant.
if (factor.getSiblingCount() != 2)
throw new SolverFactorCreationException("factor must be of form a = b*c where b is a constant matrix");
// TODO: alternatively, one of the ports could be a discrete variable with a single domain element
if (factor.getConstantCount() != 1)
throw new SolverFactorCreationException("expected one constant");
if (!factor.hasConstantAtIndex(1))
throw new SolverFactorCreationException("Expect matrix to be second arg");
Object constantObj = requireNonNull(factor.getConstantValues().get(0).getObject());
if (!(constantObj instanceof double[][]))
{
throw new SolverFactorCreationException("Constant not a double[][] matrix");
}
final double[][] constant = (double[][])constantObj;
assertUnboundedRealJoint(factor);
final int yDim = factor.getSibling(0).getDomain().getDimensions();
final int xDim = factor.getSibling(1).getDomain().getDimensions();
if (constant.length != yDim || constant[0].length != xDim)
{
throw new SolverFactorCreationException("Constant matrix does not have expected dimensions");
}
_constant = constant;
}
@Override
public void doUpdateEdge(int outPortNum)
{
MutlivariateGaussianMatrixProduct matMult = new MutlivariateGaussianMatrixProduct(_constant);
char direction;
if (outPortNum == 0)
direction = 'F';
else
direction = 'R';
MultivariateNormalParameters outMsg = getSiblingEdgeState(outPortNum).factorToVarMsg;
MultivariateNormalParameters inMsg = getSiblingEdgeState(1-outPortNum).varToFactorMsg;
matMult.ComputeMsg(inMsg, outMsg, direction);
}
/**
* Utility to indicate whether or not a factor is compatible with the requirements of this custom factor
* @deprecated as of release 0.08
*/
@Deprecated
public static boolean isFactorCompatible(Factor factor)
{
// Must be of the form form y = A*x where either A is a constant matrix.
if (factor.getSiblingCount() != 2)
return false;
// Must have exactly one constant
if (factor.getConstantCount() != 1)
return false;
Variable y = factor.getSibling(0);
Variable x = factor.getSibling(1);
RealJointDomain yDomain = y.getDomain().asRealJoint();
RealJointDomain xDomain = x.getDomain().asRealJoint();
// Variables must be unbounded multivariate reals
if (yDomain == null || xDomain == null || yDomain.isBounded() || xDomain.isBounded())
{
return false;
}
// Constant must be a matrix of the proper size
int yDimension = yDomain.getDimensions();
int xDimension = xDomain.getDimensions();
Object constant = factor.getConstantValues().get(0).getObject();
if (!(constant instanceof double[][]))
return false;
double[][] dConstant = (double[][])constant;
if (dConstant.length != yDimension || dConstant[0].length != xDimension)
return false;
return true;
}
}