/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.inference; import java.util.Iterator; import java.util.Collections; import java.util.ArrayList; import java.util.Random; import java.io.ObjectOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import cc.mallet.grmm.types.*; /** * The loopy belief propagation algorithm for approximate inference in * general graphical models. * * Created: Wed Nov 5 19:30:15 2003 * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: LoopyBP.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $ */ public class LoopyBP extends AbstractBeliefPropagation { public static final int DEFAULT_MAX_ITER = 1000; private int maxIter; private Random rand = new Random (); public void setUseCaching (boolean useCaching) { this.useCaching = useCaching; } // Note this does not have the sophisticated Terminator interface // that we've got in TRP. public LoopyBP () { this (new AbstractBeliefPropagation.SumProductMessageStrategy (), DEFAULT_MAX_ITER); } public LoopyBP (int maxIter) { this (new AbstractBeliefPropagation.SumProductMessageStrategy (), maxIter); } public LoopyBP (AbstractBeliefPropagation.MessageStrategy messager, int maxIter) { super (messager); this.maxIter = maxIter; } public static Inferencer createForMaxProduct () { return new LoopyBP (new MaxProductMessageStrategy (), DEFAULT_MAX_ITER); } public LoopyBP setRand (Random rand) { this.rand = rand; return this; } public void computeMarginals (FactorGraph mdl) { super.initForGraph (mdl); int iter; for (iter = 0; iter < maxIter; iter++) { logger.finer ("***AsyncLoopyBP iteration "+iter); propagate (mdl); if (hasConverged ()) break; copyOldMessages (); } iterUsed = iter; if (iter >= maxIter) { logger.info ("***Loopy BP quitting: not converged after "+maxIter+" iterations."); } else { iterUsed++; // there's an off-by-one b/c of location of above break logger.info ("***AsyncLoopyBP converged: "+iterUsed+" iterations"); } doneWithGraph (mdl); } private void propagate (FactorGraph mdl) { // Send all messages in random order. ArrayList factors = new ArrayList (mdl.factors()); Collections.shuffle (factors, rand); for (Iterator it = factors.iterator(); it.hasNext();) { Factor factor = (Factor) it.next(); for (Iterator vit = factor.varSet ().iterator (); vit.hasNext ();) { Variable from = (Variable) vit.next (); sendMessage (mdl, from, factor); } } for (Iterator it = factors.iterator(); it.hasNext();) { Factor factor = (Factor) it.next(); for (Iterator vit = factor.varSet ().iterator (); vit.hasNext ();) { Variable to = (Variable) vit.next (); sendMessage (mdl, factor, to); } } } // Serialization private static final long serialVersionUID = 1; // If seralization-incompatible changes are made to these classes, // then smarts can be added to these methods for backward compatibility. private void writeObject (ObjectOutputStream out) throws IOException { out.defaultWriteObject (); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { in.defaultReadObject (); } }