/* * WishartDistribution.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.math.distributions; import dr.math.GammaFunction; import dr.math.MathUtils; import dr.math.matrixAlgebra.CholeskyDecomposition; import dr.math.matrixAlgebra.IllegalDimension; import dr.math.matrixAlgebra.Matrix; /** * @author Marc Suchard */ public class WishartDistribution implements MultivariateDistribution, WishartStatistics { public static final String TYPE = "Wishart"; private double df; private int dim; private double[][] scaleMatrix; private double[] Sinv; private Matrix SinvMat; private double logNormalizationConstant; /** * A Wishart distribution class for \nu degrees of freedom and scale matrix S * Expectation = \nu * S * * @param df degrees of freedom * @param scaleMatrix scaleMatrix */ public WishartDistribution(double df, double[][] scaleMatrix) { this.df = df; this.scaleMatrix = scaleMatrix; this.dim = scaleMatrix.length; SinvMat = new Matrix(scaleMatrix).inverse(); double[][] tmp = SinvMat.toComponents(); Sinv = new double[dim * dim]; for (int i = 0; i < dim; i++) { System.arraycopy(tmp[i], 0, Sinv, i * dim, dim); } computeNormalizationConstant(); } public WishartDistribution(int dim) { // returns a non-informative (unormalizable) density this.df = 0; this.scaleMatrix = null; this.dim = dim; logNormalizationConstant = 0.0; } private void computeNormalizationConstant() { logNormalizationConstant = computeNormalizationConstant(new Matrix(scaleMatrix), df, dim); } public static double computeNormalizationConstant(Matrix Sinv, double df, int dim) { if (df == 0) { return 0.0; } double logNormalizationConstant = 0; try { logNormalizationConstant = -df / 2.0 * Math.log(Sinv.determinant()); } catch (IllegalDimension illegalDimension) { illegalDimension.printStackTrace(); } logNormalizationConstant -= df * dim / 2.0 * Math.log(2); logNormalizationConstant -= dim * (dim - 1) / 4.0 * Math.log(Math.PI); for (int i = 1; i <= dim; i++) { logNormalizationConstant -= GammaFunction.lnGamma((df + 1 - i) / 2.0); } return logNormalizationConstant; } public String getType() { return TYPE; } public double[][] getScaleMatrix() { return scaleMatrix; } public double[] getMean() { return null; } public void testMe() { int length = 100000; double save1 = 0; double save2 = 0; double save3 = 0; double save4 = 0; for (int i = 0; i < length; i++) { double[][] draw = nextWishart(); save1 += draw[0][0]; save2 += draw[0][1]; save3 += draw[1][0]; save4 += draw[1][1]; } save1 /= length; save2 /= length; save3 /= length; save4 /= length; System.err.println("S1: " + save1); System.err.println("S2: " + save2); System.err.println("S3: " + save3); System.err.println("S4: " + save4); } public double getDF() { return df; } public double[][] nextWishart() { return nextWishart(df, scaleMatrix); } /** * Generate a random draw from a Wishart distribution * Follows Odell and Feiveson (1996) JASA 61, 199-203 * <p/> * Returns a random variable with expectation = df * scaleMatrix * * @param df degrees of freedom * @param scaleMatrix scaleMatrix * @return a random draw */ public static double[][] nextWishart(double df, double[][] scaleMatrix) { int dim = scaleMatrix.length; double[][] draw = new double[dim][dim]; double[][] z = new double[dim][dim]; for (int i = 0; i < dim; i++) { for (int j = 0; j < i; j++) { z[i][j] = MathUtils.nextGaussian(); } } for (int i = 0; i < dim; i++) z[i][i] = Math.sqrt(MathUtils.nextGamma((df - i) * 0.5, 0.5)); // sqrt of chisq with df-i dfs double[][] cholesky = new double[dim][dim]; for (int i = 0; i < dim; i++) { for (int j = i; j < dim; j++) cholesky[i][j] = cholesky[j][i] = scaleMatrix[i][j]; } try { cholesky = (new CholeskyDecomposition(cholesky)).getL(); // caution: this returns the lower triangular form } catch (IllegalDimension illegalDimension) { throw new RuntimeException("Numerical exception in WishartDistribution"); } double[][] result = new double[dim][dim]; for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { // lower triangular for (int k = 0; k < dim; k++) // can also be shortened result[i][j] += cholesky[i][k] * z[k][j]; } } for (int i = 0; i < dim; i++) { // lower triangular, so more efficiency is possible for (int j = 0; j < dim; j++) { for (int k = 0; k < dim; k++) draw[i][j] += result[i][k] * result[j][k]; // transpose of 2nd element } } return draw; } public double logPdf(double[] x) { if (x.length == 4) { // bivariate return logPdf2D(x, Sinv, df, dim, logNormalizationConstant); } else { return logPdfSlow(x); } } public double logPdfSlow(double[] x) { Matrix W = new Matrix(x, dim, dim); return logPdf(W, SinvMat, df, dim, logNormalizationConstant); } public static double logPdf2D(double[] W, double[] Sinv, double df, int dim, double logNormalizationConstant) { final double det = W[0] * W[3] - W[1] * W[2]; if (det <= 0) { return Double.NEGATIVE_INFINITY; } double logDensity = Math.log(det); logDensity *= 0.5 * (df - dim - 1); // logDensity -= 0.5 * tr(Sinv %*% W) final double trace = Sinv[0] * W[0] + Sinv[1] * W[2] + Sinv[2] * W[1] + Sinv[3] * W[3]; logDensity -= 0.5 * trace; logDensity += logNormalizationConstant; return logDensity; } public static double logPdf(Matrix W, Matrix Sinv, double df, int dim, double logNormalizationConstant) { double logDensity = 0; try { // if (!W.isPD()) { // TODO isPD() does not appear to work // return Double.NEGATIVE_INFINITY; // } logDensity = W.logDeterminant(); // Returns NaN is W is not positive-definite. if (Double.isInfinite(logDensity) || Double.isNaN(logDensity)) { return Double.NEGATIVE_INFINITY; } logDensity *= 0.5; logDensity *= df - dim - 1; // need only diagonal, no? seems a waste to compute // the whole matrix if (Sinv != null) { Matrix product = Sinv.product(W); for (int i = 0; i < dim; i++) logDensity -= 0.5 * product.component(i, i); } } catch (IllegalDimension illegalDimension) { illegalDimension.printStackTrace(); } logDensity += logNormalizationConstant; return logDensity; } public static void testBivariateMethod() { System.out.println("Testing new computations ..."); WishartDistribution wd = new WishartDistribution(5, new double[][]{{2.0, -0.5}, {-0.5, 2.0}}); double[] W = new double[]{4.0, 1.0, 1.0, 3.0}; System.out.println("Fast logPdf = " + wd.logPdf(W)); System.out.println("Slow logPdf = " + wd.logPdfSlow(W)); } public static void main(String[] argv) { WishartDistribution wd = new WishartDistribution(2, new double[][]{{500.0}}); // The above is just an approximation GammaDistribution gd = new GammaDistribution(1.0 / 1000.0, 1000.0); double[] x = new double[]{1.0}; System.out.println("Wishart, df=2, scale = 500, PDF(1.0): " + wd.logPdf(x)); System.out.println("Gamma, shape = 1/1000, scale = 1000, PDF(1.0): " + gd.logPdf(x[0])); wd = new WishartDistribution(4, new double[][]{{5.0}}); gd = new GammaDistribution(2.0, 10.0); x = new double[]{1.0}; System.out.println("Wishart, df=4, scale = 5, PDF(1.0): " + wd.logPdf(x)); System.out.println("Gamma, shape = 1/1000, scale = 10, PDF(1.0): " + gd.logPdf(x[0])); // These tests show the correspondence between a 1D Wishart and a Gamma wd = new WishartDistribution(1); x = new double[]{0.1}; System.out.println("Wishart, uninformative, PDF(0.1): " + wd.logPdf(x)); x = new double[]{1.0}; System.out.println("Wishart, uninformative, PDF(1.0): " + wd.logPdf(x)); x = new double[]{10.0}; System.out.println("Wishart, uninformative, PDF(10.0): " + wd.logPdf(x)); // These tests show the correspondence between a 1D Wishart and a Gamma testBivariateMethod(); } }