/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.gibbs;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.util.ArrayList;
import java.util.Collections;
import org.eclipse.jdt.annotation.Nullable;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.Sum;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Real;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph;
import com.analog.lyric.dimple.solvers.gibbs.ISolverVariableGibbs;
/**
* Test to ensure that Gibbs solver orders deterministic directed factors correctly.
* <p>
* @since 0.07
* @author Christopher Barber
*/
public class TestGibbsDeterministicFactorOrdering
{
@Test
public void test()
{
// Builds a deterministic tree of the form:
//
// v0
// / \
// v1 v2
// / \/ \
// v3 v4 v5
// / \/ \/ \
// ..............
// \\\\\\\/////
// vx
//
// Where all factors are simple two operand sums with
// output variables above the input variables. E.g.,
// v0 = v1 + v2. All of the bottom variables are set
// deterministically from single vx variable.
final int depth = 5;
final int nVars = depth * (depth + 1) / 2;
Real[] vars = new Real[nVars];
for (int i = 0; i < nVars; ++i)
{
vars[i] = new Real();
vars[i].setName("v" + i);
}
FactorGraph fg = new FactorGraph();
fg.addVariables(vars);
// Create factors in random order.
ArrayList<Integer> outputIndices = new ArrayList<>(nVars - depth);
for (int i = nVars - depth; --i>=0;)
{
outputIndices.add(i);
}
Collections.shuffle(outputIndices);
for (int outputIndex : outputIndices)
{
Real outputVar = vars[outputIndex];
// Triangular root gives you the level in the tree:
int level = (int)((Math.sqrt(8.0 * outputIndex + 1.0) - 1.0) / 2.0);
int inputIndex1 = outputIndex + level + 1;
Real inputVar1 = vars[inputIndex1];
Real inputVar2 = vars[inputIndex1 + 1];
Factor factor =
fg.addFactor(new TestFunction(outputVar.getName(), new Sum()), outputVar, inputVar1, inputVar2);
assertTrue(factor.isDirected());
}
Real startVar = new Real();
Real[] baseVars = new Real[depth + 1];
baseVars[0] = startVar;
for (int i = 0; i < depth; ++i)
{
baseVars[i + 1] = vars[nVars - depth + i];
}
fg.addFactor(new TestFunction("dup", new Duplicate()), baseVars);
GibbsSolverGraph sfg = requireNonNull(fg.setSolverFactory(new GibbsSolver()));
sfg.initialize();
for (Factor factor : fg.getFactors())
{
TestFunction func = (TestFunction)factor.getFactorFunction();
func._evalCount = 0;
}
ISolverVariableGibbs[] svars = new ISolverVariableGibbs[nVars];
for (int i = 0; i < nVars; ++i)
{
svars[i] = sfg.getReal(vars[i]);
}
ISolverVariableGibbs startSVar = requireNonNull(sfg.getReal(startVar));
startSVar.setCurrentSample(1);
assertEquals(Math.pow(2.0, depth - 1), svars[0].getCurrentSampleValue().getDouble(), 0.0);
for (Factor factor : fg.getFactors())
{
TestFunction func = (TestFunction)factor.getFactorFunction();
assertEquals(1, func._evalCount);
}
}
private class TestFunction extends FactorFunction
{
// private final String _name;
private FactorFunction _delegate;
int _evalCount = 0;
private TestFunction(String name, FactorFunction delegate)
{
// _name = name;
_delegate = delegate;
}
@Override
public final @Nullable int[] getDirectedToIndices(int numEdges)
{
return _delegate.getDirectedToIndices(numEdges);
}
@Override
public double evalEnergy(Value[] values)
{
return _delegate.evalEnergy(values);
}
@Override
public void evalDeterministic(Value[] arguments)
{
_delegate.evalDeterministic(arguments);
++_evalCount;
}
@Override
public boolean isDeterministicDirected()
{
return _delegate.isDeterministicDirected();
}
@Override
public boolean isDirected()
{
return _delegate.isDirected();
}
}
private class Duplicate extends FactorFunction
{
@Override
public final @Nullable int[] getDirectedToIndices(int numEdges)
{
int[] indices = new int[numEdges - 1];
for (int i = 1; i < numEdges; ++i)
{
indices[i-1] = i;
}
return indices;
}
@Override
public double evalEnergy(Value[] values)
{
Value value = values[0];
for (int i = values.length; --i>=1;)
{
if (!value.valueEquals(values[i]))
{
return Double.POSITIVE_INFINITY;
}
}
return 0.0;
}
@Override
public void evalDeterministic(Value[] arguments)
{
Value value = arguments[0];
for (int i = arguments.length; --i>=1;)
{
arguments[i].setFrom(value);
}
}
@Override
public boolean isDeterministicDirected()
{
return true;
}
@Override
public boolean isDirected()
{
return true;
}
}
}