/* * File: EdgeMergingEnergyFunctionTest.java * Authors: Jeremy D. Wendt * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright 2016, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.graph.inference; import gov.sandia.cognition.graph.DenseMemoryGraph; import gov.sandia.cognition.graph.DirectedNodeEdgeGraph; import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertTrue; import org.junit.Test; /** * Tests the edge merging energy function -- the energy function that makes it * so that BP sees at most one edge between any pair of nodes and handles * merging the potentials properly. * * @author jdwendt */ public class EdgeMergingEnergyFunctionTest { private static class FlippedEdgePotentialHandler implements GraphWrappingEnergyFunction.PotentialHandler<Integer, Integer> { private boolean isFlipped0_0; private boolean isFlipped0_1; private final double[][] pairwisePotentials0_0 = { { 1, 1 }, { 4, 1 } }; private final double[][] pairwisePotentials0_1 = { { 1, 2 }, { 1, 1 } }; private final double[][] pairwisePotentials1 = { { 1, 3 }, { 3, 1 } }; @Override public double getPairwisePotential( DirectedNodeEdgeGraph<Integer> graph, int edgeId, Integer ilabel, Integer jlabel) { // NOTE: By looking only at the edge ID, this leads to complications // in the test below. However, I can think of no other easy way of // handling the fact that the first two edges might be flipped // between 0->1 or 1->0. if (edgeId == 0) { if (isFlipped0_0) { return pairwisePotentials0_0[jlabel][ilabel]; } else { return pairwisePotentials0_0[ilabel][jlabel]; } } else if (edgeId == 1) { if (isFlipped0_1) { return pairwisePotentials0_1[jlabel][ilabel]; } else { return pairwisePotentials0_1[ilabel][jlabel]; } } else if (edgeId == 2) { return pairwisePotentials1[ilabel][jlabel]; } throw new RuntimeException("Input edge id: " + edgeId + " does not exist in the graph"); } @Override public double getUnaryPotential( DirectedNodeEdgeGraph<Integer> graph, int i, Integer label, Integer assignedLabel) { return 1.0; } @Override public List<Integer> getPossibleLabels( DirectedNodeEdgeGraph<Integer> graph, int nodeId) { List<Integer> ret = new ArrayList<>(); ret.add(0); ret.add(1); return ret; } }; @Test public void basicTest() { double[] correct = new double[] { 0.375, 0.625, 0.625, 0.375, 0.4375, 0.5625 }; DenseMemoryGraph<Integer> graph = new DenseMemoryGraph<>(3, 3); // Adding nodes is necessary for this test so that the edges always show up in the correct order graph.addNode(0); graph.addNode(1); graph.addNode(2); graph.addEdge(0, 1); graph.addEdge(1, 0); graph.addEdge(1, 2); FlippedEdgePotentialHandler handler = new FlippedEdgePotentialHandler(); handler.isFlipped0_0 = false; handler.isFlipped0_1 = true; GraphWrappingEnergyFunction<Integer, Integer> fn = new GraphWrappingEnergyFunction<>(graph, handler); EnergyFunctionSolver<Integer> bp = new SumProductBeliefPropagation<>(); bp.init(new EdgeMergingEnergyFunction<>(fn)); assertTrue(bp.solve()); InferenceHelper.testExactResults(graph, fn, bp, correct, false); graph = new DenseMemoryGraph<>(3, 3); // Adding nodes is necessary for this test so that the edges always show up in the correct order graph.addNode(0); graph.addNode(1); graph.addNode(2); graph.addEdge(1, 0); graph.addEdge(1, 0); graph.addEdge(1, 2); handler.isFlipped0_0 = true; handler.isFlipped0_1 = true; fn = new GraphWrappingEnergyFunction<>(graph, handler); bp = new SumProductBeliefPropagation<>(); bp.init(new EdgeMergingEnergyFunction<>(fn)); assertTrue(bp.solve()); InferenceHelper.testExactResults(graph, fn, bp, correct, false); graph = new DenseMemoryGraph<>(3, 3); // Adding nodes is necessary for this test so that the edges always show up in the correct order graph.addNode(0); graph.addNode(1); graph.addNode(2); graph.addEdge(1, 0); graph.addEdge(0, 1); graph.addEdge(1, 2); // NOTTE: This seems out of order after seeing the order the edges are added above. The problem // is that the edges are sorted internal to the object, so the edge 0->1 becomes the first edge, // and 1->0 becomes the second. handler.isFlipped0_0 = false; handler.isFlipped0_1 = true; fn = new GraphWrappingEnergyFunction<>(graph, handler); bp = new SumProductBeliefPropagation<>(); bp.init(new EdgeMergingEnergyFunction<>(fn)); assertTrue(bp.solve()); InferenceHelper.testExactResults(graph, fn, bp, correct, false); graph = new DenseMemoryGraph<>(3, 3); // Adding nodes is necessary for this test so that the edges always show up in the correct order graph.addNode(0); graph.addNode(1); graph.addNode(2); graph.addEdge(0, 1); graph.addEdge(0, 1); graph.addEdge(1, 2); handler.isFlipped0_0 = false; handler.isFlipped0_1 = false; fn = new GraphWrappingEnergyFunction<>(graph, handler); bp = new SumProductBeliefPropagation<>(); bp.init(new EdgeMergingEnergyFunction<>(fn)); assertTrue(bp.solve()); InferenceHelper.testExactResults(graph, fn, bp, correct, false); } }