/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.inference.gbp; import java.util.Iterator; import cc.mallet.grmm.types.*; /** * Created: Jun 1, 2005 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: SparseMessageSender.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $ */ public class SparseMessageSender extends AbstractMessageStrategy { private double epsilon; public SparseMessageSender (double epsilon) { this.epsilon = epsilon; } public void sendMessage (RegionEdge edge) { Factor product = msgProduct (edge); for (Iterator it = edge.factorsToSend.iterator (); it.hasNext ();) { Factor ptl = (Factor) it.next (); product.multiplyBy (ptl); } TableFactor result = (TableFactor) product.marginalize (edge.to.vars); result.normalize (); TableFactor pruned; if (shouldPruneMessage (edge, result)) { // if (edge.to.vars.size() > 1) { pruned = Factors.retainMass (result, epsilon); pruned.normalize(); // System.err.println ("Potential pruning.\nPRE:"+result+"\nPOST:"+pruned); } else { // Only prune messages to leaves pruned = result; // System.err.println ("Message for edge "+edge+" not pruned."); } newMessages.setMessage (edge.from, edge.to, pruned); } public MessageArray averageMessages (RegionGraph rg, MessageArray a1, MessageArray a2, double inertiaWeight) { MessageArray arr = new MessageArray (rg); for (Iterator it = rg.edgeIterator (); it.hasNext ();) { RegionEdge edge = (RegionEdge) it.next (); Factor msg1 = a1.getMessage (edge.from, edge.to); Factor msg2 = a2.getMessage (edge.from, edge.to); if (msg1 != null) { TableFactor averaged = (TableFactor) Factors.average (msg1, msg2, inertiaWeight); TableFactor pruned; if (shouldPruneMessage (edge, averaged)) { pruned = Factors.retainMass (averaged, epsilon); } else { pruned = averaged; } arr.setMessage (edge.from, edge.to, pruned); } } // compute amount of sparsity int locs = 0; int idxs = 0; for (Iterator it = rg.edgeIterator (); it.hasNext ();) { RegionEdge edge = (RegionEdge) it.next (); DiscreteFactor msg = arr.getMessage (edge.from, edge.to); locs += msg.numLocations (); idxs += new HashVarSet (msg.varSet ()).weight (); } System.out.println ("Sparsity quotient = "+locs+" of "+idxs); return arr; } private boolean shouldPruneMessage (RegionEdge edge, Factor msg) { return edge.to.children.isEmpty (); } }