/* Copyright (C) 2006 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://mallet.cs.umass.edu/ This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.inference; import gnu.trove.THashSet; import java.util.Iterator; import cc.mallet.grmm.types.Factor; import cc.mallet.grmm.types.FactorGraph; import cc.mallet.grmm.types.Variable; /** * Implements the tree-based schedule of belief propagation for exact inference * in trees. Can be used either for sum-product or max-product. * <p> * <p> * Do not use the * Created: Feb 1, 2006 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: TreeBP.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $ */ public class TreeBP extends AbstractBeliefPropagation { transient private THashSet marked; transient private Variable root; public static TreeBP createForMaxProduct () { return (TreeBP) new TreeBP ().setMessager (new MaxProductMessageStrategy ()); } public void computeMarginals (FactorGraph fg) { initForGraph (fg); marked = new THashSet (); lambdaPropagation (fg, null, root); marked = new THashSet (); piPropagation (fg, root); } protected void initForGraph (FactorGraph fg) { super.initForGraph (fg); // Pick a root arbitrarily root = (Variable) fg.variablesIterator ().next (); } private void lambdaPropagation (FactorGraph mdl, Factor parent, Variable child) { logger.finer ("lambda propagation "+parent+" , "+child); marked.add (child); for (Iterator it = mdl.allFactorsContaining (child).iterator(); it.hasNext();) { Factor gchild = (Factor) it.next(); if (!marked.contains (gchild)) { lambdaPropagation (mdl, child, gchild); } } if (parent != null) { // sendLambdaMessage (mdl, child, parent); sendMessage (mdl, child, parent); } } private void lambdaPropagation (FactorGraph mdl, Variable parent, Factor child) { logger.finer ("lambda propagation "+parent+" , "+child); marked.add (child); for (Iterator it = child.varSet ().iterator(); it.hasNext();) { Variable gchild = (Variable) it.next(); if (!marked.contains (gchild)) { lambdaPropagation (mdl, child, gchild); } } if (parent != null) { // sendLambdaMessage (mdl, child, parent); sendMessage (mdl, child, parent); } } private void piPropagation (FactorGraph mdl, Variable var) { logger.finer ("Pi propagation from "+var); marked.add (var); for (Iterator it = mdl.allFactorsContaining (var).iterator(); it.hasNext();) { Factor child = (Factor) it.next(); if (!marked.contains (child)) { // sendPiMessage (mdl, var, child); sendMessage (mdl, var, child); piPropagation (mdl, child); } } } private void piPropagation (FactorGraph mdl, Factor factor) { logger.finer ("Pi propagation from "+factor); marked.add (factor); for (Iterator it = factor.varSet ().iterator(); it.hasNext();) { Variable child = (Variable) it.next(); if (!marked.contains (child)) { // sendPiMessage (mdl, var, child); sendMessage (mdl, factor, child); piPropagation (mdl, child); } } } }