/* * File: SumProductBeliefPropagation.java * Authors: Tu-Thach Quach, Jeremy D. Wendt * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright 2016, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.graph.inference; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.util.Pair; import java.util.Collection; /** * This class implements the sum-product belief propagation algorithm for * arbitrary energy functions. It has been tested against graph-based energy * functions with widely varying node degrees and overall scales. * * It runs the algorithm in parallel on as many cores as the caller specifies. * * @author tong, jdwendt * @param <LabelType> The type that the nodes' labels can take on. Note that * these values will be considered enumerations internally. */ @PublicationReference(author = "Jonahtan S. Yedidia, William T. Freeman, and Yair Weiss", title = "Understanding Belief Propagation and its Generalizations", type = PublicationType.TechnicalReport, year = 2001, notes = { "Institution: Mitsubishi Electric Research Laboratories" }) public class SumProductBeliefPropagation<LabelType> extends SumProductInferencingAlgorithm<LabelType> { /** * Creates a new BP solver with the specified settings. * * @param maxNumIterations The maximum number of iterations that will be run * @param eps The stopping epsilon that will be used * @param numThreads The number of threads that will be used */ public SumProductBeliefPropagation(int maxNumIterations, double eps, int numThreads) { super(maxNumIterations, eps, numThreads); } /** * Creates a new BP solver with the default settings excepting max number of * iterations. * * @param maxNumIterations The maximum number of iterations that will be run */ public SumProductBeliefPropagation(int maxNumIterations) { this(maxNumIterations, DEFAULT_EPS, DEFAULT_NUM_THREADS); } /** * Creates a new BP solver with the default settings. */ public SumProductBeliefPropagation() { this(DEFAULT_MAX_ITERATIONS, DEFAULT_EPS, DEFAULT_NUM_THREADS); } /** * Private helper that computes the temporary message for the specified edge * for the current iteration going in the specified direction (a different * message for each direction) * * @param edge The edge to compute the temporary message for (you can't * replace the current message yet as the current message may be needed for * other edges). * @param reverse Specifies whether this should compute the message from i * to j or j to i */ private void computeTemporaryMessage(int edge, boolean reverse) { Node<LabelType> sourceNode, targetNode; Pair<Integer, Integer> edgePair = fn.getEdge(edge); if (reverse) { sourceNode = nodes.get(edgePair.getSecond()); targetNode = nodes.get(edgePair.getFirst()); } else { sourceNode = nodes.get(edgePair.getFirst()); targetNode = nodes.get(edgePair.getSecond()); } int sourceNodeId = sourceNode.getId(); int targetNodeId = targetNode.getId(); Message targetMessage = targetNode.getMessageFromSource(sourceNodeId); // Compute m_{ij}(x_j). Collection<LabelType> sourceLabels = fn.getPossibleLabels(sourceNodeId); Collection<LabelType> targetLabels = fn.getPossibleLabels(targetNodeId); int size = sourceLabels.size() * targetLabels.size(); double[] values = new double[size]; double max = -Double.MAX_VALUE; int ij = 0; for (LabelType targetLabel : targetLabels) { int sourceLabelIdx = 0; for (LabelType sourceLabel : sourceLabels) { values[ij] = -fn.getUnaryCost(sourceNodeId, sourceLabel); if (!reverse) { values[ij] += -fn.getPairwiseCost(edge, sourceLabel, targetLabel); } else { values[ij] += -fn.getPairwiseCost(edge, targetLabel, sourceLabel); } values[ij] += sourceNode.getLogMessageSum(sourceLabelIdx, targetNodeId); max = Math.max(values[ij], max); ++sourceLabelIdx; ++ij; } } // Now correct so that all log values are less than or equal to 0 ij = 0; for (int i = 0; i < targetLabels.size(); ++i) { double value = 0; for (int j = 0; j < sourceLabels.size(); ++j) { value += Math.exp(values[ij] - max); ++ij; } targetMessage.setTempValue(i, value); } } @Override protected void computeTemporaryMessage(int edge) { computeTemporaryMessage(edge, true); computeTemporaryMessage(edge, false); } @Override void initMessages(Pair<Integer, Integer> edgePair) { Node<LabelType> node = nodes.get(edgePair.getFirst()); node.link(edgePair.getSecond(), true); node = nodes.get(edgePair.getSecond()); node.link(edgePair.getFirst(), true); } }