package beast.math.distributions; import org.apache.commons.math.MathException; import org.apache.commons.math.distribution.ContinuousDistribution; import org.apache.commons.math.distribution.Distribution; import beast.core.Description; import beast.core.Function; import beast.core.Input; import beast.core.Input.Validate; import beast.core.parameter.RealParameter; @Description("Dirichlet distribution. p(x_1,...,x_n;alpha_1,...,alpha_n) = 1/B(alpha) prod_{i=1}^K x_i^{alpha_i - 1} " + "where B() is the beta function B(alpha) = prod_{i=1}^K Gamma(alpha_i)/ Gamma(sum_{i=1}^K alpha_i}. ") public class Dirichlet extends ParametricDistribution { final public Input<RealParameter> alphaInput = new Input<>("alpha", "coefficients of the Dirichlet distribution", Validate.REQUIRED); @Override public void initAndValidate() { } @Override public Distribution getDistribution() { return null; } class DirichletImpl implements ContinuousDistribution { Double[] m_fAlpha; void setAlpha(Double[] alpha) { m_fAlpha = alpha; } @Override public double cumulativeProbability(double x) throws MathException { throw new MathException("Not implemented yet"); } @Override public double cumulativeProbability(double x0, double x1) throws MathException { throw new MathException("Not implemented yet"); } @Override public double inverseCumulativeProbability(double p) throws MathException { throw new MathException("Not implemented yet"); } @Override public double density(double x) { return Double.NaN; } @Override public double logDensity(double x) { return Double.NaN; } } // class DirichletImpl @Override public double calcLogP(Function pX) { Double[] alpha = alphaInput.get().getValues(); if (alphaInput.get().getDimension() != pX.getDimension()) { throw new IllegalArgumentException("Dimensions of alpha and x should be the same, but dim(alpha)=" + alphaInput.get().getDimension() + " and dim(x)=" + pX.getDimension()); } double logP = 0; double sumAlpha = 0; for (int i = 0; i < pX.getDimension(); i++) { double x = pX.getArrayValue(i); logP += (alpha[i] - 1) * Math.log(x); logP -= org.apache.commons.math.special.Gamma.logGamma(alpha[i]); sumAlpha += alpha[i]; } logP += org.apache.commons.math.special.Gamma.logGamma(sumAlpha); return logP; } }