/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.model.variables;
import static com.analog.lyric.util.test.ExceptionTester.*;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.util.List;
import org.junit.Test;
import com.analog.lyric.dimple.data.DataLayer;
import com.analog.lyric.dimple.data.IDatum;
import com.analog.lyric.dimple.factorfunctions.Negate;
import com.analog.lyric.dimple.factorfunctions.Normal;
import com.analog.lyric.dimple.factorfunctions.core.IUnaryFactorFunction;
import com.analog.lyric.dimple.model.core.Edge;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.core.NodeType;
import com.analog.lyric.dimple.model.domains.RealDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.model.variables.Real;
import com.analog.lyric.dimple.model.variables.Variable;
import com.analog.lyric.dimple.solvers.core.parameterizedMessages.NormalParameters;
import com.analog.lyric.dimple.solvers.interfaces.ISolverVariable;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
* Unit test for {@link Variable} class.
* <p>
* @since 0.08
* @author Christopher Barber
*/
@SuppressWarnings("deprecation")
public class TestVariable extends DimpleTestBase
{
@SuppressWarnings("unused")
@Test
public void test()
{
FactorGraph fg = new FactorGraph();
assertNull(fg.getDefaultConditioningLayer());
Real r = new Real();
assertInvariants(r);
assertEquals("Real", r.getModelerClassName()); // deprecated
assertNull(r.getParentGraph());
assertNull(r.getPrior());
assertNull(r.getCondition());
assertEquals(RealDomain.unbounded(), r.getDomain());
// deterministic input/output flags
assertFalse(r.isDeterministicInput());
assertFalse(r.isDeterministicOutput());
r.setDeterministicInput();
assertTrue(r.isDeterministicInput());
assertFalse(r.isDeterministicOutput());
r.initialize(); // clears flags
assertFalse(r.isDeterministicInput());
assertFalse(r.isDeterministicOutput());
r.setDeterministicOutput();
assertFalse(r.isDeterministicInput());
assertTrue(r.isDeterministicOutput());
r.initialize();
Value prior = Value.createReal(Math.PI);
r.setPrior(prior);
assertSame(prior, r.getPrior());
assertInvariants(r);
r.setPrior(42);
assertEquals(42.0, requireNonNull(r.getPriorValue()).getDouble(), 0.0);
expectThrow(ClassCastException.class, r, "setPrior", "barf");
assertEquals(42.0, requireNonNull(r.getPriorValue()).getDouble(), 0.0);
NormalParameters normal = new NormalParameters(1.0, 2.0);
r.setPrior(normal);
assertSame(normal, r.getPrior());
assertInvariants(r);
r.setPrior(null);
assertNull(r.getPrior());
r.setFixedValue(1.2);
assertEquals(1.2, requireNonNull(r.getPriorValue()).getDouble(), 1e-15);
assertEquals((Double)1.2, (Double)requireNonNull(r.getFixedValueAsObject()), 1e-15);
r.setFixedValueObject(null);
assertNull(r.getPrior());
r.setInputObject(normal);
assertSame(normal, r.getPrior());
// no parent - cannot create a default conditioning layer
expectThrow(IllegalStateException.class, r, "setCondition", 12);
r.setCondition(null); // ok to set to null when there is no parent
assertNull(r.getCondition());
fg.addVariables(r);
assertSame(fg, r.getParentGraph());
r.setCondition(normal);
assertSame(normal, r.getCondition());
assertNotNull(fg.getDefaultConditioningLayer());
assertInvariants(r);
r.setCondition(null);
assertNull(r.getCondition());
r.setCondition(1.2345);
assertEquals(1.2345, Value.class.cast(r.getCondition()).getDouble(), 1e-15);
fg.setDefaultConditioningLayer(null);
r.setCondition(null); // ok to set to null when there is no layer - it won't create a layer
assertNull(r.getCondition());
assertNull(fg.getDefaultConditioningLayer());
Factor rf = fg.addFactor(new Normal(0.0, 1.0), r);
fg.initialize(); // needed to update directedness
assertInvariants(r);
Real notr = new Real();
Factor notrf = fg.addFactor(new Negate(), notr, r);
fg.initialize();
assertInvariants(r);
assertInvariants(notr);
}
public static void assertInvariants(Variable var)
{
assertTrue(var.isVariable());
assertSame(var, var.asVariable());
assertEquals(NodeType.VARIABLE, var.getNodeType());
if (var instanceof Discrete)
{
assertSame(var, var.asDiscreteVariable());
}
else
{
expectThrow(ClassCastException.class, var, "asDiscreteVariable");
}
IDatum prior = var.getPrior();
assertSame(prior instanceof IUnaryFactorFunction ? prior : null, var.getPriorFunction());
assertSame(prior instanceof Value ? prior : null, var.getPriorValue());
assertEquals(prior instanceof Value, var.hasFixedValue());
IDatum condition = var.getCondition();
final FactorGraph graph = var.getParentGraph();
if (graph == null)
{
assertNull(condition);
}
else
{
DataLayer<?> conditioningLayer = graph.getDefaultConditioningLayer();
if (conditioningLayer == null)
{
assertNull(condition);
}
else
{
assertSame(conditioningLayer.get(var), condition);
}
}
ISolverVariable svar = var.getSolver();
if (svar != null)
{
assertSame(var, svar.getModelObject());
assertSame(svar, var.getSolverIfType(svar.getClass()));
}
//
// Siblings
//
final int nSiblings = var.getSiblingCount();
final List<Factor> siblings = var.getSiblings();
final Factor[] factors = var.getFactors();
assertTrue(nSiblings >= 0);
assertEquals(nSiblings, siblings.size());
assertEquals(nSiblings, factors.length);
boolean deterministicInput = false, deterministicOutput = false;
for (int i = 0; i < nSiblings; ++i)
{
Factor factor = factors[i];
assertSame(factor, siblings.get(i));
assertSame(factor, var.getSibling(i));
Edge edge = var.getSiblingEdge(i);
assertSame(factor, edge.factor());
assertSame(var, edge.variable());
assertSame(i, edge.edgeState().getVariableToFactorEdgeNumber());
if (factor.getFactorFunction().isDeterministicDirected())
{
switch (edge.direction())
{
case UNDIRECTED:
break;
case FROM_FACTOR:
deterministicOutput = true;
break;
case TO_FACTOR:
deterministicInput = true;
break;
}
}
}
assertEquals(deterministicInput, var.isDeterministicInput());
assertEquals(deterministicOutput, var.isDeterministicOutput());
//
// Deprecated stuff
//
assertEquals("Variable", var.getClassLabel());
if (svar == null)
{
assertFalse(var.guessWasSet());
expectThrow(NullPointerException.class, var, "getGuess");
expectThrow(NullPointerException.class, var, "setGuess", (Object)null);
}
assertSame(var.getPriorFunction(), var.getInputObject());
expectThrow(UnsupportedOperationException.class, var, "setSolver", (ISolverVariable)null);
if (prior instanceof Value)
{
Value priorValue = (Value)prior;
assertEquals(priorValue.getObject(), var.getFixedValueAsObject());
}
else
{
assertNull(var.getFixedValueAsObject());
assertNull(var.getFixedValueObject());
}
}
}