/*
* File: Message.java
* Authors: Tu-Thach Quach
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright 2016, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government.
* Export of this program may require a license from the United States
* Government. See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.graph.inference;
/**
* Package-private class used for BP message passing. Stores the message passed
* from a source node to the destination node. The destination node stores this
* message, so there's no need to store its id internally here.
*
* @author tong
*
*/
class Message
{
/**
* The minimum value a belief can take on for numerical precision reasons
*/
private static final double MIN_BELIEF = 1e-6;
/**
* The id for the node that sends this message
*/
private final int sourceNode;
/**
* The message's values from the last complete iteration
*/
private final double[] values;
/**
* The message's values from the current iteration being computed
*/
private final double[] tempValues;
/**
* The log of the values from the last complete iteration
*/
private final double[] logValues;
/**
* Initialize this message with the source id and number of labels
*
* @param source The id for the source node for this message
* @param numLabels The number of labels this message will contain
*/
public Message(int source,
int numLabels)
{
this.sourceNode = source;
this.values = new double[numLabels];
this.tempValues = new double[numLabels];
this.logValues = new double[numLabels];
}
/**
* Returns the id for the source node for this message
*
* @return
*/
public int getSourceNode()
{
return sourceNode;
}
/**
* Normalize the temporary values for this message following the sum-product
* algorithm
*/
public void normalizeTempValuesForSumProductAlgorithm()
{
// Normalize the messages so that we don't converge to zero.
double total = 0;
for (int label = 0; label < tempValues.length; label++)
{
total += tempValues[label];
}
for (int label = 0; label < tempValues.length; label++)
{
tempValues[label] /= total;
}
}
/**
* Update this message at the completion of an iteration
*
* @return The maximum change for any label on this message
*/
public double update()
{
double delta = 0;
for (int label = 0; label < values.length; label++)
{
delta = Math.max(delta, Math.abs(values[label] - tempValues[label]));
// Numerical precision issues
values[label] = Math.max(tempValues[label], MIN_BELIEF);
logValues[label] = Math.log(values[label]);
}
return delta;
}
/**
* Set the temporary value for this message for this iteration
*
* @param label The index whose value should be set
* @param value The value to set it to
*/
public void setTempValue(int label,
double value)
{
tempValues[label] = value;
}
/**
* Returns the real-value for this message
*
* @param label The index whose value should be returned
* @return The value for the input label
*/
public double getValue(int label)
{
return values[label];
}
/**
* Returns the log of the value for this message
*
* @param label The index whose value should be returned
* @return The log of the value for the input label
*/
public double getLogValue(int label)
{
return logValues[label];
}
/**
* Resets all values for this message to default (1.0)
*/
public void resetToOne()
{
for (int label = 0; label < values.length; label++)
{
values[label] = 1;
}
}
@Override
public String toString()
{
StringBuilder buffer = new StringBuilder();
buffer.append("Source ");
buffer.append(sourceNode);
buffer.append(" [");
for (double v : values)
{
buffer.append(v);
buffer.append(", ");
}
buffer.append(']');
return buffer.toString();
}
}