/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.factorfunctions;
import java.util.Map;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunctionUtilities;
import com.analog.lyric.dimple.factorfunctions.core.IParametricFactorFunction;
import com.analog.lyric.dimple.factorfunctions.core.UnaryFactorFunction;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.NormalParameters;
import com.analog.lyric.util.misc.Matlab;
/**
* Factor for an exchangeable set of Normally distributed variables associated
* with a variable representing the mean parameter and a second parameter
* representing the precision. The variables are ordered as follows in
* the argument list:
*
* 1) Mean parameter (real variable)
* 2) Precision parameter (inverse variance) (real variable; domain must be non-negative)
* 3...) An arbitrary number of real variables
*
* Mean and precision parameters may optionally be specified as constants in the constructor.
* In this case, the mean and precision are not included in the list of arguments.
*
*/
@Matlab(wrapper="NormalParameters")
public class Normal extends UnaryFactorFunction implements IParametricFactorFunction
{
private static final long serialVersionUID = 1L;
protected NormalParameters _parameters;
protected boolean _parametersConstant;
protected int _firstDirectedToIndex;
/*--------------
* Construction
*/
private Normal(NormalParameters parameters, boolean constant)
{
super((String)null);
_parameters = parameters;
_parametersConstant = constant;
_firstDirectedToIndex = constant? 0 : 2;
}
public Normal()
{
this(new NormalParameters(), false);
}
public Normal(double mean, double precision)
{
this(new NormalParameters(mean, precision));
if (precision < 0)
throw new DimpleException("Negative precision value. This must be a non-negative value.");
}
/**
* @since 0.05
*/
public Normal(NormalParameters parameters)
{
this(parameters, true);
}
/**
* Construct a Normal function with fixed parameters.
* <p>
* @param parameters specifies the mean and precision. Several different
* keywords are supported. To set the mean parameter, this will first look
* for a value using the keyword "mean" and then using "mu" and will otherwise
* default to a value of zero. To set the precision parameter, this will first
* look for a value using the keyword "precision" and then "tau". If not found,
* this will next look for "variance" and if found will set the precision to its
* reciprocal. If still not found, it will see if the value is specified as
* standard deviation under the keywords "std" or "sigma". If no matching keyword
* is found the precision will default to one.
* @since 0.07
*/
public Normal(Map<String,Object> parameters)
{
this(new NormalParameters(parameters));
}
protected Normal(Normal other)
{
super(other);
_parameters = other._parameters.clone();
_firstDirectedToIndex = other._firstDirectedToIndex;
_parametersConstant = other._parametersConstant;
}
@Override
public Normal clone()
{
return new Normal(this);
}
/*----------------
* IDatum methods
*/
@Override
public boolean objectEquals(@Nullable Object other)
{
if (this == other)
{
return true;
}
if (other instanceof Normal)
{
Normal that = (Normal)other;
return _parametersConstant == that._parametersConstant &&
_parameters.objectEquals(that._parameters) &&
_firstDirectedToIndex == that._firstDirectedToIndex;
}
return false;
}
/*------------------------
* FactorFunction methods
*/
@Override
public final double evalEnergy(Value[] arguments)
{
int index = 0;
if (!_parametersConstant)
{
double mean = arguments[index++].getDouble(); // First variable is mean parameter
double precision = arguments[index++].getDouble(); // Second variable is precision (must be non-negative)
if (precision < 0) return Double.POSITIVE_INFINITY;
_parameters.setMean(mean);
_parameters.setPrecision(precision);
}
return _parameters.evalNormalizedEnergy(arguments, index);
}
@Override
public final boolean isDirected() {return true;}
@Override
public final int[] getDirectedToIndices(int numEdges)
{
// All edges except the parameter edges (if present) are directed-to edges
return FactorFunctionUtilities.getListOfIndices(_firstDirectedToIndex, numEdges-1);
}
/*-----------------------------------
* IParametricFactorFunction methods
*/
@Override
public int copyParametersInto(Map<String, Object> parameters)
{
if (_parametersConstant)
{
parameters.put("mean", getMean());
parameters.put("precision", getPrecision());
return 2;
}
return 0;
}
@Override
public @Nullable Object getParameter(String parameterName)
{
if (_parametersConstant)
{
switch (parameterName)
{
case "mean":
case "mu":
return getMean();
case "precision":
return getPrecision();
case "variance":
return getVariance();
case "sigma":
case "std":
return getStandardDeviation();
}
}
return null;
}
@Override
public NormalParameters getParameterizedMessage()
{
return _parameters;
}
@Override
public final boolean hasConstantParameters()
{
return _parametersConstant;
}
/*-------------------------
* Factor-specific methods
*/
/**
* @since 0.05
*/
public final NormalParameters getParameters()
{
return _parameters;
}
@Matlab
public final double getMean()
{
return _parameters.getMean();
}
@Matlab
public final double getPrecision()
{
return _parameters.getPrecision();
}
/**
* @since 0.05
*/
@Matlab
public final double getVariance()
{
return _parameters.getVariance();
}
/**
* @since 0.05
*/
@Matlab
public final double getStandardDeviation()
{
return _parameters.getStandardDeviation();
}
/**
* @since 0.05
*/
public final void setMean(double mean)
{
_parameters.setMean(mean);
}
/**
* @since 0.05
*/
public final void setPrecision(double precision)
{
_parameters.setPrecision(precision);
}
}