/******************************************************************************* * Copyright 2015 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.data; import static com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.*; import static com.analog.lyric.util.test.ExceptionTester.*; import static org.junit.Assert.*; import java.util.ArrayList; import org.junit.Test; import com.analog.lyric.dimple.data.DataLayer; import com.analog.lyric.dimple.data.DataStack; import com.analog.lyric.dimple.data.GenericDataLayer; import com.analog.lyric.dimple.data.PriorDataLayer; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.CurrentModel; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.solvers.core.parameterizedMessages.NormalParameters; import com.analog.lyric.dimple.test.DimpleTestBase; /** * * @since 0.08 * @author Christopher Barber */ public class TestDataStack extends DimpleTestBase { @Test public void test() { try (CurrentModel root = using(new FactorGraph("root"))) { Real a = real("a"); Real b = real("b"); normal(0.0, 1.0, a, b); PriorDataLayer layer0 = new PriorDataLayer(root.graph); DataStack priorStack = new DataStack(layer0); assertInvariants(priorStack); assertSame(layer0, priorStack.get(0)); GenericDataLayer layer1 = GenericDataLayer.dense(root.graph); DataStack stack = new DataStack(layer0, layer1); assertInvariants(stack); assertSame(layer0, stack.get(0)); assertSame(layer1, stack.get(1)); // Test computeTotalEnergy NormalParameters normal = new NormalParameters(0, 1); a.setPrior(0.0); expectThrow(IllegalStateException.class, "There is no value for.*", stack, "computeTotalEnergy"); b.setPrior(1.0); layer1.set(b, 42); // superceded by prior assertEquals(normal.evalEnergy(0) + normal.evalEnergy(1) - 2 * normal.getNormalizationEnergy(), stack.computeTotalEnergy(), 1e-15); b.setPrior(null); assertEquals(normal.evalEnergy(0) + normal.evalEnergy(42) - 2 * normal.getNormalizationEnergy(), stack.computeTotalEnergy(), 1e-15); NormalParameters bPrior = new NormalParameters(1,2); b.setPrior(bPrior); assertEquals(normal.evalEnergy(0) + normal.evalEnergy(42) - 2 * normal.getNormalizationEnergy() + bPrior.evalEnergy(42), stack.computeTotalEnergy(), 1e-15); } try { @SuppressWarnings("unused") DataStack stack = new DataStack(new ArrayList<DataLayer<?>>()); fail("Expected IllegalArgumentException"); } catch (IllegalArgumentException ex) {} try { @SuppressWarnings("unused") DataStack stack = new DataStack(GenericDataLayer.dense(new FactorGraph()), GenericDataLayer.dense(new FactorGraph())); fail("Expected IllegalArgumentException"); } catch (IllegalArgumentException ex) {} } private void assertInvariants(DataStack stack) { final FactorGraph root = stack.rootGraph(); final int size = stack.size(); assertTrue(size > 0); assertFalse(stack.isEmpty()); for (int i = 0; i < size; ++i) { DataLayer<?> layer = stack.get(i); assertSame(root, layer.rootGraph()); } expectThrow(IndexOutOfBoundsException.class, stack, "get", -1); expectThrow(IndexOutOfBoundsException.class, stack, "get", size); expectThrow(UnsupportedOperationException.class, stack, "add", GenericDataLayer.sparse(root)); assertEquals(size, stack.size()); } }