/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.core.parameterizedMessages; import static com.analog.lyric.util.test.ExceptionTester.*; import static org.junit.Assert.*; import java.util.Arrays; import org.junit.Test; import com.analog.lyric.dimple.factorfunctions.Dirichlet; import com.analog.lyric.dimple.model.domains.RealJointDomain; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DirichletParameters; import com.analog.lyric.util.test.SerializationTester; /** * * @since 0.06 * @author Christopher Barber */ public class TestDirichletParameters extends TestParameterizedMessage { @Test public void test() { DirichletParameters msg = new DirichletParameters(); assertEquals(0, msg.getSize()); assertInvariants(msg); expectThrow(ArrayIndexOutOfBoundsException.class, msg, "getAlpha", 0); msg = new DirichletParameters(3); assertEquals(3, msg.getSize()); for (int i = 0; i < 3; ++i) { assertEquals(1.0, msg.getAlpha(i), 0.0); } assertInvariants(msg); DirichletParameters msg2 = new DirichletParameters(new double[] {2,2,2}); assertEquals(3, msg2.getSize()); for (int i = 0; i < 3; ++i) { assertEquals(3.0, msg2.getAlpha(i), 0.0); } assertInvariants(msg2); assertEquals(1.16798581949, msg.computeKLDivergence(msg2), 1e-9); assertEquals(0.5248713233626, msg2.computeKLDivergence(msg), 1e-9); msg2.setSize(4); assertEquals(4, msg2.getSize()); for (int i = 0; i < 4; ++i) { assertEquals(1.0, msg2.getAlpha(i), 0.0); } expectThrow(IllegalArgumentException.class, "Incompatible Dirichlet sizes.*", msg, "computeKLDivergence", msg2); msg2.add(0, 1.2); assertEquals(2.2, msg2.getAlpha(0), 1e-15); assertEquals(1.0, msg2.getAlpha(1), 0.0); msg2.increment(0); msg2.increment(1); assertEquals(3.2, msg2.getAlpha(0), 1e-15); assertEquals(2, msg2.getAlpha(1), 0.0); assertEquals(1, msg2.getAlpha(2), 0.0); msg2.fillAlphaMinusOne(4); for (int i = 0; i < msg2.getSize(); ++i) { assertEquals(5, msg2.getAlpha(i), 0.0); } expectThrow(IllegalArgumentException.class, msg2, "add", msg); msg2.add(new DirichletParameters(4, 1.5)); for (int i = 0; i < 4; ++i) { assertEquals(6.5, msg2.getAlpha(i), 1e-15); } msg2.setAlphaMinusOne(new double[] { 10, 11, 12 }); assertEquals(3, msg2.getSize()); msg2.add(new int[] { 1, 2, 3}); assertEquals(11, msg2.getAlphaMinusOne(0), 0.0); assertEquals(13, msg2.getAlphaMinusOne(1), 0.0); assertEquals(15, msg2.getAlphaMinusOne(2), 0.0); assertFalse(msg2.isSymmetric()); msg2.setNull(); assertEquals(3, msg2.getSize()); for (int i = 0; i < 3; ++i) { assertEquals(0.0, msg2.getAlphaMinusOne(i), 0.0); } assertTrue(msg2.isSymmetric()); msg2.setUniform(); assertEquals(3, msg2.getSize()); for (int i = 0; i < 3; ++i) { assertEquals(0.0, msg2.getAlphaMinusOne(i), 0.0); } assertTrue(msg2.isSymmetric()); msg2.setAlphaMinusOne(new double[] { 1, 2, 3}); for (int i = 0; i < 3; ++i) { assertEquals(i + 1, msg2.getAlphaMinusOne(i), 0.0); } } private void assertInvariants(DirichletParameters msg) { final int n = msg.getSize(); assertGenericInvariants(msg); DirichletParameters msg2 = msg.clone(); assertNotSame(msg2, msg); assertEquals(n, msg2.getSize()); for (int i = 0; i < n; ++i) { assertEquals(msg.getAlpha(i), msg2.getAlpha(i), 0.0); } assertEquals(msg.isSymmetric(), msg2.isSymmetric()); DirichletParameters msg3 = SerializationTester.clone(msg); assertNotSame(msg3, msg); assertEquals(n, msg3.getSize()); for (int i = 0; i < n; ++i) { assertEquals(msg.getAlpha(i), msg3.getAlpha(i), 0.0); } for (int i = 0; i < n; ++i) { assertEquals(msg.getAlpha(i) - 1, msg.getAlphaMinusOne(i), 1e-14); } if (n > 0) { Value value = Value.create(RealJointDomain.create(n)); Dirichlet factor = new Dirichlet(msg.getAlphas()); boolean isNull = msg.isNull(); for (int i = 0; i < 10; ++i) { double[] array = value.getDoubleArray(); double sum = 0.0; for (int j = n; --j>=0;) sum += array[j] = testRand.nextDouble(); for (int j = n; --j>=0;) array[j] /= sum; assertEquals(factor.evalEnergy(value), msg.evalEnergy(value)- msg.getNormalizationEnergy(), 1e-12); if (isNull) assertEquals(0.0, msg.evalEnergy(value), 0.0); } Arrays.fill(value.getDoubleArray(), 2.0); // not on probability simplex assertEquals(Double.POSITIVE_INFINITY, msg.evalEnergy(value), 0.0); value.getDoubleArray()[0] = -2*n - 1; assertEquals(Double.POSITIVE_INFINITY, msg.evalEnergy(value), 0.0); } } }