/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.inference.gbp; import gnu.trove.THashSet; import java.util.*; import cc.mallet.grmm.types.Factor; import cc.mallet.grmm.types.VarSet; import cc.mallet.grmm.types.Variable; /** * Created: May 27, 2005 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: RegionGraph.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $ */ class RegionGraph { private Set regions = new THashSet (); private List edges = new ArrayList (); public RegionGraph () { } void add (Region parent, Region child) { if (!isConnected (parent, child)) { addRegion (parent); addRegion (child); child.isRoot = false; if (parent.children == null) parent.children = new ArrayList (); parent.children.add (child); if (child.parents == null) child.parents = new ArrayList (); child.parents.add (parent); edges.add (new RegionEdge (parent, child)); } } private boolean isConnected (Region parent, Region child) { return (parent.children.contains (child)); } private void addRegion (Region region) { if (regions.add (region)) { if (region.index != -1) { throw new IllegalArgumentException ("Region "+region+" has already been added to a different region graph."); } region.index = regions.size() - 1; } } int size () { return regions.size (); } Iterator iterator () { return regions.iterator (); } Iterator edgeIterator () { return edges.iterator (); } public void computeInferenceCaches () { computeDescendants (); includeDescendantFactors (); computeFactorsToSend (); computeCountingNumbers (); computeCousins (); computeNeighboringParents (); computeLoopingMessages (); // todo: Compute D(P,R) as well } private void includeDescendantFactors () { // Slightly inefficient: A recursive soln would be more efficient for (Iterator it = iterator (); it.hasNext();) { Region region = (Region) it.next (); for (Iterator dIt = region.descendants.iterator (); dIt.hasNext ();) { Region descendant = (Region) dIt.next (); // factors is a set, so it avoids duplicates region.factors.addAll (descendant.factors); } } } private void computeLoopingMessages () { for (Iterator it = edgeIterator (); it.hasNext();) { RegionEdge edge = (RegionEdge) it.next (); Region to = edge.to; List result = new ArrayList (); for (Iterator cousinIt = edge.cousins.iterator (); cousinIt.hasNext ();) { Region cousin = (Region) cousinIt.next (); if (cousin == edge.from) continue; for (Iterator edgeIt = cousin.children.iterator (); edgeIt.hasNext();) { Region cousinChild = (Region) edgeIt.next (); if (cousinChild == to || to.descendants.contains (cousinChild)) { result.add (findEdge (cousin, cousinChild)); } } } edge.loopingMessages = result; } } // computes region graph counting numbers as defined in Yedidia et al. private void computeCountingNumbers () { LinkedList queue = new LinkedList (); for (Iterator it = regions.iterator (); it.hasNext ();) { Region region = (Region) it.next (); if (region.isRoot) queue.add (region); } while (!queue.isEmpty()) { Region region = (Region) queue.removeFirst (); int parentCnt = 0; for (Iterator it = region.parents.iterator (); it.hasNext ();) { Region parent = (Region) it.next (); parentCnt += parent.countingNumber; } region.countingNumber = 1 - parentCnt; queue.addAll (region.children); } } private void computeFactorsToSend () { for (Iterator it = edges.iterator (); it.hasNext ();) { RegionEdge edge = (RegionEdge) it.next (); edge.initializeFactorsToSend (); } } private void computeCousins () { for (Iterator it = edgeIterator (); it.hasNext();) { RegionEdge edge = (RegionEdge) it.next (); Set cousins = new THashSet (edge.from.descendants); cousins.removeAll (edge.to.descendants); cousins.remove (edge.to); cousins.add (edge.from); edge.cousins = cousins; } } private void computeDescendants () { for (Iterator it = regions.iterator (); it.hasNext ();) { Region region = (Region) it.next (); if (region.isRoot) { computeDescendantsRec (region); } } } private void computeDescendantsRec (Region region) { Set descendants = new THashSet (region.children.size ()); // all region graphs are DAGs, so no infinite regress for (Iterator it = region.children.iterator (); it.hasNext();) { Region child = (Region) it.next (); computeDescendantsRec (child); descendants.add (child); descendants.addAll (child.descendants); } region.descendants = descendants; } private void computeNeighboringParents () { for (Iterator it = edgeIterator (); it.hasNext();) { RegionEdge edge = (RegionEdge) it.next (); edge.neighboringParents = new ArrayList (); List l = new LinkedList (regions); l.removeAll (edge.from.descendants); l.remove (edge.from); for (Iterator uncleIt = l.iterator (); uncleIt.hasNext ();) { Region uncle = (Region) uncleIt.next (); for (Iterator childIt = uncle.children.iterator (); childIt.hasNext();) { Region cousin = (Region) childIt.next (); if (edge.cousins.contains (cousin)) { edge.neighboringParents.add (findEdge (uncle, cousin)); } } } } } // horrifically inefficient private RegionEdge findEdge (Region uncle, Region cousin) { int idx = edges.indexOf (new RegionEdge (uncle, cousin)); return (RegionEdge) edges.get (idx); } public String toString () { StringBuffer buf = new StringBuffer (); buf.append ("REGION GRAPH\nRegions:\n"); for (Iterator it = regions.iterator (); it.hasNext ();) { Region region = (Region) it.next (); buf.append ("\n "); buf.append (region); } buf.append ("\nEdges:"); for (Iterator it = edges.iterator (); it.hasNext ();) { RegionEdge edge = (RegionEdge) it.next (); buf.append ("\n "); buf.append (edge.from); buf.append (" --> "); buf.append (edge.to); } buf.append ("\n"); return buf.toString (); } public boolean contains (Region region) { return regions.contains (region); } /** Returns the region in this graph whose factor list contains only * a given potential. * @param ptl * @param doCreate If true, an appropriate region will be created and added * to graph if none is found. * @return A region, or null if no region found and doCreate false. */ public Region findRegion (Factor ptl, boolean doCreate) { Set allVars = ptl.varSet (); for (Iterator it = regions.iterator (); it.hasNext ();) { Region region = (Region) it.next (); if (region.vars.size() == allVars.size() && region.vars.containsAll (allVars)) return region; } if (doCreate) { Region region = new Region (ptl); addRegion (region); return region; } else { return null; } } /** Returns the region in this graph whose variable list contains only * a given variable. * @param var * @param doCreate If true, an appropriate region will be created and added * to graph if none is found. * @return A region, or null if no region found and doCreate false. */ public Region findRegion (Variable var, boolean doCreate) { for (Iterator it = regions.iterator (); it.hasNext ();) { Region region = (Region) it.next (); if ((region.vars.size() == 1) && (region.vars.contains (var))) { return region; } } if (doCreate) { Region region = new Region (var); addRegion (region); return region; } else { return null; } } /** Finds the smallest region containing a given variable. * This might return a region that contains many extraneous variables. * @param variable * @return */ public Region findContainingRegion (Variable variable) { Region ret = null; for (Iterator it = regions.iterator (); it.hasNext ();) { Region region = (Region) it.next (); if (region.vars.contains (variable)) { if (ret == null || region.vars.size() < ret.vars.size ()) ret = region; } } return ret; } /** Finds the smallest region containing all the variables in a given set. * This might return a region that contains many extraneous variables. * @param varSet * @return */ public Region findContainingRegion (VarSet varSet) { Region ret = null; for (Iterator it = regions.iterator (); it.hasNext ();) { Region region = (Region) it.next (); if (region.vars.containsAll (varSet)) { if (ret == null || region.vars.size() < ret.vars.size ()) ret = region; } } return ret; } public int numEdges () { return edges.size (); } }