/*******************************************************************************
* Copyright 2015 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.gibbs;
import static com.analog.lyric.dimple.model.sugar.ModelSyntacticSugar.*;
import static com.analog.lyric.math.MoreMatrixUtils.*;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import org.apache.commons.math3.linear.RealMatrix;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.core.FactorTable;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.repeated.BitStream;
import com.analog.lyric.dimple.model.repeated.DoubleArrayDataSink;
import com.analog.lyric.dimple.model.repeated.DoubleArrayDataSource;
import com.analog.lyric.dimple.model.variables.Bit;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolver;
import com.analog.lyric.dimple.solvers.sumproduct.SumProductSolverGraph;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
*
* @since 0.08
* @author Christopher Barber
*/
public class TestGibbsRolledUp extends DimpleTestBase
{
/**
* Adapted from MATLAB algoRolledUpGraphs/testGibbsTableFactor.m
*/
@Test
public void testGibbsTableFactor()
{
final int N = 100;
final int bufferSize = 1;
// Create graph
Bit xi = name("xi", new Bit());
Bit xo = name("xo", new Bit());
FactorGraph sg = new FactorGraph(xi, xo);
IFactorTable table = FactorTable.create(DiscreteDomain.bit(), DiscreteDomain.bit());
table.setWeightsDense(new double[] {0,1,1,0});
sg.addFactor(table, xi, xo);
FactorGraph fg = new FactorGraph();
BitStream x = new BitStream("x");
fg.addRepeatedFactorWithBufferSize(sg, bufferSize, x, x.getSlice(1));
// Generate data
final double[][] input = new double[N][];
double val = 1.0;//testRand.nextBoolean() ? 1 : 0;
input[0] = new double[] { val, 1 - val };
for (int i = 1; i < N; ++i)
{
double p = testRand.nextDouble();
val = p > table.getWeightForIndices((int)input[i-1][0], 0) ? 1 : 0;
input[i] = new double[] { val, 1 - val};
}
// Solve using sum-product
SumProductSolverGraph sfg1 = requireNonNull(fg.setSolverFactory(new SumProductSolver()));
x.setDataSource(new DoubleArrayDataSource(input));
DoubleArrayDataSink sink1 = new DoubleArrayDataSink();
x.setDataSink(sink1);
sfg1.solve();
// Solve again using Gibbs
GibbsSolverGraph sfg2 = requireNonNull(fg.setSolverFactory(new GibbsSolver()));
x.setDataSource(new DoubleArrayDataSource(input));
DoubleArrayDataSink sink2 = new DoubleArrayDataSink();
x.setDataSink(sink2);
sfg2.solve();
RealMatrix b1 = wrapRealMatrix(sink1.getArray());
RealMatrix b2 = wrapRealMatrix(sink2.getArray());
// RealMatrixFormat fmt = new RealMatrixFormat("[","]","","", "; ", ",");
// System.out.println(fmt.format(wrapRealMatrix(input)));
// System.out.println(fmt.format(b1));
// System.out.println(fmt.format(b2));
assertEquals(0.0, b1.subtract(b2).getNorm(), 1e-20);
}
}