/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.core.parameterizedMessages;
import static org.junit.Assert.*;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.Beta;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.BetaParameters;
import com.analog.lyric.util.test.SerializationTester;
/**
*
* @since 0.06
* @author Christopher Barber
*/
public class TestBetaParameters extends TestParameterizedMessage
{
@Test
public void test()
{
BetaParameters msg = new BetaParameters();
assertInvariants(msg);
assertEquals(0.0, msg.getAlphaMinusOne(), 0.0);
assertEquals(0.0, msg.getBetaMinusOne(), 0.0);
msg.setAlpha(3.5);
assertEquals(3.5, msg.getAlpha(), 0.0);
msg.setBeta(2.75);
assertEquals(2.75, msg.getBeta(), 0.0);
assertInvariants(msg);
msg.setAlphaMinusOne(2.0);
assertEquals(3.0, msg.getAlpha(), 0.0);
msg.setBetaMinusOne(2.0);
assertEquals(3.0, msg.getBeta(), 0.0);
assertInvariants(msg);
msg.setNull();
assertEquals(0.0, msg.getAlphaMinusOne(), 0.0);
assertEquals(0.0, msg.getBetaMinusOne(), 0.0);
assertInvariants(msg);
msg.setUniform();
assertEquals(0.0, msg.getAlphaMinusOne(), 0.0);
assertEquals(0.0, msg.getBetaMinusOne(), 0.0);
msg.setBeta(2.0);
assertInvariants(msg);
msg.setAlpha(2.0);
msg.setBeta(1.0);
assertInvariants(msg);
//
// KL divergence tests
//
// Wikipedia example
msg.setAlpha(1.0);
msg.setBeta(1.0);
BetaParameters msg2 = new BetaParameters();
msg2.setAlpha(3.0);
msg2.setBeta(3.0);
assertEquals(0.598803, msg.computeKLDivergence(msg2), 1e-6);
assertEquals(0.267864, msg2.computeKLDivergence(msg), 1e-6);
// Wikipedia example
msg.setAlpha(3);
msg.setBeta(.5);
msg2.setAlpha(.5);
msg2.setBeta(3);
assertEquals(7.21574, msg.computeKLDivergence(msg2), 1e-5);
assertEquals(7.21574, msg2.computeKLDivergence(msg), 1e-5);
// Hand computed in MATLAB
msg.setAlpha(1);
msg.setBeta(2);
msg2.setAlpha(1);
msg2.setBeta(3);
assertEquals(0.09453489, msg.computeKLDivergence(msg2), 1e-7);
msg2.setAlpha(3);
msg2.setBeta(2);
assertEquals(1.20824053, msg.computeKLDivergence(msg2), 1e-7);
}
private void assertInvariants(BetaParameters msg)
{
assertGenericInvariants(msg);
assertEquals(msg.getAlpha() - 1, msg.getAlphaMinusOne(), 0.0);
assertEquals(msg.getBeta() - 1, msg.getBetaMinusOne(), 0.0);
assertEquals(msg.getAlpha() == 1.0 && msg.getBeta() == 1.0, msg.isNull());
BetaParameters msg2 = msg.clone();
assertNotSame(msg, msg2);
assertEquals(msg.getAlpha(), msg2.getAlpha(), 0.0);
assertEquals(msg.getBeta(), msg2.getBeta(), 0.0);
BetaParameters msg3 = SerializationTester.clone(msg);
assertEquals(msg.getAlpha(), msg3.getAlpha(), 0.0);
assertEquals(msg.getBeta(), msg3.getBeta(), 0.0);
// TODO - test against Beta factor function
Beta function = new Beta(msg.getAlpha(), msg.getBeta());
Value value = Value.createReal(1.5);
assertEquals(Double.POSITIVE_INFINITY, msg.evalEnergy(value), 0.0); // outside range [0,1]
value.setDouble(-1e-15);
assertEquals(Double.POSITIVE_INFINITY, msg.evalEnergy(value), 0.0); // outside range [0,1]
for (int i = 0; i < 10; ++i)
{
value.setDouble(testRand.nextDouble());
assertEquals(function.evalEnergy(value), msg.evalEnergy(value) - msg.getNormalizationEnergy(), 1e-15);
}
}
}