/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.factorfunctions;
import java.util.Map;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunctionUtilities;
import com.analog.lyric.dimple.factorfunctions.core.IParametricFactorFunction;
import com.analog.lyric.dimple.factorfunctions.core.UnaryFactorFunction;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.NormalParameters;
/**
* Log-normal distribution.
* <p>
* The variables in the argument list are ordered as follows:
* <ol>
* <li>Mean parameter
* <li>Precision parameter (inverse variance) (non-negative)
* <li>An arbitrary number of real variables
* </ol>
* Mean and precision parameters may optionally be specified as constants in the constructor.
* In this case, the mean and precision are not included in the list of arguments.
*/
public class LogNormal extends UnaryFactorFunction implements IParametricFactorFunction
{
private static final long serialVersionUID = 1L;
protected double _mean;
protected double _precision;
protected double _logSqrtPrecisionOver2Pi;
protected double _precisionOverTwo;
protected boolean _parametersConstant = false;
protected int _firstDirectedToIndex = 2;
protected static final double _logSqrt2pi = Math.log(2*Math.PI)*0.5;
/*--------------
* Construction
*/
public LogNormal() {super((String)null);}
public LogNormal(double mean, double precision)
{
this();
_mean = mean;
_precision = precision;
_logSqrtPrecisionOver2Pi = Math.log(_precision)*0.5 - _logSqrt2pi;
_precisionOverTwo = _precision*0.5;
_parametersConstant = true;
_firstDirectedToIndex = 0;
if (_precision < 0) throw new DimpleException("Negative precision value. This must be a non-negative value.");
}
/**
* Constructs log-normal distribution with fixed mean and precision.
* @param parameters is in the same form accepted by {@link Normal#Normal(Map)}.
* @since 0.07
*/
public LogNormal(Map<String,Object> parameters)
{
this(new NormalParameters(parameters));
}
protected LogNormal(LogNormal other)
{
super(other);
_mean = other._mean;
_precision = other._precision;
_logSqrtPrecisionOver2Pi = other._logSqrtPrecisionOver2Pi;
_precisionOverTwo = other._precisionOverTwo;
_parametersConstant = other._parametersConstant;
_firstDirectedToIndex = other._firstDirectedToIndex;
}
@Override
public LogNormal clone()
{
return new LogNormal(this);
}
private LogNormal(NormalParameters parameters)
{
this(parameters.getMean(), parameters.getPrecision());
}
/*----------------
* IDatum methods
*/
@Override
public boolean objectEquals(@Nullable Object other)
{
if (this == other)
{
return true;
}
if (other instanceof LogNormal)
{
LogNormal that = (LogNormal)other;
return _parametersConstant == that._parametersConstant &&
_mean == that._mean &&
_precision == that._precision &&
_firstDirectedToIndex == that._firstDirectedToIndex;
}
return false;
}
/*------------------------
* FactorFunction methods
*/
@Override
public final double evalEnergy(Value[] arguments)
{
int index = 0;
if (!_parametersConstant)
{
_mean = arguments[index++].getDouble(); // First variable is mean parameter
_precision = arguments[index++].getDouble(); // Second variable is precision (must be non-negative)
_logSqrtPrecisionOver2Pi = Math.log(_precision)*0.5 - _logSqrt2pi;
_precisionOverTwo = _precision*0.5;
if (_precision < 0) return Double.POSITIVE_INFINITY;
}
final int length = arguments.length;
final int N = length - index; // Number of non-parameter variables
double sum = 0;
for (; index < length; index++)
{
final double x = arguments[index].getDouble(); // Remaining inputs are LogNormal variables
if (x <= 0)
return Double.POSITIVE_INFINITY;
else
{
final double logX = Math.log(x);
final double relLogX = logX - _mean;
sum += logX + relLogX*relLogX*_precisionOverTwo;
}
}
return sum - N * _logSqrtPrecisionOver2Pi;
}
@Override
public final boolean isDirected() {return true;}
@Override
public final int[] getDirectedToIndices(int numEdges)
{
// All edges except the parameter edges (if present) are directed-to edges
return FactorFunctionUtilities.getListOfIndices(_firstDirectedToIndex, numEdges-1);
}
/*-----------------------------------
* IParametricFactorFunction methods
*/
@Override
public int copyParametersInto(Map<String, Object> parameters)
{
if (_parametersConstant)
{
parameters.put("mean", _mean);
parameters.put("precision", _precision);
return 2;
}
return 0;
}
@Override
public @Nullable Object getParameter(String parameterName)
{
if (_parametersConstant)
{
switch (parameterName)
{
case "mean":
case "mu":
return _mean;
case "precision":
return _precision;
case "variance":
return 1.0 / _precision;
case "sigma":
case "std":
return Math.sqrt(1.0 / _precision);
}
}
return null;
}
@Override
public final boolean hasConstantParameters()
{
return _parametersConstant;
}
/*-------------------------
* Factor-specific methods
*/
public final double getMean()
{
return _mean;
}
public final double getPrecision()
{
return _precision;
}
}