/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.core.parameterizedMessages; import static com.analog.lyric.util.test.ExceptionTester.*; import static org.junit.Assert.*; import org.junit.Test; import com.analog.lyric.collect.ArrayUtil; import com.analog.lyric.dimple.factorfunctions.MultivariateNormal; import com.analog.lyric.dimple.model.domains.RealJointDomain; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.MultivariateNormalParameters; import com.analog.lyric.util.test.SerializationTester; /** * * @since 0.06 * @author Christopher Barber */ public class TestMultivariateNormalParameters extends TestParameterizedMessage { @Test public void test() { assertInvariants(new MultivariateNormalParameters(0)); MultivariateNormalParameters msg = new MultivariateNormalParameters(3); assertTrue(msg.isNull()); assertFalse(msg.isInInformationForm()); assertInvariants(msg); assertEquals(3, msg.getInformationVector().length); assertFalse(msg.isInInformationForm()); assertInvariants(msg); double[] means = new double[2]; double[][] covariance = new double[2][2]; // mean covariance // | 5.0 | | 1.0 2.0 | Not positive definite // | 6.0 | | 2.0 1.0 | // means[0] = 5.0; means[1] = 6.0; // covariance[0][0] = 1.0; // covariance[1][1] = 1.0; // covariance[0][1] = 2.0; // covariance[1][0] = 2.0; // expectThrow(DimpleException.class, "Matrix is not positive definite", msg, "setMeanAndCovariance", // means, covariance); // // Message not changed due to exception // assertFalse(msg.isInInformationForm()); // assertTrue(msg.isNull()); // assertInvariants(msg); // mean covariance inverse // | 5.0 | | 2.0 1.0 | det = 3 | 2/3 -1/3 | // | 6.0 | | 1.0 2.0 | | -1/3 2/3 | // covariance[0][0] = 2.0; covariance[1][1] = 2.0; covariance[0][1] = 1.0; covariance[1][0] = 1.0; msg.setMeanAndCovariance(means, covariance); assertFalse(msg.isInInformationForm()); assertArrayEquals(means, msg.getMean(), 0.0); assertArrayEquals(covariance[0], msg.getCovariance()[0], 0.0); assertEquals(-2.3871832107434, msg.getNormalizationEnergy(), 1e-13); // computed by hand in MATLAB double[][] infoMatrix = msg.getInformationMatrix(); assertTrue(msg.isInInformationForm()); assertEquals(2.0/3.0, infoMatrix[0][0], 1e-10); assertEquals(2.0/3.0, infoMatrix[1][1], 1e-10); assertEquals(-1.0/3.0, infoMatrix[0][1], 1e-10); assertEquals(-1.0/3.0, infoMatrix[1][0], 1e-10); double[] infoVector = msg.getInformationVector(); assertArrayEquals(new double[] { 4.0/3.0, 7.0/3.0 }, infoVector, 1e-10); assertArrayEquals(means, msg.getMean(), 0.0); assertTrue(msg.isInInformationForm()); assertEquals(-2.3871832107434, msg.getNormalizationEnergy(), 1e-13); assertInvariants(msg); double[][] covariance2 = msg.getCovariance(); assertArrayEquals(covariance[0], covariance2[0], 1e-10); assertArrayEquals(covariance[1], covariance2[1], 1e-10); // mean covariance inverse // | 7.0 | | 4.0 2.0 | det = 8 | 3/8 -1/4 | // | 8.0 | | 2.0 3.0 | | -1/3 1/2 | // means[0] = 7.0; means[1] = 8.0; covariance[0][0] = 4.0; covariance[1][1] = 3.0; covariance[0][1] = 2.0; covariance[1][0] = 2.0; MultivariateNormalParameters msg2 = new MultivariateNormalParameters(means, covariance); assertInvariants(msg2); assertFalse(msg2.isDiagonal()); // Computed these by hand in MATLAB assertEquals(.865414626505863, msg.computeKLDivergence(msg2), 1e-10); assertEquals(1.50958537349414, msg2.computeKLDivergence(msg), 1e-10); expectThrow(IllegalArgumentException.class, "Incompatible vector sizes.*", msg, "computeKLDivergence", new MultivariateNormalParameters(4)); msg2.setUniform(); assertTrue(msg2.isNull()); assertArrayEquals(new double[2], msg2.getMean(), 0.0); assertTrue(msg2.isDiagonal()); assertEquals(0.0, msg2.evalEnergy(new double[] { 1, 2 }), 1e-6); assertEquals(0.0, msg2.evalEnergy(new double[] { 0, 0 }), 0.0); assertEquals(0.0, msg2.getNormalizationEnergy(), 0.0); assertArrayEquals(new double[2], msg2.getDiagonalPrecision(), MultivariateNormalParameters.MIN_EIGENVALUE); assertInvariants(msg2); msg2.setNull(); assertTrue(msg2.isNull()); assertTrue(msg2.isInInformationForm()); msg2.setInformation(means, covariance); assertArrayEquals(means, msg2.getInformationVector(), 1e-14); assertTrue(msg2.isInInformationForm()); assertInvariants(msg2); means = new double[] {1.0, 2.0}; double[] variance = new double[] { 2.0, 3.0 }; msg = new MultivariateNormalParameters(means, variance); assertInvariants(msg); assertArrayEquals(variance, msg.getDiagonalVariance(), 0.0); assertTrue(msg.isDiagonal()); covariance[0][0] = variance[0]; covariance[0][1] = 0.0; covariance[1][0] = 0.0; covariance[1][1] = variance[1]; msg2 = new MultivariateNormalParameters(means, covariance); assertInvariants(msg2); assertTrue(msg2.isDiagonal()); } private void assertInvariants(MultivariateNormalParameters msg) { assertGenericInvariants(msg); final int n = msg.getVectorLength(); if (n == 0) assertTrue(msg.isNull()); MultivariateNormalParameters msg2 = msg.clone(); assertEquals(msg.isInInformationForm(), msg2.isInInformationForm()); assertEquals(n, msg2.getVectorLength()); if (msg.isInInformationForm()) { assertArrayEquals(msg2.getInformationVector(), msg.getInformationVector(), 0.0); } else { assertArrayEquals(msg2.getMean(), msg.getMean(), 0.0); } MultivariateNormalParameters msg3 = SerializationTester.clone(msg); assertEquals(msg.isInInformationForm(), msg3.isInInformationForm()); assertEquals(n, msg3.getVectorLength()); if (msg.isInInformationForm()) { assertArrayEquals(msg.getInformationVector(), msg3.getInformationVector(), 0.0); } else { assertArrayEquals(msg.getMean(), msg3.getMean(), 0.0); } double[] means = msg2.getMean(); assertArrayEquals(msg2.getMeans(), means, 0.0); if (n > 0) assertNotSame(means, msg2.getMean()); else assertSame(ArrayUtil.EMPTY_DOUBLE_ARRAY, msg2.getMean()); assertEquals(n, means.length); double[] infoVector = msg2.getInformationVector(); assertEquals(n, infoVector.length); double[][] covariance = msg2.getCovariance(); assertFalse(msg2.isInInformationForm()); assertEquals(n, covariance.length); for (int row = 0; row < n; ++row) { assertEquals(n, covariance[row].length); for (int col = 0; col < n; ++col) { assertEquals(covariance[row][col], covariance[col][row], 1e-14); } } double[][] infoMatrix = msg2.getInformationMatrix(); assertTrue(msg2.isInInformationForm()); assertEquals(n, infoMatrix.length); for (int row = 0; row < n; ++row) { assertEquals(n, infoMatrix[row].length); for (int col = 0; col < n; ++col) { assertEquals(infoMatrix[row][col], infoMatrix[col][row], 1e-14); } } if (msg.isDiagonal()) { double[] precision = msg.getDiagonalPrecision(); double[] variance = msg.getDiagonalVariance(); assertEquals(n, precision.length); assertEquals(n, variance.length); for (int i = 0; i < n; ++i) { assertEquals(precision[i], 1.0/variance[i], 1e-15); double[] covarianceRow = covariance[i]; double[] infoRow = infoMatrix[i]; for (int j = 0; j < n; ++j) { if (i == j) { assertEquals(precision[i], infoRow[i], 1e-15); assertEquals(variance[i], covarianceRow[i], 1e-15); } else { assertEquals(0.0, covarianceRow[j], 0.0); assertEquals(0.0, infoRow[j], 0.0); } } } } else { assertSame(ArrayUtil.EMPTY_DOUBLE_ARRAY, msg.getDiagonalPrecision()); assertSame(ArrayUtil.EMPTY_DOUBLE_ARRAY, msg.getDiagonalVariance()); } if (n > 0) { final double normalizer = msg.getNormalizationEnergy(); MultivariateNormal function = new MultivariateNormal(msg.clone()); Value value = Value.create(RealJointDomain.create(n)); final double[] array = value.getDoubleArray(); for (int i = 0; i < 10; ++i) { for (int j = 0; j < n; ++j) { array[j] = testRand.nextDouble(); } final double expectedEnergy = function.evalEnergy(value); final double unnormalizedEnergy = msg.evalEnergy(value); final double normalizedEnergy = unnormalizedEnergy - normalizer; assertEquals(expectedEnergy, normalizedEnergy, Math.abs(expectedEnergy) / 1e-12); } } } }