/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.factorfunctions;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunctionUtilities;
import com.analog.lyric.dimple.factorfunctions.core.UnaryFactorFunction;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.MultivariateNormalParameters;
import com.analog.lyric.util.misc.Matlab;
/**
* Factor for an exchangeable set of multivariate Normally distributed variables associated
* with a vector representing the mean parameter and a matrix (in the form of an array of vectors)
* representing the precision, or alternatively a vector representing the information vector and
* a matrix representing the information matrix.
* <p>
* TODO: THE CURRENT IMPLEMENTATION OF THIS FACTOR FUNCTION SUPPORTS ONLY CONSTANT PARAMETRES,
* SPECIFIED IN THE CONSTRUCTOR. VARIABLE PARAMETER SUPPORT SHOULD BE ADDED.
* <p>
* The variables are ordered as follows in the argument list:
*
* 1...) An arbitrary number of RealJoint variables
*
* @since 0.05
*/
@Matlab(wrapper="MultivariateNormalParameters")
public class MultivariateNormal extends UnaryFactorFunction
{
private static final long serialVersionUID = 1L;
private MultivariateNormalParameters _parameters;
private boolean _parametersConstant = true; // TODO: support variable parameters
private int _firstDirectedToIndex;
protected static final double _logSqrt2pi = Math.log(2*Math.PI)*0.5;
// Constructors
// public MultivariateNormal() {super();} // TODO: Implement variable parameters case
public MultivariateNormal(double[] mean, double[][] covariance)
{
this(new MultivariateNormalParameters(mean, covariance));
}
public MultivariateNormal(double[] vector, double[][] matrix, boolean informationForm)
{
this(new MultivariateNormalParameters(vector, matrix, informationForm));
}
public MultivariateNormal(MultivariateNormalParameters parameters)
{
super((String)null);
_parameters = parameters;
initializeConstantParameters(parameters);
}
protected MultivariateNormal(MultivariateNormal other)
{
super(other);
_parameters = other._parameters.clone();
_parametersConstant = other._parametersConstant;
_firstDirectedToIndex = other._firstDirectedToIndex;
}
@Override
public MultivariateNormal clone()
{
return new MultivariateNormal(this);
}
// Common initialization method used when parameters are constant
private void initializeConstantParameters(MultivariateNormalParameters parameters)
{
_parameters = parameters;
_parametersConstant = true;
_firstDirectedToIndex = 0;
}
/*----------------
* IDatum methods
*/
@Override
public boolean objectEquals(@Nullable Object other)
{
if (this == other)
{
return true;
}
if (other instanceof MultivariateNormal)
{
MultivariateNormal that = (MultivariateNormal)other;
return _parametersConstant == that._parametersConstant &&
_firstDirectedToIndex == that._firstDirectedToIndex &&
_parameters.objectEquals(that._parameters);
}
return false;
}
@Override
public final double evalEnergy(Value[] arguments)
{
final MultivariateNormalParameters params = _parameters;
final int length = arguments.length;
int index = 0;
final int N = length - index; // Number of non-parameter variables
double sum = 0;
for (; index < length; index++)
{
sum += params.evalEnergy(arguments[index]);
}
return sum - N * params.getNormalizationEnergy();
}
@Override
public final boolean isDirected() {return true;}
@Override
public final int[] getDirectedToIndices(int numEdges)
{
// All edges except the parameter edges (if present) are directed-to edges
return FactorFunctionUtilities.getListOfIndices(_firstDirectedToIndex, numEdges-1);
}
// Factor-specific methods
public final boolean hasConstantParameters()
{
return _parametersConstant;
}
public final MultivariateNormalParameters getParameters()
{
return _parameters.clone();
}
public final void setParameters(MultivariateNormalParameters parameters)
{
initializeConstantParameters(parameters);
}
}