/******************************************************************************* * Copyright 2012 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test; import static org.junit.Assert.*; import java.util.Arrays; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; import com.analog.lyric.dimple.factorfunctions.XorDelta; import com.analog.lyric.dimple.factorfunctions.core.CustomFactorFunctionWrapper; import com.analog.lyric.dimple.factorfunctions.core.FactorFunction; import com.analog.lyric.dimple.factorfunctions.core.FactorTable; import com.analog.lyric.dimple.factorfunctions.core.IFactorTable; import com.analog.lyric.dimple.factorfunctions.core.TableFactorFunction; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.domains.DiscreteDomain; import com.analog.lyric.dimple.model.domains.JointDomainIndexer; import com.analog.lyric.dimple.model.factors.DiscreteFactor; import com.analog.lyric.dimple.model.variables.Discrete; public class FactorFunctionTest extends DimpleTestBase { @BeforeClass public static void setUpBeforeClass() { } @AfterClass public static void tearDownAfterClass() { } @Before public void setUp() { } @After public void tearDown() { } @Test(expected=Exception.class) public void test_simpleStuff() { String name = "name"; FactorFunction ff = new CustomFactorFunctionWrapper(name); assertTrue(ff.getName() == name); //should kaboom ff.eval(new Object[]{.5, .5}); } @Test public void test_variable_constructor() { // public FactorTable(int [][] indices, double [] weights, Discrete... variables) // public FactorTable(int [][] indices, double [] weights,DiscreteDomain ... domains) Discrete[] discretes = new Discrete[]{new Discrete(0, 1), new Discrete(0, 1), new Discrete(0, 1)}; Discrete[] discretes6 = new Discrete[]{new Discrete(0, 1), new Discrete(0, 1), new Discrete(0, 1),new Discrete(0, 1), new Discrete(0, 1), new Discrete(0, 1)}; DiscreteDomain[] domains = new DiscreteDomain[]{DiscreteDomain.bit(), DiscreteDomain.bit(), DiscreteDomain.bit()}; DiscreteDomain[] vDomains = new DiscreteDomain[]{discretes[0].getDiscreteDomain(), discretes[1].getDiscreteDomain(), discretes[2].getDiscreteDomain()}; XorDelta xorFF = new XorDelta(); IFactorTable xorFT = xorFF.getFactorTable(domains); int[][] xorIndices = xorFT.getIndicesSparseUnsafe(); double[] xorWeights = xorFT.getWeightsSparseUnsafe(); int[][] table = new int[xorIndices.length][]; double[] weights = new double[xorWeights.length]; for(int i = 0; i < table.length; ++i) { table[i] = new int[xorIndices[i].length]; for(int j = 0; j < table[i].length; ++j) { table[i][j] = xorIndices[i][j]; } weights[i] = xorWeights[i]; } IFactorTable ftVar = FactorTable.create(table, weights, discretes); IFactorTable ftVDomain = FactorTable.create(table, weights, vDomains); IFactorTable ftDomain = FactorTable.create(table, weights, domains); TableFactorFunction tffVar = new TableFactorFunction("tffVar", table, weights, discretes); TableFactorFunction tffDVar = new TableFactorFunction("tffDVar", ftVar); TableFactorFunction tffVDomain = new TableFactorFunction("tffVar", table, weights, vDomains); TableFactorFunction tffDomain = new TableFactorFunction("tffVar", table, weights, domains); FactorGraph fg = new FactorGraph(); DiscreteFactor fVar = (DiscreteFactor) fg.addFactor(tffVar, (Object[])discretes); DiscreteFactor fDomain = (DiscreteFactor) fg.addFactor(tffDomain, (Object[])discretes); DiscreteFactor fxd = (DiscreteFactor) fg.addFactor(xorFF, (Object[])discretes); DiscreteFactor fVar2 = (DiscreteFactor) fg.addFactor(tffVar, (Object[])discretes); DiscreteFactor fDomain2 = (DiscreteFactor) fg.addFactor(tffDomain, (Object[])discretes); DiscreteFactor fxd2 = (DiscreteFactor) fg.addFactor(xorFF, (Object[])discretes); DiscreteFactor fxd6 = (DiscreteFactor) fg.addFactor(xorFF, (Object[])discretes6); DiscreteFactor fxd6_2 = (DiscreteFactor) fg.addFactor(xorFF, (Object[])discretes6); assertEquals(fVar.getFactorTable().hashCode(), fVar2.getFactorTable().hashCode()); assertEquals(fDomain.getFactorTable().hashCode(), fDomain2.getFactorTable().hashCode()); assertEquals(fxd.getFactorTable().hashCode(), fxd2.getFactorTable().hashCode()); assertEquals(fxd6.getFactorTable().hashCode(), fxd6_2.getFactorTable().hashCode()); assertTrue(fVar.getFactorTable().hashCode() != fDomain.getFactorTable().hashCode()); assertTrue(fxd.getFactorTable().hashCode() != fxd6.getFactorTable().hashCode()); JointDomainIndexer ftDomains = ftVar.getDomainIndexer(); for(int i = 0; i < ftDomains.size(); ++i) { DiscreteDomain idomain = ftDomains.get(i); for(int j = 0; j < idomain.size(); ++j) { Object ijelt = idomain.getElement(j); assertEquals(ijelt ,ftVDomain.getDomainIndexer().get(i).getElement(j)); assertEquals(ijelt, ftDomain.getDomainIndexer().get(i).getElement(j)); assertEquals(ijelt, tffVar.getFactorTable(domains).getDomainIndexer().get(i).getElement(j)); assertEquals(ijelt, tffDVar.getFactorTable(domains).getDomainIndexer().get(i).getElement(j)); assertEquals(ijelt, tffVDomain.getFactorTable(domains).getDomainIndexer().get(i).getElement(j)); assertEquals(ijelt, tffDomain.getFactorTable(domains).getDomainIndexer().get(i).getElement(j)); assertEquals(ijelt, fVar2.getFactorTable().getDomainIndexer().get(i).getElement(j)); } } for(int i = 0; i < ftVar.sparseSize(); ++i) { int[] ftVarRow = ftVar.sparseIndexToIndices(i); int[] xorFTRow = xorFT.sparseIndexToIndices(i); assertArrayEquals(ftVarRow, xorFTRow); } DiscreteDomain[] domains6 = new DiscreteDomain[6]; Arrays.fill(domains6, DiscreteDomain.bit()); DiscreteDomain domainThreeEntries = DiscreteDomain.range(0.0, 2.0); IFactorTable ftThreeBinary = xorFF.getFactorTable(domains); IFactorTable ftThreeBinary2 = xorFF.getFactorTable(vDomains); assertSame(ftThreeBinary, ftThreeBinary2); assertEquals(ftThreeBinary.hashCode(), ftThreeBinary2.hashCode()); IFactorTable xor6FT = xorFF.getFactorTable(domains6); IFactorTable xor3AryFT = xorFF.getFactorTable(new DiscreteDomain[]{domains[0], domains[1], domainThreeEntries}); assertNotSame(ftThreeBinary, xor6FT); assertTrue(ftThreeBinary.hashCode() != xor6FT.hashCode()); assertNotSame(ftThreeBinary, xor3AryFT); assertTrue(ftThreeBinary.hashCode() != xor3AryFT.hashCode()); //no kaboom assertTrue(ftThreeBinary.toString().length() != 0); } }