/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.inference.gbp; import java.util.Iterator; import cc.mallet.grmm.types.*; /** * A first implementation of MessageStrategy that assumes that a BP region graph * is being used. * * Created: May 29, 2005 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: FullMessageStrategy.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $ */ public class FullMessageStrategy extends AbstractMessageStrategy { private static final boolean debug = false; private static final boolean debugLite = false; public FullMessageStrategy () { } public void sendMessage (RegionEdge edge) { if (debugLite) { System.err.println ("Sending message "+edge); } Factor product = msgProduct (edge); Region from = edge.from; Region to = edge.to; if (debug) System.err.println ("Message "+from+" --> "+to+" after msgProduct: "+product); for (Iterator it = edge.factorsToSend.iterator (); it.hasNext ();) { Factor ptl = (Factor) it.next (); product.multiplyBy (ptl); } TableFactor result = (TableFactor) product.marginalize (to.vars); result.normalize (); if (debug) { System.err.println ("Final message "+edge+":"+result); } newMessages.setMessage (from, to, result); } /* static void multiplyEdgeFactors (RegionEdge edge, DiscretePotential product) { for (Iterator it = edge.factorsToSend.iterator (); it.hasNext ();) { DiscretePotential ptl = (DiscretePotential) it.next (); if (debug) System.err.println ("Message "+edge+" multiplying by: "+ptl); product.multiplyBy (ptl); } } */ // debugging function private boolean willBeNaN (Factor product, Factor otherMsg) { Factor p2 = product.duplicate (); p2.divideBy (otherMsg); return p2.isNaN (); } // debugging function private boolean willBeNaN2 (Factor product, Factor otherMsg) { Factor p2 = product.duplicate (); p2.multiplyBy (otherMsg); return p2.isNaN (); } public MessageArray averageMessages (RegionGraph rg, MessageArray a1, MessageArray a2, double inertiaWeight) { MessageArray arr = new MessageArray (rg); for (Iterator it = rg.edgeIterator (); it.hasNext ();) { RegionEdge edge = (RegionEdge) it.next (); DiscreteFactor msg1 = a1.getMessage (edge.from, edge.to); DiscreteFactor msg2 = a2.getMessage (edge.from, edge.to); if (msg1 != null) { TableFactor averaged = (TableFactor) Factors.average (msg1, msg2, inertiaWeight); arr.setMessage (edge.from, edge.to, averaged); } } return arr; } }