/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.core.parameterizedMessages;
import static com.analog.lyric.util.test.ExceptionTester.*;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import org.junit.Test;
import com.analog.lyric.dimple.exceptions.InvalidDistributionException;
import com.analog.lyric.dimple.factorfunctions.Normal;
import com.analog.lyric.dimple.model.domains.RealDomain;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.IParameterizedMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.NormalParameters;
import com.analog.lyric.util.test.SerializationTester;
/**
*
* @since 0.06
* @author Christopher Barber
*/
public class TestNormalParameters extends TestParameterizedMessage
{
@Test
public void test()
{
NormalParameters msg = new NormalParameters();
assertInvariants(msg);
assertEquals(0.0, msg.getMean(), 0.0);
assertEquals(0.0, msg.getPrecision(), 0.0);
msg.setPrecision(Double.POSITIVE_INFINITY);
assertEquals(Double.POSITIVE_INFINITY, msg.getPrecision(), 0.0);
assertInvariants(msg);
msg.setMean(4.2);
assertEquals(4.2, msg.getMean(), 0.0);
msg.setVariance(1.5);
assertEquals(1.5, msg.getVariance(), 1e-14);
assertInvariants(msg);
msg.setStandardDeviation(2.3);
assertEquals(2.3, msg.getStandardDeviation(), 1e-14);
NormalParameters msg2 = new NormalParameters(10.0, 2.0);
assertEquals(10.0, msg2.getMean(), 0.0);
assertEquals(2.0, msg2.getPrecision(), 0.0);
msg.set(msg2);
assertEquals(10.0, msg.getMean(), 0.0);
assertEquals(2.0, msg2.getPrecision(), 0.0);
// if precisions are equal then KL is simply one half the precision of the second message
// times the squared difference of the means.
msg.setMean(9.0);
assertEquals(1, msg.computeKLDivergence(msg2), 1e-14);
assertEquals(1, msg2.computeKLDivergence(msg), 1e-14);
msg.setMean(12.0);
assertEquals(4, msg.computeKLDivergence(msg2), 1e-14);
msg2.setMean(12);
msg2.setPrecision(10);
assertEquals((4 - Math.log(5))/2, msg.computeKLDivergence(msg2), 1e-14);
msg.setNull();
assertEquals(0.0, msg.getMean(), 0.0);
assertEquals(0.0, msg.getPrecision(), 0.0);
msg2.setUniform();
assertEquals(0.0, msg2.getMean(), 0.0);
assertEquals(0.0, msg2.getPrecision(), 0.0);
//
// Test addFrom
//
// Adding null message doesn't change anything
msg.addFrom(msg2);
assertEquals(0.0, msg.getMean(), 0.0);
assertEquals(0.0, msg.getPrecision(), 0.0);
msg2.setMean(1.0);
msg2.setPrecision(2.0);
msg.addFrom((IParameterizedMessage)msg2);
assertEquals(1.0, msg.getMean(), 0.0);
assertEquals(2.0, msg.getPrecision(), 0.0);
msg.addFrom(msg2);
assertEquals(1.0, msg.getMean(), 0.0);
assertEquals(4.0, msg.getPrecision(), 0.0);
msg2.setMean(2.0);
msg2.setPrecision(.5);
msg.addFrom(msg2);
assertEquals(4.5, msg.getPrecision(), 0.0);
assertEquals(5/4.5, msg.getMean(), 1e-15);
msg2.setDeterministic(45);
msg.addFrom(msg2);
msg.addFrom(msg2);
msg2.setMean(-3);
msg2.setPrecision(100);
msg.addFrom(msg2); // has no effect
assertEquals(45, msg.getMean(), 0.0);
assertEquals(Double.POSITIVE_INFINITY, msg.getPrecision(), 0.0);
msg2.setDeterministic(Value.createReal(44));
expectThrow(InvalidDistributionException.class, msg, "addFrom", msg2);
//
// Other errors
//
expectThrow(IllegalArgumentException.class, msg, "setStandardDeviation", -1.0);
}
private void assertInvariants(NormalParameters message)
{
assertGenericInvariants(message);
assertEquals(1.0/message.getPrecision(), message.getVariance(), 1e-14);
assertEquals(Math.sqrt(message.getVariance()), message.getStandardDeviation(), 1e-14);
NormalParameters message2 = message.clone();
assertEquals(message.getPrecision(), message2.getPrecision(), 0.0);
assertEquals(message.getMean(), message2.getMean(), 0.0);
NormalParameters message3 = SerializationTester.clone(message);
assertEquals(message.getPrecision(), message3.getPrecision(), 0.0);
assertEquals(message.getMean(), message3.getMean(), 0.0);
assertEquals(message.getPrecision() == 0.0, message.isNull());
Value value = Value.createReal(0.0);
if (message.getPrecision() == 0.0)
{
assertEquals(0.0, message.evalEnergy(value), 0.0);
value.setDouble(testRand.nextDouble());
assertEquals(0.0, message.evalEnergy(value), 0.0);
}
else
{
Normal normal = new Normal(message.getMean(), message.getPrecision());
for (int i = 0; i < 10; ++i)
{
value.setDouble(testRand.nextDouble());
assertEquals(normal.evalEnergy(value), message.evalEnergy(value) - message.getNormalizationEnergy(),
1e-15);
}
}
if (message.getPrecision() == Double.POSITIVE_INFINITY)
{
assertTrue(message.hasDeterministicValue());
assertEquals(message.getMean(), message.toDeterministicValue(), 0.0);
assertEquals(message.getMean(),
requireNonNull(message.toDeterministicValue(RealDomain.unbounded())).getDouble(), 0.0);
}
else
{
assertFalse(message.hasDeterministicValue());
assertTrue(Double.isNaN(message.toDeterministicValue()));
assertNull(message.toDeterministicValue(RealDomain.unbounded()));
}
}
}