/******************************************************************************* * Copyright 2014 Analog Devices, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. ********************************************************************************/ package com.analog.lyric.dimple.test.solvers.gibbs; import static java.util.Objects.*; import static org.junit.Assert.*; import java.util.ArrayList; import java.util.HashSet; import java.util.Set; import org.junit.Test; import com.analog.lyric.collect.ReleasableIterator; import com.analog.lyric.dimple.factorfunctions.MatrixProduct; import com.analog.lyric.dimple.factorfunctions.Normal; import com.analog.lyric.dimple.factorfunctions.core.FactorFunction; import com.analog.lyric.dimple.model.core.FactorGraph; import com.analog.lyric.dimple.model.core.Node; import com.analog.lyric.dimple.model.domains.RealDomain; import com.analog.lyric.dimple.model.factors.Factor; import com.analog.lyric.dimple.model.values.Value; import com.analog.lyric.dimple.model.variables.Real; import com.analog.lyric.dimple.model.variables.Variable; import com.analog.lyric.dimple.solvers.gibbs.ISolverNodeGibbs; import com.analog.lyric.dimple.solvers.gibbs.ISolverVariableGibbs; import com.analog.lyric.dimple.solvers.gibbs.Solver; import com.analog.lyric.dimple.test.DimpleTestBase; public class TestGibbsNeighborList extends DimpleTestBase { @SuppressWarnings("null") @Test public void test() { FactorGraph fg = new FactorGraph(); fg.setSolverFactory(new Solver()); Real[][] aMatrix = makeRealMatrix("a", 2,4); Real[] aVars = flattenRealMatrix(aMatrix); Real[][] bMatrix = makeRealMatrix("b", 4,3); Real[] bVars = flattenRealMatrix(bMatrix); Real[][] cMatrix = makeRealMatrix("c", 2,3); Real[] cVars = flattenRealMatrix(cMatrix); Real[] vars = new Real[aVars.length + bVars.length + cVars.length]; System.arraycopy(cVars, 0, vars, 0, cVars.length); System.arraycopy(aVars, 0, vars, cVars.length, aVars.length); System.arraycopy(bVars, 0, vars, cVars.length + aVars.length, bVars.length); Factor matrixProduct = fg.addFactor(new MatrixProduct(2,4,3), vars); fg.initialize(); // The inputs should have empty scoring lists because the output variables // aren't attached to any other factors and don't have inputs. for (Real var : cVars) { assertFalse(((ISolverVariableGibbs)var.getSolver()).hasPotential()); assertScoreNodes(var, matrixProduct); } for (Real var : aVars) { assertScoreNodes(var); } for (Real var : bVars) { assertScoreNodes(var); } // // Now set inputs on the matrix product output variables and verify that the appropriate variables // show up as neighbors of the matrix product input variables. // for (Real var : cVars) { var.setPrior(new Normal()); } fg.initialize(); for (Real var : cVars) { assertTrue(((ISolverVariableGibbs)var.getSolver()).hasPotential()); assertScoreNodes(var, matrixProduct); } for (int row = 0, rows = aMatrix.length; row < rows; ++row) { for (int col = 0, cols = aMatrix[0].length; col < cols; ++col) { assertScoreNodes(aMatrix[row][col], cMatrix[row]); } } for (int col = 0, cols = bMatrix[0].length; col < cols; ++col) { Real[] expected = new Real[cMatrix.length]; for (int crow = 0; crow < cMatrix.length; ++crow) { expected[crow] = cMatrix[crow][col]; } for (int row = 0, rows = bMatrix.length; row < rows; ++row) { assertScoreNodes(bMatrix[row][col], expected); } } // // Clear inputs and instead, attach to Normals as variable // Factor[][] cNormalMatrix = new Factor[cMatrix.length][cMatrix[0].length]; Factor[] cNormals = new Factor[cMatrix.length * cMatrix[0].length]; for (int col = 0, cols = cMatrix[0].length, i = 0; col < cols; ++col) { for (int row = 0, rows = cMatrix.length; row < rows; ++row) { Real cVar = cMatrix[row][col]; cVar.setPrior(null); Factor factor = fg.addFactor(new Normal(), cVar, 1, new Real(RealDomain.nonNegative())); cNormalMatrix[row][col] = factor; cNormals[i++] = factor; } } fg.initialize(); for (int i = 0; i < cVars.length; ++i) { Real var = cVars[i]; assertFalse(((ISolverVariableGibbs)var.getSolver()).hasPotential()); assertScoreNodes(var, matrixProduct, cNormals[i]); } for (int row = 0, rows = aMatrix.length; row < rows; ++row) { for (int col = 0, cols = aMatrix[0].length; col < cols; ++col) { assertScoreNodes(aMatrix[row][col], cNormalMatrix[row]); } } for (int col = 0, cols = bMatrix[0].length; col < cols; ++col) { Node[] expected = new Node[cMatrix.length]; for (int crow = 0; crow < cMatrix.length; ++crow) { expected[crow] = cNormalMatrix[crow][col]; } for (int row = 0, rows = bMatrix.length; row < rows; ++row) { assertScoreNodes(bMatrix[row][col], expected); } } // // Make sure that visiting the same factor twice from different edges gets // reflected in the neighbors. // Real x = new Real(RealDomain.nonNegative()); fg.addFactor(new Copy(), x, aMatrix[0][0], bMatrix[1][1]); fg.initialize(); { ArrayList<Node> expectedList = new ArrayList<Node>(); for (Node node : cNormalMatrix[0]) expectedList.add(node); for (int row = 0, rows = cMatrix.length; row < rows; ++row) expectedList.add(cNormalMatrix[row][1]); assertScoreNodes(x, expectedList.toArray(new Node[expectedList.size()])); } } // test /*---------------- * Helper methods */ private void assertScoreNodes(Variable var, Node ... scoreNodes) { Set<ISolverNodeGibbs> expectedNodes = new HashSet<ISolverNodeGibbs>(scoreNodes.length); for (Node node : scoreNodes) { expectedNodes.add((ISolverNodeGibbs)node.getSolver()); } ISolverVariableGibbs svar = requireNonNull((ISolverVariableGibbs)var.getSolver()); int count = 0; ReleasableIterator<ISolverNodeGibbs> iter = svar.getSampleScoreNodes(); while (iter.hasNext()) { ++count; ISolverNodeGibbs snode = iter.next(); assertTrue(expectedNodes.contains(snode)); } iter.release(); assertEquals(expectedNodes.size(), count); } private Real[][] makeRealMatrix(String namePrefix, int rows, int cols) { Real[][] matrix = new Real[rows][cols]; for (int row = 0; row < rows; ++row) for (int col = 0; col < cols; ++col) { Real var = new Real(RealDomain.unbounded()); var.setName(String.format("%s[%d,%d]", namePrefix, row, col)); matrix[row][col] = var; } return matrix; } private Real[] flattenRealMatrix(Real[][] matrix) { final int rows = matrix.length; final int cols = matrix[0].length; Real[] flattened = new Real[rows * cols]; int cur = 0; for (int col = 0; col < cols; ++col) for (int row = 0; row < rows; ++row) flattened[cur++] = matrix[row][col]; return flattened; } /*----------------- * FactorFunctions */ // Replicates first argument to remaining args. static class Copy extends FactorFunction { @Override public int[] getDirectedToIndices(int numEdges) { int[] edges = new int[numEdges - 1]; for (int i = edges.length; --i>=0;) edges[i] = i + 1; return edges; } @Override public final double evalEnergy(Value[] arguments) { Value value = arguments[0]; for (int i = arguments.length; --i>=1;) if (!value.valueEquals(arguments[i])) return Double.POSITIVE_INFINITY; return 0.0; } @Override public void evalDeterministic(Value[] arguments) { Value value = arguments[0]; for (int i = arguments.length; --i>=1;) arguments[i].setFrom(value); } @Override public boolean isDeterministicDirected() { return true; } @Override public boolean isDirected() { return true; } } }