/******************************************************************************* * Copyright 2013-2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.solvers.core.parameterizedMessages; import static org.apache.commons.math3.special.Beta.*; import static org.apache.commons.math3.special.Gamma.*; import java.io.PrintStream; import org.eclipse.jdt.annotation.Nullable; import com.analog.lyric.dimple.model.values.Value; /** * Parameterized message for beta distribution. */ public class BetaParameters extends ParameterizedMessageBase { /*------- * State */ private static final long serialVersionUID = 1L; // The parameters used are the natural additive parameters, (alpha-1) and (beta-1) private double _alphaMinusOne; private double _betaMinusOne; /*-------------- * Construction */ public BetaParameters() { } public BetaParameters(double alphaMinusOne, double betaMinusOne) { _alphaMinusOne = alphaMinusOne; _betaMinusOne = betaMinusOne; } public BetaParameters(BetaParameters other) // Copy constructor { super(other); _alphaMinusOne = other._alphaMinusOne; _betaMinusOne = other._betaMinusOne; } @Override public BetaParameters clone() { return new BetaParameters(this); } /*---------------- * Object methods */ @Override public String toString() { return String.format("Beta(%g,%g)", getAlpha(), getBeta()); } /*----------------- * IEquals methods */ @Override public boolean objectEquals(@Nullable Object other) { if (other == this) { return true; } if (other instanceof BetaParameters) { BetaParameters that = (BetaParameters)other; return _alphaMinusOne == that._alphaMinusOne && _betaMinusOne == that._betaMinusOne && super.objectEquals(other); } return false; } /*---------------------- * IUnaryFactorFunction */ @Override public double evalEnergy(Value value) { final double x = value.getDouble(); if (x < 0 | x > 1) { return Double.POSITIVE_INFINITY; } double y; if (_alphaMinusOne == 0.0) { if (_betaMinusOne == 0.0) { return 0.0; } else { y = Math.log(1 - x) * _betaMinusOne; } } else if (_betaMinusOne == 0.0) { y = Math.log(x) * _alphaMinusOne; } else { y = _alphaMinusOne * Math.log(x) + _betaMinusOne * Math.log(1 - x); } return -y; } /*-------------------- * IPrintable methods */ @Override public void print(PrintStream out, int verbosity) { if (verbosity >= 0) { String fmt; switch (verbosity) { case 0: fmt = "Beta(%g,%g)"; break; default: fmt = "Beta(alpha=%g, beta=%g)"; break; } out.format(fmt, getAlpha(), getBeta()); } } /*------------------------------- * IParameterizedMessage methods */ @Override public void addFrom(IParameterizedMessage other) { addFrom((BetaParameters)other); } public void addFrom(BetaParameters other) { _alphaMinusOne += other._alphaMinusOne; _betaMinusOne += other._betaMinusOne; } /** * {@inheritDoc} * <p> * Computes divergences as follows, where P is this, Q is that, and Β(x) and ψ(x) refer to the * beta and digamma functions respectively. * <p> * ln(Β(α<sub>Q</sub>, β<sub>Q</sub>)) * - ln(Β(α<sub>P</sub>, β<sub>P</sub>)) * + (α<sub>P</sub>-α<sub>Q</sub>)ψ(α<sub>P</sub>) * + (β<sub>P</sub>-β<sub>Q</sub>)ψ(β<sub>P</sub>) * + (α<sub>Q</sub>-α<sub>P</sub>+β<sub>Q</sub>-β<sub>P</sub>) * ψ(α<sub>P</sub>+β<sub>P</sub>) * * @see <a href="http://en.wikipedia.org/wiki/Beta_distribution#Quantities_of_information_.28entropy.29"> * Beta distribution (Wikipedia)</a> * @since 0.06 */ @Override public double computeKLDivergence(IParameterizedMessage that) { if (that instanceof BetaParameters) { // http://en.wikipedia.org/wiki/Beta_distribution#Quantities_of_information_.28entropy.29 // // KL(P|Q) == log(beta(aq,bq)/beta(ap,bp)) - (aq-ap)*digamma(ap) - (bq-bp)*digamma(bp) + // (aq - ap + bq - bp)*digamma(ap + aq) final BetaParameters P = this, Q = (BetaParameters)that; final double ap = P.getAlpha(), aq = Q.getAlpha(); final double bp = P.getBeta(), bq = Q.getBeta(); final double adiff = ap - aq; final double bdiff = bp - bq; double divergence = 0.0; if (adiff != 0 | bdiff !=0) { divergence += logBeta(aq,bq); divergence -= logBeta(ap,bp); if (adiff != 0.0) { divergence += adiff * digamma(ap); } if (bdiff != 0.0) { divergence += bdiff * digamma(bp); } final double ndiff = -adiff - bdiff; if (ndiff != 0.0) { divergence += ndiff * digamma(ap+bp); } } return divergence; } throw new IllegalArgumentException(String.format("Expected '%s' but got '%s'", getClass(), that.getClass())); } @Override public boolean isNull() { return _alphaMinusOne == 0 && _betaMinusOne == 0; } @Override public void setFrom(IParameterizedMessage other) { BetaParameters that = (BetaParameters)other; _alphaMinusOne = that._alphaMinusOne; _betaMinusOne = that._betaMinusOne; copyNormalizationEnergy(that); } /** * {@inheritDoc} * <p> * Sets alpha and beta parameters both to 1. */ @Override public final void setUniform() { _alphaMinusOne = 0; _betaMinusOne = 0; _normalizationEnergy = 0.0; } /*--------------- * Local methods */ // Natural parameters are alpha-1 and beta-1 public final double getAlphaMinusOne() {return _alphaMinusOne;} public final double getBetaMinusOne() {return _betaMinusOne;} public final void setAlphaMinusOne(double alphaMinusOne) { _alphaMinusOne = alphaMinusOne; forgetNormalizationEnergy(); } public final void setBetaMinusOne(double betaMinusOne) { _betaMinusOne = betaMinusOne; forgetNormalizationEnergy(); } // Ordinary parameters, alpha and beta public final double getAlpha() {return _alphaMinusOne + 1;} public final double getBeta() {return _betaMinusOne + 1;} public final void setAlpha(double alpha) { setAlphaMinusOne(alpha - 1); } public final void setBeta(double beta) { setBetaMinusOne(beta - 1); } /*------------------- * Protected methods */ @Override protected double computeNormalizationEnergy() { return -org.apache.commons.math3.special.Beta.logBeta(_alphaMinusOne + 1, _betaMinusOne + 1); } }