/* * File: Node.java * Authors: Tu-Thach Quach * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright 2016, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.graph.inference; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; /** * Package-private support class for BP. Specifically stores and computes intermediate * values for a node in the BP algorithm. Only made public as the * SumProductInferencingAlgorithm class does expose a protected member of this * type. * * @author tong * @param <LabelType> The type that labels can be for this node -- used only as * an enum */ class Node<LabelType> { /** * The BP-internal id for this node */ private final int id; /** * The labels this node can take on */ private final Collection<LabelType> labels; /** * Stores the messages that are sent to this node */ private final List<Message> messages; /** * The beliefs currently associated with each label on this node */ private final double[] beliefs; /** * Initialize this node with the specified id and possible labels * * @param id The BP-internal id for this node * @param labels The labels this node can take on */ public Node(int id, Collection<LabelType> labels) { this.id = id; this.labels = labels; beliefs = new double[labels.size()]; messages = new ArrayList<>(); } /** * Create a message from sourceNode to this * * @param sourceNode The source of a message to this node * @param ensureUniqueEdges Ensure that entered edges are unique */ public void link(int sourceNode, boolean ensureUniqueEdges) { if (ensureUniqueEdges) { // Make sure we do not have source node in the message list. for (Message msg : messages) { if (msg.getSourceNode() == sourceNode) { throw new IllegalArgumentException("Source node " + sourceNode + " to " + id + " appears more than once."); } } } Message msg = new Message(sourceNode, labels.size()); messages.add(msg); } /** * Helper that normalizes the messages following the sum-product methodology */ public void normalizeMessagesForSumProductAlgorithm() { for (Message msg : messages) { msg.normalizeTempValuesForSumProductAlgorithm(); } } /** * Update all of the messages received by this node * * @return The maximum change for any messages received by this node */ public double update() { double delta = 0; for (Message msg : messages) { delta = Math.max(delta, msg.update()); } return delta; } /** * Resets all the incoming messages to one (for sum-product algorithm) */ public void resetToOne() { for (Message msg : messages) { msg.resetToOne(); } } /** * Return the internal id for this node * * @return the internal id for this node */ public int getId() { return id; } /** * Returns the sum of the log of the messages excluding the message from * excludeNode * * @param nodeLabel The label id to compute the product for * @param excludeNode The node whose message should be excluded * @return The sum of the log of the messages */ public double getLogMessageSum(int nodeLabel, int excludeNode) { double v = 0; for (Message msg : messages) { if (msg.getSourceNode() != excludeNode) { v += msg.getLogValue(nodeLabel); } } return v; } /** * Returns the sum of the log of the messages incoming to this node * * @param nodeLabel The label id to compute the product for * @return The sum of the log of the messages */ public double getLogMessageSum(int nodeLabel) { double v = 0; for (Message msg : messages) { v += msg.getLogValue(nodeLabel); } return v; } /** * Returns the message to this node from the input source node * * @param sourceNode The node whose message to this is requested * @return the message to this node from the input source node */ public Message getMessageFromSource(int sourceNode) { for (Message msg : messages) { if (msg.getSourceNode() == sourceNode) { return msg; } } throw new IllegalArgumentException("Node " + id + " does not contain a message from source node " + sourceNode); } /** * Computes the beliefs for this node following the sum-product algorithm * * @param f The function to use for values */ public void computeBeliefsForSumProductAlgorithm(EnergyFunction<LabelType> f) { double max = -Double.MAX_VALUE; Iterator<LabelType> iter = labels.iterator(); for (int i = 0; i < labels.size(); ++i) { LabelType label = iter.next(); double belief = -f.getUnaryCost(id, label) + getLogMessageSum(i); beliefs[i] = belief; max = Math.max(belief, max); } double total = 0; for (int i = 0; i < labels.size(); ++i) { beliefs[i] = Math.exp(beliefs[i] - max); total += beliefs[i]; } // Normalize. for (int label = 0; label < labels.size(); label++) { beliefs[label] /= total; } } /** * Returns the belief solved for this node and the input label * * @param label The label whose associated belief is desired * @return The belief associated for the input label */ public double getBelief(int label) { return beliefs[label]; } @Override public String toString() { StringBuilder buffer = new StringBuilder(); buffer.append("Node "); buffer.append(id); buffer.append(":\r\n"); for (Message msg : messages) { buffer.append(msg.toString()); buffer.append("\r\n"); } return buffer.toString(); } }