/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.core.parameterizedMessages;
import static com.analog.lyric.math.Utilities.*;
import static com.analog.lyric.util.test.ExceptionTester.*;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.eclipse.jdt.annotation.Nullable;
import org.junit.Test;
import com.analog.lyric.dimple.data.DataRepresentationType;
import com.analog.lyric.dimple.data.IDatum;
import com.analog.lyric.dimple.exceptions.NormalizationException;
import com.analog.lyric.dimple.factorfunctions.core.IUnaryFactorFunction;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteEnergyMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.DiscreteWeightMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.IParameterizedMessage;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.NormalParameters;
import com.analog.lyric.util.test.SerializationTester;
import com.google.common.primitives.Doubles;
/**
* Unit test for {@link DiscreteMessage} implementations
* @since 0.06
* @author Christopher Barber
*/
public class TestDiscreteMessage extends TestParameterizedMessage
{
@Test
public void test()
{
DiscreteMessage msg = new DiscreteWeightMessage(10);
assertInvariants(msg);
for (int i = msg.size(); --i>=0;)
{
assertEquals(1.0, msg.getWeight(i), 1e-14);
msg.setWeight(i,i);
assertEquals(i, msg.getWeight(i), 0.0);
assertEquals(weightToEnergy(i), msg.getEnergy(i), 1e-14);
msg.setEnergy(i,i);
assertEquals(i, msg.getEnergy(i), 0.0);
}
assertEquals(weightToEnergy(msg.sumOfWeights()), msg.getNormalizationEnergy(), 0.0);
double sum = msg.sumOfWeights();
msg.setNormalizationEnergy(0.0);
msg.normalize();
assertEquals(1.0, msg.sumOfWeights(), 1e-10);
assertInvariants(msg);
assertEquals(sum, energyToWeight(msg.getNormalizationEnergy()), 1e-10);
Arrays.fill(msg.representation(), 23);
msg.setNull();
for (int i= msg.size(); --i>= 0;)
{
assertEquals(0.0, msg.getEnergy(i), 0.0);
}
Arrays.fill(msg.representation(), 23);
msg.setUniform();
for (int i = msg.size(); --i>= 0;)
{
assertEquals(1.0 / msg.size(), msg.getWeight(i), 1e-14);
}
msg = new DiscreteWeightMessage(new double[] { 4, 5, 6 });
assertInvariants(msg);
assertEquals(4, msg.getWeight(0), 0.0);
assertEquals(5, msg.getWeight(1), 0.0);
assertEquals(6, msg.getWeight(2), 0.0);
msg.setDeterministicIndex(1);
assertEquals(1, msg.toDeterministicValueIndex());
assertInvariants(msg);
msg = new DiscreteEnergyMessage(10);
assertInvariants(msg);
for (int i = msg.size(); --i>=0;)
{
assertEquals(1.0, msg.getWeight(i), 1e-14);
msg.setWeight(i,i);
assertEquals(i, msg.getWeight(i), 1e-14);
assertEquals(weightToEnergy(i), msg.getEnergy(i), 1e-14);
msg.setEnergy(i,i);
assertEquals(i, msg.getEnergy(i), 0.0);
}
Arrays.fill(msg.representation(), 23);
msg.setNull();
for (int i= msg.size(); --i>= 0;)
{
assertEquals(0, msg.getEnergy(i), 0.0);
}
Arrays.fill(msg.representation(), 23);
msg.setUniform();
for (int i= msg.size(); --i>= 0;)
{
assertEquals(0, msg.getEnergy(i), 0.0);
}
msg.setDeterministic(Value.create(DiscreteDomain.range(1, msg.size()),2));
assertEquals(1, msg.toDeterministicValueIndex());
assertInvariants(msg);
msg = new DiscreteEnergyMessage(new double[] { 4, 5, 6 });
assertInvariants(msg);
assertEquals(4, msg.getEnergy(0), 0.0);
assertEquals(5, msg.getEnergy(1), 0.0);
assertEquals(6, msg.getEnergy(2), 0.0);
sum = msg.sumOfWeights();
msg.setNormalizationEnergy(0.0);
msg.normalize();
assertEquals(sum, energyToWeight(msg.getNormalizationEnergy()), 1e-10);
msg = new DiscreteWeightMessage(10);
DiscreteEnergyMessage msg2 = new DiscreteEnergyMessage(10);
Random rand = new Random(42);
for (int i = 0; i < msg.size(); ++i)
{
msg.setWeight(i, rand.nextDouble());
msg2.setWeight(i, rand.nextDouble());
}
assertEquals(expectedKL(msg, msg2), msg.computeKLDivergence(msg2), 1e-14);
assertInvariants(msg);
assertInvariants(msg2);
msg2.normalizeEnergy();
assertInvariants(msg2);
assertEquals(0.0, Doubles.min(msg2.representation()), 0.0);
expectThrow(IllegalArgumentException.class, msg, "computeKLDivergence", new DiscreteEnergyMessage(3));
}
/**
* Test {@link DiscreteEnergyMessage#createFrom}/{@linkplain DiscreteEnergyMessage#convertFrom convertFrom}.
*
* @since 0.08
*/
@SuppressWarnings("unchecked")
@Test
public void testCreateFrom()
{
DiscreteDomain domain = DiscreteDomain.range(1,3);
DiscreteEnergyMessage msg123 = new DiscreteEnergyMessage(new double[] {1,2,3});
List<? extends IDatum> empty = Collections.emptyList();
assertNull(DiscreteEnergyMessage.convertFrom(domain, empty));
assertSame(msg123, DiscreteEnergyMessage.convertFrom(domain, Arrays.asList(msg123)));
DiscreteEnergyMessage msg = DiscreteEnergyMessage.createFrom(domain, Arrays.asList(msg123));
assertNotSame(msg123, msg);
assertTrue(msg123.objectEquals(msg));
Value value = Value.createWithIndex(domain, 1);
msg = DiscreteEnergyMessage.convertFrom(domain, Arrays.asList(value));
assertEquals(1, requireNonNull(msg).toDeterministicValueIndex());
assertEquals(0.0, msg.evalEnergy(value), 0.0);
msg = DiscreteEnergyMessage.createFrom(domain, Arrays.asList(msg123, value));
assertEquals(1, requireNonNull(msg).toDeterministicValueIndex());
// Still deterministic, but incorporates energy from msg123
assertEquals(2.0, msg.evalEnergy(value), 0.0);
msg = DiscreteEnergyMessage.createFrom(domain, Arrays.asList(value, msg123));
assertEquals(1, requireNonNull(msg).toDeterministicValueIndex());
// Elements after value are ignored:
assertEquals(0.0, msg.evalEnergy(value), 0.0);
DiscreteEnergyMessage msg456 = new DiscreteEnergyMessage(new double[] {4, 5, 6});
msg = DiscreteEnergyMessage.createFrom(domain, Arrays.asList(msg456, msg123));
assertArrayEquals(new double[] { 5, 7, 9 }, requireNonNull(msg).getEnergies(), 0.0);
}
private double expectedKL(DiscreteMessage msg1, DiscreteMessage msg2)
{
double KL = 0.0;
double total1 = 0.0, total2 = 0.0;
for (int i = msg1.size(); --i>=0;)
{
total1 += msg1.getWeight(i);
total2 += msg2.getWeight(i);
}
for (int i = msg1.size(); --i>=0;)
{
double p = msg1.getWeight(i)/total1;
double q = msg2.getWeight(i)/total2;
KL += p * Math.log(p/q);
}
assertTrue(KL >= 0.0);
return KL;
}
private void assertInvariants(DiscreteMessage message)
{
assertGenericInvariants(message);
assertFalse(message.objectEquals(null));
assertFalse(message.objectEquals("bogus"));
assertTrue(message.objectEquals(message));
final int size = message.size();
final Value objValue = Value.create("");
final Value discreteValue = Value.create(DiscreteDomain.range(1, size));
for (int i = 0; i < size; ++i)
{
assertEquals(message.getWeight(i), energyToWeight(message.getEnergy(i)), 1e-15);
objValue.setObject(i);
assertEquals(message.getEnergy(i), message.evalEnergy(objValue), 0.0);
discreteValue.setIndex(i);
assertEquals(message.getEnergy(i), message.evalEnergy(discreteValue), 0.0);
}
if (message.storesWeights())
{
for (int i = 0; i < size; ++i)
{
assertEquals(message.getWeight(i), message.representation()[i], 0.0);
}
}
else
{
for (int i = 0; i < size; ++i)
{
assertEquals(message.getEnergy(i), message.representation()[i], 0.0);
}
if (message instanceof DiscreteEnergyMessage)
{
assertEquals(Doubles.min(message.representation()), ((DiscreteEnergyMessage)message).minEnergy(), 0.0);
}
}
int onlyNonZeroWeightIndex = -1;
for (int i = 0; i < size; ++i)
{
final double w = message.getWeight(i);
if (w == 0)
{
assertTrue(message.hasZeroWeight(i));
}
else
{
assertFalse(message.hasZeroWeight(i));
if (onlyNonZeroWeightIndex < 0)
onlyNonZeroWeightIndex = i;
else
{
onlyNonZeroWeightIndex = -1;
break;
}
}
}
if (onlyNonZeroWeightIndex >= 0)
{
assertTrue(message.hasDeterministicValue());
assertEquals(onlyNonZeroWeightIndex, message.toDeterministicValueIndex());
assertEquals(onlyNonZeroWeightIndex,
requireNonNull(message.toDeterministicValue(DiscreteDomain.range(1, size))).getIndex());
}
else
{
assertFalse(message.hasDeterministicValue());
assertNull(message.toDeterministicValue(DiscreteDomain.range(1, size)));
assertEquals(-1, message.toDeterministicValueIndex());
}
expectThrow(ArrayIndexOutOfBoundsException.class, message, "getWeight", -1);
expectThrow(ArrayIndexOutOfBoundsException.class, message, "getWeight", size);
expectThrow(ClassCastException.class, message, "setFrom", new NormalParameters());
expectThrow(IllegalArgumentException.class, message, "setWeights", new double[size+1]);
expectThrow(IllegalArgumentException.class, message, "setEnergies", new double[size+1]);
expectThrow(IllegalArgumentException.class, ".* is not discrete", message, "setDeterministic",
Value.create("hi"));
DiscreteMessage message2 = message.clone();
assertTrue(message.objectEquals(message2));
assertEquals(message.size(), message2.size());
for (int i = 0; i < size; ++i)
{
assertEquals(message.getWeight(i), message2.getWeight(i), 0.0);
}
assertEquals(message.getNormalizationEnergy(), message2.getNormalizationEnergy(), 0.0);
double prevDenormalizer = message2.getNormalizationEnergy();
if (prevDenormalizer != message2.getNormalizationEnergy())
{
assertFalse(message.objectEquals(message2));
}
message2.setNormalizationEnergy(42);
assertEquals(42, message2.getNormalizationEnergy(), 0.0);
DiscreteMessage message3 = SerializationTester.clone(message);
assertTrue(message.objectEquals(message));
assertEquals(message.size(), message3.size());
for (int i = 0; i < size; ++i)
{
assertEquals(message.getWeight(i), message3.getWeight(i), 0.0);
}
assertEquals(message.getNormalizationEnergy(), message3.getNormalizationEnergy(), 0.0);
message3.setWeightsToZero();
assertEquals(0.0, message3.sumOfWeights(), 0.0);
for (int i = message3.size(); --i>=0;)
{
assertEquals(0.0, message3.getWeight(i), 0.0);
assertEquals(Double.POSITIVE_INFINITY, message3.getEnergy(i), 0.0);
}
assertEquals(Double.POSITIVE_INFINITY, message3.getNormalizationEnergy(), 0.0);
expectThrow(NormalizationException.class, ".*weights add up to zero", message3, "normalize");
message3.setFrom(message);
assertTrue(message.objectEquals(message3));
message3.setNormalizationEnergy(Double.NaN); // unset
message3.normalize();
assertEquals(1.0, message3.sumOfWeights(), 1e-15);
assertEquals(0.0, message3.getNormalizationEnergy(), 0.0);
message3.setFrom(message);
message3.setNormalizationEnergy(0.0);
message3.normalize();
assertEquals(1.0, message3.sumOfWeights(), 1e-15);
assertEquals(weightToEnergy(message.sumOfWeights()), message3.getNormalizationEnergy(), 1e-15);
message3.setNull();
assertEquals(weightToEnergy(size), message3.getNormalizationEnergy(), 0.0);
for (int i = 0; i < size; ++i)
{
discreteValue.setIndex(i);
message3.evalEnergy(discreteValue);
}
if (message3.storesWeights())
{
assertEquals(size, message3.sumOfWeights(), 0.0);
message3.setWeights(message.representation());
}
else
{
assertEquals(size, message3.sumOfWeights(), 0.0);
message3.setEnergies(message.representation());
}
message3.setNormalizationEnergy(message.getNormalizationEnergy());
assertTrue(message3.objectEquals(message));
message3.setNormalizationEnergy(message3.getNormalizationEnergy() + 1.0);
assertFalse(message.objectEquals(message3));
assertEquals(message.getNormalizationEnergy() + 1.0, message3.getNormalizationEnergy(), 0.0);
message3.setWeight(0, message3.getWeight(0) + 1.0);
assertFalse(message.objectEquals(message3));
message3.setNormalizationEnergy(message.getNormalizationEnergy());
assertFalse(message.objectEquals(message3));
message3.setFrom((IParameterizedMessage)message);
assertTrue(message.objectEquals(message3));
message3.addWeightsFrom(message);
for (int i = size; --i>=0;)
{
assertEquals(message.getWeight(i) * 2, message3.getWeight(i), 1e-10);
}
message3.setNormalizationEnergy(Double.NaN); // unset normalization energy
message3.normalize();
assertEquals(0.0, message3.getNormalizationEnergy(), 0.0);
message3.setNull();
message3.addEnergiesFrom(message);
assertArrayEquals(message.getEnergies(), message3.getEnergies(), 0.0);
message3.addFrom(message);
for (int i = size; --i>=0;)
{
assertEquals(message.getEnergy(i) * 2, message3.getEnergy(i), 1e-10);
}
DiscreteMessage message4 = message.storesWeights() ? new DiscreteEnergyMessage(size) :
new DiscreteWeightMessage(size);
assertFalse(message.objectEquals(message4));
message4.setFrom(message);
assertFalse(message.objectEquals(message4));
assertEquals(message.getNormalizationEnergy(), message4.getNormalizationEnergy(), 0.0);
assertEquals(0.0, message.computeKLDivergence(message4), 1e-10);
message4.addWeightsFrom(message);
for (int i = size; --i>=0;)
{
assertEquals(message.getWeight(i) * 2, message4.getWeight(i), 1e-10);
}
message4.setNull();
message4.addEnergiesFrom(message);
assertArrayEquals(message.getEnergies(), message4.getEnergies(), 1e-15);
message4.addFrom(message);
for (int i = size; --i>=0;)
{
assertEquals(message.getEnergy(i) * 2, message4.getEnergy(i), 1e-10);
}
DiscreteMessage message5 = new DiscreteEnergyMessage(message4);
assertArrayEquals(message4.getWeights(), message5.getWeights(), 1e-15);
message5 = new DiscreteWeightMessage(message4);
assertArrayEquals(message4.getWeights(), message5.getWeights(), 1e-15);
DiscreteDomain domain = DiscreteDomain.range(1, size);
message5.setFrom(domain, new IndexPlusOne());
for (int i = size; --i>=0;)
{
assertEquals(i + 1, message5.getEnergy(i), 0.0);
}
message5.setFrom(domain, Value.createWithIndex(domain, 0));
assertEquals(1.0, message5.getWeight(0), 0.0);
for (int i = 1; i < size; ++i)
{
assertEquals(0.0, message5.getWeight(i), 0.0);
}
message5.setFrom(domain, Value.create(1));
assertEquals(1.0, message5.getWeight(0), 0.0);
for (int i = 1; i < size; ++i)
{
assertEquals(0.0, message5.getWeight(i), 0.0);
}
message5.setFrom(domain, message);
assertArrayEquals(message.getWeights(), message5.getWeights(), 1e-15);
message5 = new DiscreteWeightMessage(domain, null);
assertEquals(size, message5.size());
for (double w : message5.getWeights())
assertEquals(1.0/size, w, 1e-15);
message5 = new DiscreteWeightMessage(domain, new IndexPlusOne());
for (int i = size; --i>=0;)
{
assertEquals(i + 1, message5.getEnergy(i), 0.0);
}
message5 = new DiscreteEnergyMessage(domain, new IndexPlusOne());
for (int i = size; --i>=0;)
{
assertEquals(i + 1, message5.getEnergy(i), 0.0);
}
}
private static class IndexPlusOne implements IUnaryFactorFunction
{
private static final long serialVersionUID = 1L;
@Override
public DataRepresentationType representationType()
{
return DataRepresentationType.FUNCTION;
}
@Override
public boolean objectEquals(@Nullable Object other)
{
return other instanceof IndexPlusOne;
}
@Override
public IUnaryFactorFunction clone()
{
return this;
}
@Override
public double evalEnergy(Value value)
{
return value.getIndex() + 1;
}
@Override
public double evalEnergy(Object value)
{
return Double.POSITIVE_INFINITY;
}
}
}