/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers;
import static com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.*;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.CurrentModel;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.solvers.gibbs.GibbsOptions;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph;
import com.analog.lyric.dimple.solvers.minsum.MinSumSolver;
import com.analog.lyric.dimple.solvers.minsum.MinSumSolverGraph;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolver;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
*
* @since 0.08
* @author Christopher Barber
*/
public class TestConstantsWithSolvers extends DimpleTestBase
{
@SuppressWarnings("unused")
@Test
public void trivialConstantTest()
{
try (CurrentModel model = using(new FactorGraph()))
{
DiscreteDomain d3 = DiscreteDomain.range(1, 3);
Discrete a = discrete("a",d3);
a.setPrior(Value.constant(3));
Discrete b = discrete("b",d3);
Discrete c = discrete("c",d3);
FactorFunction func = new FactorFunction () {
@Override
public double evalEnergy(Value[] values)
{
double sum = 0.0;
for (Value value : values)
{
sum += value.getDouble();
}
return sum;
}
};
Factor fab = name("f(a,b)", addFactor(func, a, b));
Factor fc = name("f(3,c)", addFactor(func, 3, c));
requireNonNull(fc.getConstantValueByIndex(0)).valueEquals(requireNonNull(a.getPriorValue()));
SumProductSolverGraph sumproduct = requireNonNull(model.graph.setSolverFactory(new SumProductSolver()));
sumproduct.solve();
assertArrayEquals(b.getBelief(), c.getBelief(), 1e-15);
MinSumSolverGraph minsum = requireNonNull(model.graph.setSolverFactory(new MinSumSolver()));
minsum.solve();
assertArrayEquals(b.getBelief(), c.getBelief(), 1e-15);
GibbsSolverGraph gibbs = requireNonNull(model.graph.setSolverFactory(new GibbsSolver()));
gibbs.setOption(GibbsOptions.numSamples, 20000);
gibbs.setOption(GibbsOptions.saveAllSamples, true);
gibbs.solve();
assertArrayEquals(b.getBelief(), c.getBelief(), 1e-2);
}
}
}