/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.core.parameterizedMessages;
import static org.apache.commons.math3.special.Gamma.*;
import java.io.PrintStream;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.dimple.model.values.Value;
public class GammaParameters extends ParameterizedMessageBase
{
private static final long serialVersionUID = 1L;
private double _alphaMinusOne = 0;
private double _beta = 0;
/*--------------
* Construction
*/
public GammaParameters() {}
public GammaParameters(double alphaMinusOne, double beta)
{
// FIXME - it seems that really this constructor should take alpha, not alpha - 1
_alphaMinusOne = alphaMinusOne;
_beta = beta;
if (alphaMinusOne <= -1)
throw new IllegalArgumentException("Non-positive alpha parameter. This must be a positive value.");
if (beta < 0)
throw new IllegalArgumentException("Negative beta parameter. This must be a positive value.");
}
public GammaParameters(GammaParameters other) // Copy constructor
{
super(other);
_alphaMinusOne = other._alphaMinusOne;
_beta = other._beta;
}
@Override
public GammaParameters clone()
{
return new GammaParameters(this);
}
/*---------
* IEquals
*/
@Override
public boolean objectEquals(@Nullable Object other)
{
if (this == other)
{
return true;
}
if (other instanceof GammaParameters)
{
GammaParameters that = (GammaParameters)other;
return _alphaMinusOne == that._alphaMinusOne && _beta == that._beta && super.objectEquals(other);
}
return false;
}
/*-----------------------
* IUnaryFactorFunctions
*/
@Override
public double evalEnergy(Value value)
{
final double x = value.getDouble();
if (x < 0)
{
return Double.POSITIVE_INFINITY;
}
if (_alphaMinusOne == 0.0)
{
return x * _beta;
}
else
{
return x * _beta - Math.log(x) * _alphaMinusOne;
}
}
/*---------------
* Local methods
*/
// Natural parameters are alpha-1 and beta
public final double getAlphaMinusOne() {return _alphaMinusOne;}
public final double getBeta() {return _beta;}
public final void setAlphaMinusOne(double alphaMinusOne)
{
_alphaMinusOne = alphaMinusOne;
forgetNormalizationEnergy();
}
public final void setBeta(double beta)
{
_beta = beta;
forgetNormalizationEnergy();
}
// Ordinary alpha parameter
public final double getAlpha() {return _alphaMinusOne + 1;}
public final void setAlpha(double alpha)
{
setAlphaMinusOne(alpha - 1.0);
}
/*--------------------
* IPrintable methods
*/
@Override
public void print(PrintStream out, int verbosity)
{
if (verbosity >= 0)
{
String fmt;
switch (verbosity)
{
case 0:
fmt = "Gamma(%g,%g)";
break;
default:
fmt = "Gamma(alpha=%g, beta=%g)";
break;
}
out.format(fmt, getAlpha(), getBeta());
}
}
/*-------------------------------
* IParameterizedMessage methods
*/
@Override
public void addFrom(IParameterizedMessage other)
{
addFrom((GammaParameters)other);
}
public void addFrom(GammaParameters other)
{
_alphaMinusOne += other._alphaMinusOne;
_beta += other._beta;
}
/**
* {@inheritDoc}
* <p>
* Computes KL as follows, where Γ(x) is the gamma function and
* ψ(x) is the digamma function.
* <p>
* (α<sub>P</sub>-α<sub>Q</sub>)ψ(α<sub>P</sub>)
* - ln(Γ(α<sub>P</sub>)) + ln(Γ(α<sub>Q</sub>))
* + α<sub>Q</sub>(ln(β<sub>P</sub>/β<sub>Q</sub>))
* + α<sub>P</sub>(β<sub>Q</sub>-β<sub>P</sub>)/β<sub>P</sub>
*
* @see <a href="http://en.wikipedia.org/wiki/Gamma_distribution#Kullback.E2.80.93Leibler_divergence"
* >Gamma distribution (Wikipedia)</a>
*/
@Override
public double computeKLDivergence(IParameterizedMessage that)
{
if (that instanceof GammaParameters)
{
// KL(P|Q) == (ap-aq)*digamma(ap) - log(gamma(ap)) + log(gamma(aq)) + aq*(log(bp)-log(bq)) + ap*(bq-bp)/bp
final GammaParameters P = this, Q = (GammaParameters)that;
final double ap = P.getAlpha(), aq = Q.getAlpha();
final double bp = P.getBeta(), bq = Q.getBeta();
double divergence = 0.0;
if (ap != aq)
{
divergence += (ap-aq)*digamma(ap);
divergence -= logGamma(ap);
divergence += logGamma(aq);
}
if (bp != bq)
{
divergence += aq*(Math.log(bp)-Math.log(bq)) + ap * ((bq-bp)/bp);
}
return divergence;
}
throw new IllegalArgumentException(String.format("Expected '%s' but got '%s'", getClass(), that.getClass()));
}
@Override
public boolean isNull()
{
return _beta == 0.0 && _alphaMinusOne == 0.0;
}
@Override
public void setFrom(IParameterizedMessage other)
{
GammaParameters that = (GammaParameters)other;
_alphaMinusOne = that._alphaMinusOne;
_beta = that._beta;
copyNormalizationEnergy(that);
}
/**
* {@inheritDoc}
* <p>
* Sets alpha to one and beta to zero.
*/
@Override
public final void setUniform()
{
_alphaMinusOne = 0;
_beta = 0;
_normalizationEnergy = 0.0;
}
/*-------------------
* Protected methods
*/
@Override
protected double computeNormalizationEnergy()
{
final double alpha = _alphaMinusOne + 1;
final double logBeta = _beta != 0 ? Math.log(_beta) : 0.0;
return -(org.apache.commons.math3.special.Gamma.logGamma(alpha) - alpha * logBeta);
}
}