package math; /** * @author Marc Suchard */ public class MultivariateNormalDistribution implements MultivariateDistribution { public static final String TYPE = "MultivariateNormal"; private final double[] mean; private final double[][] precision; private double[][] variance = null; private double[][] cholesky = null; private Double logDet = null; public MultivariateNormalDistribution(double[] mean, double[][] precision) { this.mean = mean; this.precision = precision; } public String getType() { return TYPE; } public double[][] getVariance() { if (variance == null) { variance = new SymmetricMatrix(precision).inverse().toComponents(); } return variance; } public double[][] getCholeskyDecomposition() { if (cholesky == null) { cholesky = getCholeskyDecomposition(getVariance()); } return cholesky; } public double getLogDet() { if (logDet == null) { logDet = Math.log(calculatePrecisionMatrixDeterminate(precision)); } return logDet; } public double[][] getScaleMatrix() { return precision; } public double[] getMean() { return mean; } public double[] nextMultivariateNormal() { return nextMultivariateNormalCholesky(mean, getCholeskyDecomposition(), 1.0); } public double[] nextMultivariateNormal(double[] x) { return nextMultivariateNormalCholesky(x, getCholeskyDecomposition(), 1.0); } // Scale lives in variance-space public double[] nextScaledMultivariateNormal(double[] mean, double scale) { return nextMultivariateNormalCholesky(mean, getCholeskyDecomposition(), Math.sqrt(scale)); } // Scale lives in variance-space public void nextScaledMultivariateNormal(double[] mean, double scale, double[] result) { nextMultivariateNormalCholesky(mean, getCholeskyDecomposition(), Math .sqrt(scale), result); } public static double calculatePrecisionMatrixDeterminate( double[][] precision) { try { return new Matrix(precision).determinant(); } catch (IllegalDimension e) { throw new RuntimeException(e.getMessage()); } } public double logPdf(double[] x) { return logPdf(x, mean, precision, getLogDet(), 1.0); } public static double logPdf(double[] x, double[] mean, double[][] precision, double logDet, double scale) { if (logDet == Double.NEGATIVE_INFINITY) return logDet; final int dim = x.length; final double[] delta = new double[dim]; final double[] tmp = new double[dim]; for (int i = 0; i < dim; i++) { delta[i] = x[i] - mean[i]; } for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { tmp[i] += delta[j] * precision[j][i]; } } double SSE = 0; for (int i = 0; i < dim; i++) SSE += tmp[i] * delta[i]; return dim * logNormalize + 0.5 * (logDet - dim * Math.log(scale) - SSE / scale); // There was // an error // here. // Variance = (scale * Precision^{-1}) } /* Equal precision, independent dimensions */ public static double logPdf(double[] x, double[] mean, double precision, double scale) { final int dim = x.length; double SSE = 0; for (int i = 0; i < dim; i++) { double delta = x[i] - mean[i]; SSE += delta * delta; } return dim * logNormalize + 0.5 * (dim * (Math.log(precision) - Math.log(scale)) - SSE * precision / scale); } private static double[][] getInverse(double[][] x) { return new SymmetricMatrix(x).inverse().toComponents(); } private static double[][] getCholeskyDecomposition(double[][] variance) { double[][] cholesky; try { cholesky = (new CholeskyDecomposition(variance)).getL(); } catch (IllegalDimension illegalDimension) { throw new RuntimeException( "Attempted Cholesky decomposition on non-square matrix"); } return cholesky; } public static double[] nextMultivariateNormalPrecision(double[] mean, double[][] precision) { return nextMultivariateNormalVariance(mean, getInverse(precision)); } public static double[] nextMultivariateNormalVariance(double[] mean, double[][] variance) { return nextMultivariateNormalVariance(mean, variance, 1.0); } public static double[] nextMultivariateNormalVariance(double[] mean, double[][] variance, double scale) { return nextMultivariateNormalCholesky(mean, getCholeskyDecomposition(variance), Math.sqrt(scale)); } public static double[] nextMultivariateNormalCholesky(double[] mean, double[][] cholesky) { return nextMultivariateNormalCholesky(mean, cholesky, 1.0); } public static double[] nextMultivariateNormalCholesky(double[] mean, double[][] cholesky, double sqrtScale) { double[] result = new double[mean.length]; nextMultivariateNormalCholesky(mean, cholesky, sqrtScale, result); return result; } public static void nextMultivariateNormalCholesky(double[] mean, double[][] cholesky, double sqrtScale, double[] result) { final int dim = mean.length; System.arraycopy(mean, 0, result, 0, dim); double[] epsilon = new double[dim]; for (int i = 0; i < dim; i++) epsilon[i] = MathUtils.nextGaussian() * sqrtScale; for (int i = 0; i < dim; i++) { for (int j = 0; j <= i; j++) { result[i] += cholesky[i][j] * epsilon[j]; // caution: decomposition returns lower triangular } } } // TODO should be a junit test public static void main(String[] args) { testPdf(); testRandomDraws(); } public static void testPdf() { double[] start = { 1, 2 }; double[] stop = { 0, 0 }; double[][] precision = { { 2, 0.5 }, { 0.5, 1 } }; double scale = 0.2; System.err.println("logPDF = " + logPdf(start, stop, precision, Math .log(calculatePrecisionMatrixDeterminate(precision)), scale)); System.err.println("Should = -19.94863\n"); System.err.println("logPDF = " + logPdf(start, stop, 2, 0.2)); System.err.println("Should = -24.53529\n"); } public static void testRandomDraws() { double[] start = { 1, 2 }; double[][] precision = { { 2, 0.5 }, { 0.5, 1 } }; int length = 100000; System.err.println("Random draws (via precision) ..."); double[] mean = new double[2]; double[] SS = new double[2]; double[] var = new double[2]; double ZZ = 0; for (int i = 0; i < length; i++) { double[] draw = nextMultivariateNormalPrecision(start, precision); for (int j = 0; j < 2; j++) { mean[j] += draw[j]; SS[j] += draw[j] * draw[j]; } ZZ += draw[0] * draw[1]; } for (int j = 0; j < 2; j++) { mean[j] /= length; SS[j] /= length; var[j] = SS[j] - mean[j] * mean[j]; } ZZ /= length; ZZ -= mean[0] * mean[1]; System.err.println("Mean: " + new Vector(mean)); System.err.println("TRUE: [ 1 2 ]\n"); System.err.println("MVar: " + new Vector(var)); System.err.println("TRUE: [ 0.571 1.14 ]\n"); System.err.println("Covv: " + ZZ); System.err.println("TRUE: -0.286"); } public static final double logNormalize = -0.5 * Math.log(2.0 * Math.PI); }