/* * File: SumProductPairwiseBayesNetTest.java * Authors: Jeremy D. Wendt * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright 2016, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.graph.inference; import gov.sandia.cognition.graph.DenseMemoryGraph; import java.util.Map; import org.junit.Test; import static org.junit.Assert.*; /** * Test the SumProductPairwiseBayesNet class. * * @author jdwendt */ public class SumProductDirectedPropagationTest { @Test public void exactTest() { DenseMemoryGraph<Integer> graph = new DenseMemoryGraph<>(3, 2); graph.addEdge(1, 2); graph.addEdge(2, 3); BasicHomogeneousHandler<Integer> handler = new BasicHomogeneousHandler<>(); NodeNameAwareEnergyFunction<Integer, Integer> fn = new GraphWrappingEnergyFunction<>(graph, handler); fn.setLabel(1, 0); EnergyFunctionSolver<Integer> bayes = new SumProductDirectedPropagation<>( 100, 1e-6, 1); bayes.init(fn); assertTrue(bayes.solve()); InferenceHelper.testExactResults(graph, fn, bayes, new double[] { 1.0, 0.0, 0.99, 0.01, 0.9802, 0.0198 }, false); handler.setSpecialUnaryPotential(3, graph, 0.8); fn.setLabel(3, 1); bayes.init(fn); assertTrue(bayes.solve()); InferenceHelper.testExactResults(graph, fn, bayes, new double[] { 1.0, 0.0, 0.99, 0.01, 0.9252, 0.0748 }, false); } @Test public void basicTest() { DenseMemoryGraph<Integer> graph = new DenseMemoryGraph<>(6, 5); graph.addEdge(0, 1); graph.addEdge(1, 2); graph.addEdge(2, 3); graph.addEdge(2, 4); graph.addEdge(5, 4); NodeNameAwareEnergyFunction<Integer, Integer> fn = new GraphWrappingEnergyFunction<>(graph, new BasicHomogeneousHandler<Integer>()); fn.setLabel(1, 0); EnergyFunctionSolver<Integer> bayes = new SumProductDirectedPropagation<>( 100, 1e-6, 1); bayes.init(fn); assertTrue(bayes.solve()); InferenceHelper.testApproximateResults(graph, fn, bayes, new double[] { 0.5, 0.5, 0.99, 0.01, 0.98, 0.02, 0.98, 0.02, 0.98, 0.02, 0.5, 0.5 }, 0.01, false); Map<Integer, Double> beliefs1 = fn.getBeliefs(1, bayes); Map<Integer, Double> beliefs2 = fn.getBeliefs(2, bayes); // 2's belief should be less than 1's because it's learning from 1 assertTrue(beliefs2.get(0) < beliefs1.get(0)); Map<Integer, Double> beliefs3 = fn.getBeliefs(3, bayes); // 3's belief should be less than 2's because it's learning from 2 assertTrue(beliefs3.get(0) < beliefs2.get(0)); } }