package cc.mallet.grmm.inference;
import java.util.logging.Logger;
import java.util.logging.Level;
import java.util.List;
import java.util.Iterator;
import java.io.*;
import cc.mallet.grmm.types.*;
import cc.mallet.util.MalletLogger;
/**
* Abstract base class for umplementations of belief propagation for general factor graphs.
* This class manages arrays of messages, computing beliefs from messages, and convergence
* thresholds.
* <p/>
* How to send individual messages (e.g., sum-product, max-product, etc) are mananged
* by istances of the interface @link{#MessageStrategy}. Concrete subclasses decide
* which order to send messages in.
*
* @author Charles Sutton
* @version $Id: AbstractBeliefPropagation.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $
*/
public abstract class AbstractBeliefPropagation extends AbstractInferencer {
protected static Logger logger = MalletLogger.getLogger (AbstractBeliefPropagation.class.getName ());
private static final boolean diagnoseConvergence = false;
protected boolean normalizeBeliefs = true;
static private int totalMessagesSent = 0;
transient private int myMessagesSent = 0;
transient private int messagesSentAtStart = 0;
private double threshold = 0.00001;
protected boolean useCaching = false;
private MessageStrategy messager;
protected transient int iterUsed;
protected AbstractBeliefPropagation ()
{
this (new SumProductMessageStrategy ());
}
protected AbstractBeliefPropagation (MessageStrategy messager)
{
this.messager = messager;
}
public MessageStrategy getMessager ()
{
return messager;
}
public AbstractBeliefPropagation setMessager (MessageStrategy messager)
{
this.messager = messager;
return this;
}
/**
* Returns the total number of messages all BP inferencers have sent in the current Java image.
*/
public static int getTotalMessagesSent ()
{ return totalMessagesSent; }
/** Returns the total number of messages this inferencer has sent since its creation. */
public int getMessagesSent () { return myMessagesSent; }
/**
* Returns the number of messages sent during the last call to
* computeMarginals.
*/
public int getMessagesUsedLastTime ()
{
return myMessagesSent - messagesSentAtStart;
}
protected void resetMessagesSentAtStart ()
{
messagesSentAtStart = myMessagesSent;
}
/**
* Array that maps (to, from) to the lambda message sent from node
* from to node to.
*/
transient private MessageArray messages;
transient private MessageArray oldMessages; // messages from variable --> factor
transient private Factor[] bel;
protected transient FactorGraph mdlCurrent;
private void retrieveCachedMessages (FactorGraph m)
{
messages = (MessageArray) m.getInferenceCache (getClass ());
}
private void cacheMessages (FactorGraph m)
{
m.setInferenceCache (getClass (), messages);
}
private void clearOldMessages ()
{
oldMessages = null;
}
final protected void copyOldMessages ()
{
clearOldMessages ();
oldMessages = messages.duplicate ();
}
final protected boolean hasConverged ()
{
return hasConverged (this.threshold);
}
final protected boolean hasConverged (double threshold)
{
double maxDiff = Double.NEGATIVE_INFINITY;
Factor bestOldMsg = null, bestNewMsg = null;
for (MessageArray.Iterator msgIt = oldMessages.iterator (); msgIt.hasNext ();) {
Factor oldMsg = (Factor) msgIt.next ();
Object from = msgIt.from ();
Object to = msgIt.to ();
Factor newMsg = messages.get (from, to);
if (oldMsg != null) {
assert newMsg != null
: "Message went from nonnull to null " + from + " --> " + to;
for (java.util.Iterator it = oldMsg.assignmentIterator (); it.hasNext ();) {
Assignment assn = (Assignment) it.next ();
double val1 = oldMsg.value (assn);
double val2 = newMsg.value (assn);
double diff = Math.abs (val1 - val2);
if (diff > threshold) {
if (diagnoseConvergence) {
System.err.println ("*** Not converged: Difference of : " + diff + " from " + oldMsg + " --> " + newMsg);
}
return false;
} else if (diff > maxDiff) {
maxDiff = diff;
bestOldMsg = oldMsg;
bestNewMsg = newMsg;
}
}
}
}
if (diagnoseConvergence) {
System.err.println (
"*** CONVERGED: Max absolute difference : " + maxDiff + " from " + bestOldMsg + " --> " + bestNewMsg);
}
return true;
}
private void initOldMessages (FactorGraph fg)
{
oldMessages = new MessageArray (fg);
if (useCaching && fg.getInferenceCache (getClass ()) != null) {
logger.info ("AsyncLoopyBP: Reusing previous marginals");
retrieveCachedMessages (fg);
copyOldMessages ();
} else {
for (java.util.Iterator it = fg.factorsIterator (); it.hasNext ();) {
Factor factor = (Factor) it.next ();
VarSet varset = factor.varSet ();
for (java.util.Iterator vit = varset.iterator (); vit.hasNext ();) {
Variable var = (Variable) vit.next ();
oldMessages.put (var, factor, new TableFactor (var));
oldMessages.put (factor, var, new TableFactor (var));
}
}
}
}
transient protected int assignedVertexPtls[];
protected void initForGraph (FactorGraph mdl)
{
mdlCurrent = mdl;
int numV = mdl.numVariables ();
bel = new Factor [numV];
Object cache = mdl.getInferenceCache (getClass ());
if (useCaching && (cache != null)) {
messages = (MessageArray) cache;
} else {
messages = new MessageArray (mdl);
/*
// setup self-messages for vertex potentials
for (Iterator it = mdl.getVerticesIterator (); it.hasNext ();) {
Variable var = (Variable) it.next ();
Factor ptl = mdl.factorOfVar (var);
if (ptl != null) {
if (inLogSpace) {
logger.finer ("BeliefPropagation: Using log space.");
setMessage (i, i, new LogTableFactor ((AbstractTableFactor) ptl));
} else {
setMessage (i, i, ptl);
}
}
}
*/
}
initOldMessages (mdl);
messager.setMessageArray (messages, oldMessages);
}
protected void sendMessage (FactorGraph mdl, Variable from, Factor to)
{
totalMessagesSent++; myMessagesSent++;
// System.err.println (GeneralUtils.classShortName (this)+" send message "+from+" --> "+to);
messager.sendMessage (mdl, from, to);
}
protected void sendMessage (FactorGraph mdl, Factor from, Variable to)
{
totalMessagesSent++; myMessagesSent++;
// System.err.println (GeneralUtils.classShortName (this)+" send message "+from+" --> "+to);
messager.sendMessage (mdl, from, to);
}
protected void doneWithGraph (FactorGraph mdl)
{
clearOldMessages (); // free up memory
if (useCaching) cacheMessages (mdl);
}
public int iterationsUsed () { return iterUsed; }
public interface MessageStrategy {
void setMessageArray (MessageArray msgs, MessageArray oldMsgs);
void sendMessage (FactorGraph mdl, Factor from, Variable to);
void sendMessage (FactorGraph mdl, Variable from, Factor to);
Factor msgProduct (Factor product, int idx, int excludeMsgFrom);
}
public abstract static class AbstractMessageStrategy implements MessageStrategy {
protected MessageArray messages;
protected MessageArray oldMessages;
public void setMessageArray (MessageArray msgs, MessageArray oldMsgs)
{
messages = msgs;
oldMessages = oldMsgs;
}
public Factor msgProduct (Factor product, int idx, int excludeMsgFrom)
{
if (product == null) {
product = createEmptyFactorForVar (idx);
}
for (MessageArray.ToMsgsIterator it = messages.toMessagesIterator (idx); it.hasNext ();) {
it.next ();
int j = it.currentFromIdx ();
Factor msg = it.currentMessage ();
if (j != excludeMsgFrom) {
product.multiplyBy (msg);
// assert product.varSet ().size () <= 2;
}
}
return product;
}
private Factor createEmptyFactorForVar (int idx)
{
Factor product;
if (messages.isInLogSpace ()) {
product = new LogTableFactor ((Variable) messages.idx2obj (idx));
} else {
product = new TableFactor ((Variable) messages.idx2obj (idx));
}
return product;
}
}
public static class SumProductMessageStrategy extends AbstractMessageStrategy implements Serializable {
private double damping = 1.0;
public SumProductMessageStrategy ()
{
}
public SumProductMessageStrategy (double damping)
{
this.damping = damping;
}
public void sendMessage (FactorGraph mdl, Factor from, Variable to)
{
int fromIdx = messages.getIndex (from);
int toIdx = messages.getIndex (to);
Factor product = from.duplicate ();
msgProduct (product, fromIdx, toIdx);
Factor msg = product.marginalize (to);
msg.normalize ();
if (logger.isLoggable (Level.FINEST)) {
logger.info ("MSG "+from+" --> "+to);
logger.info ("FACTOR: "+from.dumpToString());
logger.info ("MSG: "+msg.dumpToString ());
logger.info ("END MSG "+from+" --> "+to);
}
assert msg.varSet ().size () == 1;
assert msg.varSet ().contains (to);
makeDampedUpdate (fromIdx, toIdx, msg);
}
public void sendMessage (FactorGraph mdl, Variable from, Factor to)
{
// System.err.println ("...sum-prod message");
int fromIdx = messages.getIndex (from);
int toIdx = messages.getIndex (to);
Factor msg = msgProduct (null, fromIdx, toIdx);
msg.normalize ();
assert msg.varSet ().size () == 1;
assert msg.varSet ().contains (from);
messages.put (fromIdx, toIdx, msg);
}
private void makeDampedUpdate (int fromIdx, int toIdx, Factor msg)
{
if (damping < 1.0) {
// there's damping
Factor oldMsg = oldMessages.get (fromIdx, toIdx);
// Factor oldMsg = messages.get (fromIdx, toIdx);
if (oldMsg != null) {
AbstractTableFactor oldTbl = (AbstractTableFactor) oldMsg.duplicate ();
oldTbl.normalize ();
oldTbl.timesEquals (1 - damping);
AbstractTableFactor tbl = (AbstractTableFactor) msg;
tbl.timesEquals (damping);
tbl.plusEquals (oldTbl);
msg = tbl;
}
}
messages.put (fromIdx, toIdx, msg);
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 2;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
out.writeDouble (damping);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
int version = in.readInt (); // version
if (2 <= version) {
damping = in.readDouble ();
}
}
}
public static class MaxProductMessageStrategy extends AbstractMessageStrategy implements Serializable {
public void sendMessage (FactorGraph mdl, Factor from, Variable to)
{
// System.err.println ("...max-prod message");
int fromIdx = messages.getIndex (from);
int toIdx = messages.getIndex (to);
Factor product = from.duplicate ();
msgProduct (product, fromIdx, toIdx);
Factor msg = product.extractMax (to);
msg.normalize ();
assert msg.varSet ().size () == 1;
assert msg.varSet ().contains (to);
messages.put (fromIdx, toIdx, msg);
}
public void sendMessage (FactorGraph mdl, Variable from, Factor to)
{
// System.err.println ("...max-prod message");
int fromIdx = messages.getIndex (from);
int toIdx = messages.getIndex (to);
Factor msg = msgProduct (null, fromIdx, toIdx);
msg.normalize ();
assert msg.varSet ().size () == 1;
assert msg.varSet ().contains (from);
messages.put (fromIdx, toIdx, msg);
}
// Serialization
private static final long serialVersionUID = 1;
private static final int CUURENT_SERIAL_VERSION = 1;
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
out.writeInt (CUURENT_SERIAL_VERSION);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
in.readInt (); // version
}
}
public Factor lookupMarginal (Variable var)
{
int idx = mdlCurrent.getIndex (var);
if ((idx < 0) || (idx > bel.length)) {
throw new IllegalArgumentException ("Cannot find variable "+var+" in factor graph "+mdlCurrent);
}
if (bel[idx] == null) {
Factor marg = messager.msgProduct (null, idx, Integer.MIN_VALUE);
if (normalizeBeliefs) {
marg.normalize ();
}
assert marg.varSet ().size () == 1
:"Invalid marginal for var " + var + ": " + marg;
assert marg.varSet ().contains (var)
:"Invalid marginal for var " + var + ": " + marg;
bel[idx] = marg;
}
return bel[idx];
}
public void dump ()
{
messages.dump ();
}
public void reportTime ()
{
System.err.println ("AbstractBeliefPropagation: Total messages sent = "+totalMessagesSent);
}
public void dump (PrintWriter writer)
{
messages.dump (writer);
}
// }}}
// Serialization
private static final long serialVersionUID = 1;
// If seralization-incompatible changes are made to these classes,
// then smarts can be added to these methods for backward compatibility.
private void writeObject (ObjectOutputStream out) throws IOException
{
out.defaultWriteObject ();
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException
{
in.defaultReadObject ();
}
public Factor lookupMarginal (VarSet c)
{
if (c.size () == 1) {
return lookupMarginal (c.get (0));
} else {
List factors = mdlCurrent.allFactorsOf (c);
if (factors.isEmpty ()) {
throw new UnsupportedOperationException
("Cannot compute marginal of " + c + ": Must be either a single variable or a factor in the graph.");
}
return lookupMarginal (c, factors);
}
}
private Factor lookupMarginal (VarSet vs, List factors)
{
Factor marginal = Factors.multiplyAll (factors);
for (Iterator fit = factors.iterator (); fit.hasNext ();) {
Factor factor = (Factor) fit.next ();
for (java.util.Iterator it = vs.iterator (); it.hasNext ();) {
Variable var = (Variable) it.next ();
Factor msg = messages.get (var, factor);
if (msg != null) { // if the inferencer was stopped early, there may be no message
marginal.multiplyBy (msg);
}
}
}
marginal.normalize ();
return marginal;
}
public double lookupLogJoint (Assignment assn)
{
double accum = 0.0;
// Compute using BP-factorization
// prod_s (p(x_s))^-(deg(s)-1) * ...
for (java.util.Iterator it = mdlCurrent.variablesIterator (); it.hasNext ();) {
Variable var = (Variable) it.next ();
Factor ptl = lookupMarginal (var);
int deg = mdlCurrent.getDegree (var);
if (deg != 1)
// Note that below works correctly for degree 0!
{
accum -= (deg - 1) * ptl.logValue (assn);
}
}
// ... * prod_{c} p(x_C)
for (java.util.Iterator it = mdlCurrent.varSetIterator (); it.hasNext ();) {
VarSet varSet = (VarSet) it.next ();
Factor p12 = lookupMarginal (varSet);
double logphi = p12.logValue (assn);
accum += logphi;
}
return accum;
}
}