/*******************************************************************************
* Copyright 2013 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.solvers.core.parameterizedMessages;
import static org.apache.commons.math3.special.Gamma.*;
import java.io.PrintStream;
import java.util.Arrays;
import org.apache.commons.math3.special.Gamma;
import org.eclipse.jdt.annotation.Nullable;
import com.analog.lyric.collect.ArrayUtil;
import com.analog.lyric.dimple.data.IDatum;
import com.analog.lyric.dimple.exceptions.DimpleException;
import com.analog.lyric.dimple.factorfunctions.Dirichlet;
import com.analog.lyric.dimple.model.values.Value;
import com.google.common.math.DoubleMath;
public class DirichletParameters extends ParameterizedMessageBase
{
private static final long serialVersionUID = 1L;
private static final double SIMPLEX_THRESHOLD = 1e-12;
// The parameters used are the natural additive parameters, (alpha-1)
private double[] _alphaMinusOne;
private transient byte _symmetric; // <0 not computed, 0 false, > 0 true
/*--------------
* Construction
*/
public DirichletParameters()
{
_alphaMinusOne = ArrayUtil.EMPTY_DOUBLE_ARRAY;
}
public DirichletParameters(int length)
{
_alphaMinusOne = new double[length];
_symmetric = -1;
}
public DirichletParameters(int length, double alphaMinusOne)
{
_alphaMinusOne = new double[length];
Arrays.fill(_alphaMinusOne, alphaMinusOne);
_symmetric = 1;
}
public DirichletParameters(double[] alphaMinusOne)
{
_alphaMinusOne = alphaMinusOne.clone();
}
public DirichletParameters(DirichletParameters other) // Copy constructor
{
super(other);
_alphaMinusOne = other._alphaMinusOne.clone();
_symmetric = other._symmetric;
}
@Override
public DirichletParameters clone()
{
return new DirichletParameters(this);
}
public static @Nullable DirichletParameters fromDatum(IDatum datum)
{
if (datum instanceof DirichletParameters)
{
return (DirichletParameters)datum;
}
else if (datum instanceof Dirichlet)
{
return new DirichletParameters(((Dirichlet)datum).getAlphaMinusOneArray());
}
return null;
}
/*---------
* IEquals
*/
@Override
public boolean objectEquals(@Nullable Object other)
{
if (other == this)
{
return true;
}
if (other instanceof DirichletParameters)
{
DirichletParameters that = (DirichletParameters)other;
return super.objectEquals(other) && Arrays.equals(_alphaMinusOne, that._alphaMinusOne);
}
return false;
}
/*----------------------
* IUnaryFactorFunction
*/
@Override
public double evalEnergy(Value value)
{
final double[] x = value.getDoubleArray();
final int n = _alphaMinusOne.length;
if (x.length != n)
{
throw new DimpleException("Argument does not contain %d-dimensional real joint value", n);
}
double sum = 0.0, xSum = 0.0;
if (isSymmetric())
{
for (int i = n; --i>=0;)
{
final double xi = x[i];
if (xi <= 0)
{
return Double.POSITIVE_INFINITY;
}
sum -= Math.log(xi);
xSum += xi;
}
sum *= _alphaMinusOne[0];
}
else
{
for (int i = n; --i>=0;)
{
final double xi = x[i];
if (xi <= 0)
{
return Double.POSITIVE_INFINITY;
}
sum -= (_alphaMinusOne[i]) * Math.log(xi); // -log(x_i ^ (a_i-1))
xSum += xi;
}
}
if (!DoubleMath.fuzzyEquals(xSum, 1, SIMPLEX_THRESHOLD * n)) // Values must be on the probability simplex
{
return Double.POSITIVE_INFINITY;
}
return sum;
}
/*---------------
* Local methods
*/
public final int getSize() {return _alphaMinusOne.length;}
public final void setSize(int size)
{
setAlphaMinusOne(new double[size]);
}
public final double getAlpha(int index)
{
return _alphaMinusOne[index] + 1.0;
}
public final double[] getAlphas()
{
final int n = _alphaMinusOne.length;
double[] alphas = new double[n];
for (int i = n; --i>=0;)
alphas[i] = _alphaMinusOne[i] + 1;
return alphas;
}
public final double[] getAlphaMinusOneArray()
{
return _alphaMinusOne.clone();
}
public final double getAlphaMinusOne(int index) {return _alphaMinusOne[index];}
public final void setAlphaMinusOne(double[] alphaMinusOne)
{
int length = alphaMinusOne.length;
if (length != _alphaMinusOne.length)
{
_alphaMinusOne = alphaMinusOne.clone();
}
else
{
System.arraycopy(alphaMinusOne, 0, _alphaMinusOne, 0, length);
}
forgetNormalizationEnergy();
}
public final void setAlpha(double [] alpha)
{
setAlphaMinusOne(alpha);
for (int i = _alphaMinusOne.length; --i>=0;)
_alphaMinusOne[i] -= 1.0;
}
public final void fillAlphaMinusOne(double alphaMinusOne)
{
Arrays.fill(_alphaMinusOne, alphaMinusOne); // Replicate a single value into all entries
forgetNormalizationEnergy();
_symmetric = 1;
}
// Operations on the parameters
public final void increment(int index)
{
_alphaMinusOne[index]++;
forgetNormalizationEnergy();
}
public final void add(int index, double value)
{
_alphaMinusOne[index] += value;
forgetNormalizationEnergy();
}
public final void add(double[] values)
{
int length = values.length;
for (int i = 0; i < length; i++)
_alphaMinusOne[i] += values[i];
forgetNormalizationEnergy();
}
public final void add(int[] values)
{
int length = values.length;
for (int i = 0; i < length; i++)
_alphaMinusOne[i] += values[i];
forgetNormalizationEnergy();
}
public final void add(DirichletParameters parameters)
{
addFrom(parameters);
}
/**
* True if all the parameters are the same.
* @since 0.08
*/
public final boolean isSymmetric()
{
int symmetric = _symmetric;
if (symmetric < 0)
{
final double[] params = _alphaMinusOne;
final int n = params.length;
symmetric = 1;
if (n > 1)
{
final double a = params[0];
for (int i = 1; i < n; ++i)
{
if (params[i] != a)
{
symmetric = 0;
break;
}
}
}
}
return symmetric > 0;
}
/*--------------------
* IPrintable methods
*/
@Override
public void print(PrintStream out, int verbosity)
{
if (verbosity >= 0)
{
out.print("Dirichlet(");
for (int i = 0, end = getSize(); i < end; ++i)
{
if (i > 0)
{
out.print(',');
if (verbosity > 1)
{
out.print(' ');
}
}
if (verbosity > 1)
{
out.format("a%d=", i);
}
out.format("%g", getAlpha(i));
}
out.print(')');
}
}
/*-------------------------------
* IParameterizedMessage methods
*/
@Override
public void addFrom(IParameterizedMessage other)
{
addFrom((DirichletParameters)other);
}
public void addFrom(DirichletParameters other)
{
double[] params = _alphaMinusOne;
double[] otherParams = other._alphaMinusOne;
if (params.length != otherParams.length)
{
throw new IllegalArgumentException("Cannot add from DirichletParameters with different size");
}
forgetNormalizationEnergy();
if (other.isSymmetric())
{
double a = otherParams[0];
if (_symmetric > 0)
{
Arrays.fill(params, a + params[0]);
_symmetric = 1;
}
else
{
for (int i = params.length; --i>=0;)
{
params[i] += a;
}
}
}
else
{
for (int i = params.length; --i>=0;)
{
params[i] += otherParams[i];
}
}
}
/**
* {@inheritDoc}
* <p>
* Dirichlet parameter messages are computed using:
* <blockquote>
* log Β(<b>α<sub>Q</sub></b>) - log Β(<b>α<sub>P</sub></b>)
* + <big><big>Σ</big></big>(α<sub>Q<sub>i</sub></sub>-α<sub>P<sub>i</sub></sub>)
* (ψ(α<sub>P<sub>i</sub></sub>) - ψ(Σα<sub>P<sub>j</sub></sub>))
* </blockquote>
* where Β(x) is the multinomial beta function:
* <blockquote>
* Β(<b>α</b>) =
* <big>Π</big>Γα<sub>i</sub> <big>/</big> <big>Γ</big>Σα<sub>i</sub>
* </blockquote>
* so
* <blockquote>
* log Β(<b>α</b>) =
* <big>Σ</big>log Γα<sub>i</sub> - log <big>Γ</big>Σα<sub>i</sub>
* </blockquote>
* and ψ(x) is the digamma function.
*/
@Override
public double computeKLDivergence(IParameterizedMessage that)
{
if (that instanceof DirichletParameters)
{
final DirichletParameters P = this, Q = (DirichletParameters)that;
final double[] alphasP = P._alphaMinusOne, alphasQ = Q._alphaMinusOne;
final int size = alphasP.length;
assertSameSize(alphasQ.length);
// To summarize the doc comment in plain ascii:
//
// logGamma(sum(alphaPi) - logGamma(sum(alphaQi) + sum(logGamma(alphaQi)) - sum(logGamma(alphaPi))
// + sum((alphaPi - alphaQi) * (digamma(alphaPi) - digamma(sum(alphaPj)))
//
double divergence = 0.0;
if (size > 0)
{
// TODO optimize for symmetric case
double alphaSumP = size, alphaSumQ = size;
for (int i = 0; i < size; ++i)
{
alphaSumP += alphasP[i];
alphaSumQ += alphasQ[i];
}
final double digammaAlphaSumP = digamma(alphaSumP);
if (alphaSumP != alphaSumQ)
{
divergence += logGamma(alphaSumP) - logGamma(alphaSumQ);
}
for (int i = 0; i < size; ++i)
{
final double alphaP = alphasP[i] + 1, alphaQ = alphasQ[i] + 1;
if (alphaP != alphaQ)
{
divergence += logGamma(alphaQ);
divergence -= logGamma(alphaP);
divergence += (alphaP-alphaQ) * (digamma(alphaP) - digammaAlphaSumP);
}
}
}
return divergence;
}
throw new IllegalArgumentException(String.format("Expected '%s' but got '%s'", getClass(), that.getClass()));
}
@Override
public boolean isNull()
{
if (_symmetric > 0)
return _alphaMinusOne.length == 0 || _alphaMinusOne[0] == 0.0;
for (double alpha : _alphaMinusOne)
if (alpha != 0.0)
return false;
_symmetric = 1;
return true;
}
@Override
public void setFrom(IParameterizedMessage other)
{
DirichletParameters that = (DirichletParameters)other;
double[] newAlphaMinusOne = that._alphaMinusOne;
final int size = newAlphaMinusOne.length;
assertSameSize(size);
System.arraycopy(newAlphaMinusOne, 0, _alphaMinusOne, 0, size);
copyNormalizationEnergy(that);
}
/**
* {@inheritDoc}
* <p>
* Sets all alphas to one.
*/
@Override
public final void setUniform()
{
fillAlphaMinusOne(0);
}
public final void setNull(int size)
{
setSize(size); // Create the array if it isn't already there, or change the size
setUniform();
}
/*-------------------
* Protected methods
*/
protected void assertSameSize(int otherSize)
{
final int size = _alphaMinusOne.length;
if (size != otherSize)
{
throw new IllegalArgumentException(
String.format("Incompatible Dirichlet sizes '%d' and '%d'", size, otherSize));
}
}
@Override
protected final double computeNormalizationEnergy()
{
final double[] alphaMinusOne = _alphaMinusOne;
final int n = alphaMinusOne.length;
final boolean symmetric = isSymmetric() & n > 1;
double sumAlpha = 0;
double sumLogGamma = 0;
final int end = symmetric ? 1 : n;
for (int i = 0; i < end; ++i)
{
final double alpha = alphaMinusOne[i] + 1;
sumAlpha += alpha;
sumLogGamma += Gamma.logGamma(alpha);
}
if (symmetric)
{
sumAlpha *= n;
sumLogGamma *= n;
}
return -(sumLogGamma - Gamma.logGamma(sumAlpha));
}
@Override
protected void forgetNormalizationEnergy()
{
super.forgetNormalizationEnergy();
_symmetric = -1;
}
}