/*******************************************************************************
* Copyright 2012 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.factorfunctions;
import java.util.Map;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunctionUtilities;
import com.analog.lyric.dimple.factorfunctions.core.IParametricFactorFunction;
import com.analog.lyric.dimple.factorfunctions.core.UnaryFactorFunction;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DirichletParameters;
/**
* Factor for an exchangeable set of Dirichlet distributed variables
* associated with a variable or fixed parameter vector. In this version,
* all of the parameter values are common and specified as a single real
* value. The variables are ordered as follows in the argument list:
* <p>
* <ol>
* <li>Parameter (non-negative Real variable).
* <li> An arbitrary number of RealJoint variables, each one a Dirichlet distributed random variable.
* </ol>
* The dimension of the Dirichlet variable must be specified in the constructor.
* <p>
* The parameter may optionally be specified as constants in the constructor.
* In this case, the parameters are not included in the list of arguments.
*
* @since 0.05
*/
public class ExchangeableDirichlet extends UnaryFactorFunction implements IParametricFactorFunction
{
private static final long serialVersionUID = 1L;
protected DirichletParameters _parameters;
private boolean _parametersConstant;
private int _firstDirectedToIndex;
private ExchangeableDirichlet(DirichletParameters parameters, int index)
{
super((String)null);
_parameters = parameters;
_parametersConstant = index == 0;
_firstDirectedToIndex = index;
if (!parameters.isSymmetric())
{
throw new IllegalArgumentException("ExchangeableDirichlet requires symmetric arguments");
}
}
public ExchangeableDirichlet(DirichletParameters parameters)
{
this(parameters, 0);
}
public ExchangeableDirichlet(int dimension) // Variable parameter
{
this(new DirichletParameters(dimension), 1);
}
public ExchangeableDirichlet(int dimension, double alpha) // Constant parameter
{
this(new DirichletParameters(dimension, alpha - 1));
}
protected ExchangeableDirichlet(ExchangeableDirichlet other)
{
super(other);
_parameters = _parameters.clone();
_firstDirectedToIndex = other._firstDirectedToIndex;
_parametersConstant = other._parametersConstant;
}
@Override
public ExchangeableDirichlet clone()
{
return new ExchangeableDirichlet(this);
}
/*----------------
* IDatum methods
*/
@Override
public boolean objectEquals(@Nullable Object other)
{
if (this == other)
{
return true;
}
if (other instanceof ExchangeableDirichlet)
{
ExchangeableDirichlet that = (ExchangeableDirichlet)other;
return _parametersConstant == that._parametersConstant &&
_parameters.objectEquals(_parameters) &&
_firstDirectedToIndex == that._firstDirectedToIndex;
}
return false;
}
@Override
public final double evalEnergy(Value[] arguments)
{
int index = 0;
if (!_parametersConstant)
{
double alpha = arguments[index++].getDouble();
if (alpha <= 0)
return Double.POSITIVE_INFINITY;
_parameters.fillAlphaMinusOne(alpha - 1.0);
}
return _parameters.evalNormalizedEnergy(arguments, index);
}
@Override
public final boolean isDirected() {return true;}
@Override
public final int[] getDirectedToIndices(int numEdges)
{
// All edges except the parameter edges (if present) are directed-to edges
return FactorFunctionUtilities.getListOfIndices(_firstDirectedToIndex, numEdges-1);
}
/*-----------------------------------
* IParametricFactorFunction methods
*/
@Override
public int copyParametersInto(Map<String, Object> parameters)
{
if (_parametersConstant)
{
parameters.put("alpha", getParameter("alpha"));
return 1;
}
return 0;
}
@Override
public @Nullable Object getParameter(String parameterName)
{
if (_parametersConstant)
{
switch (parameterName)
{
case "alpha":
case "alphas":
return getAlphaMinusOne() + 1;
}
}
return null;
}
@Override
public DirichletParameters getParameterizedMessage()
{
return _parameters;
}
@Override
public final boolean hasConstantParameters()
{
return _parametersConstant;
}
/*--------------------------
* Factor-specific methods
*/
public final double getAlphaMinusOne()
{
return _parameters.getAlphaMinusOne(0);
}
public final double[] getAlphaMinusOneArray() // Get parameters as if they were separate
{
return _parameters.getAlphaMinusOneArray();
}
public final int getDimension()
{
return _parameters.getSize();
}
}