/*******************************************************************************
* Copyright 2014 Analog Devices, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
********************************************************************************/
package com.analog.lyric.dimple.test.solvers.core;
import static java.util.Objects.*;
import static org.junit.Assert.*;
import java.util.Arrays;
import org.junit.Test;
import com.analog.lyric.dimple.factorfunctions.Sum;
import com.analog.lyric.dimple.factorfunctions.core.FactorFunction;
import com.analog.lyric.dimple.factorfunctions.core.FactorTableEntry;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTable;
import com.analog.lyric.dimple.factorfunctions.core.IFactorTableIterator;
import com.analog.lyric.dimple.model.core.FactorGraph;
import com.analog.lyric.dimple.model.domains.DiscreteDomain;
import com.analog.lyric.dimple.model.factors.Factor;
import com.analog.lyric.dimple.model.values.Value;
import com.analog.lyric.dimple.model.variables.Bit;
import com.analog.lyric.dimple.model.variables.Discrete;
import com.analog.lyric.dimple.options.SolverOptions;
import com.analog.lyric.dimple.solvers.core.STableFactorBase;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolver;
import com.analog.lyric.dimple.solvers.gibbs.GibbsSolverGraph;
import com.analog.lyric.dimple.test.DimpleTestBase;
/**
* Unit test for STableFactorBase
*/
public class TestSTableFactorBase extends DimpleTestBase
{
@Test
public void test()
{
FactorGraph fg = new FactorGraph();
final int nBits = 40;
Discrete[] bits = new Discrete[nBits];
for (int i = 0; i < nBits; ++i)
{
bits[i] = new Bit();
bits[i].setName("bit" + i);
}
fg.addVariables(bits);
Discrete big = new Discrete(DiscreteDomain.range(0, Short.MAX_VALUE));
fg.setOption(SolverOptions.maxAutomaticFactorTableSize, Integer.MAX_VALUE);
Factor f2 = fg.addFactor(new Function(), bits[0], bits[1]);
Factor f8 = fg.addFactor(new Function(), bits[0], bits[1], bits[2], bits[3], bits[4], bits[5], bits[6], bits[7]);
Factor flarge = fg.addFactor(new Function(), bits);
Discrete[] vars = new Discrete[5];
vars[0] = big;
System.arraycopy(bits, 0, vars, 1, 4);
Factor sum = fg.addFactor(new Sum(), vars);
assertTrue(sum.isDirected());
assertTrue(sum.getFactorFunction().isDeterministicDirected());
GibbsSolverGraph sfg = requireNonNull(fg.setSolverFactory(new GibbsSolver()));
STableFactorBase sf2 = requireNonNull((STableFactorBase)sfg.getSolverFactor(f2));
STableFactorBase sf8 = requireNonNull((STableFactorBase)sfg.getSolverFactor(f8));
STableFactorBase sflarge = requireNonNull((STableFactorBase)sfg.getSolverFactor(flarge));
STableFactorBase ssum = requireNonNull((STableFactorBase)sfg.getSolverFactor(sum));
assertNull(sf2.getFactorTableIfComputed());
assertInvariants(sf2);
assertNull(sf8.getFactorTableIfComputed());
assertInvariants(sf8);
assertNull(sflarge.getFactorTableIfComputed());
assertInvariants(sflarge);
assertNull(ssum.getFactorTableIfComputed());
assertInvariants(ssum);
sfg.initialize();
assertNotNull(sf2.getFactorTableIfComputed());
assertInvariants(sf2);
assertNotNull(sf8.getFactorTableIfComputed());
assertInvariants(sf8);
assertNull(sflarge.getFactorTableIfComputed());
assertInvariants(sflarge);
assertNotNull(ssum.getFactorTableIfComputed());
assertInvariants(ssum);
sf2.clearFactorTable();
assertNull(sf2.getFactorTableIfComputed());
sf8.clearFactorTable();
assertNull(sf8.getFactorTableIfComputed());
fg.setOption(SolverOptions.maxAutomaticFactorTableSize, 0);
sfg.initialize();
assertNull(sf2.getFactorTableIfComputed());
assertNull(sf8.getFactorTableIfComputed());
assertNull(sflarge.getFactorTableIfComputed());
fg.setOption(SolverOptions.maxAutomaticFactorTableSize, 100);
sfg.initialize();
assertNotNull(sf2.getFactorTableIfComputed());
assertNull(sf8.getFactorTableIfComputed());
assertNull(sflarge.getFactorTableIfComputed());
}
private void assertInvariants(STableFactorBase sfactor)
{
final FactorFunction function = sfactor.getFactor().getFactorFunction();
final IFactorTable table = sfactor.getFactorTableIfComputed();
if (table != null)
{
assertSame(table, sfactor.getFactorTable());
assertSame(table, sfactor.getFactorTableIfComputed());
assertTrue(Arrays.deepEquals(table.getIndicesSparseUnsafe(), sfactor.getPossibleBeliefIndices()));
IFactorTableIterator iter = table.fullIterator();
for (FactorTableEntry entry; (entry = iter.next()) != null;)
{
assertEquals(entry.energy(), function.evalEnergy(entry.values()), 0.0);
}
}
}
private static class Function extends FactorFunction
{
@Override
public double evalEnergy(Value[] values)
{
long energy = 0L;
for (Value value : values)
{
energy <<= 1;
energy |= (value.getInt() & 1);
}
return energy;
}
}
}