/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.inference.gbp; import java.util.*; import java.util.logging.Logger; import cc.mallet.grmm.types.Factor; import cc.mallet.grmm.types.FactorGraph; import cc.mallet.grmm.types.UndirectedGrid; import cc.mallet.grmm.types.Variable; import cc.mallet.util.CollectionUtils; import cc.mallet.util.MalletLogger; /** * Created: Jun 1, 2005 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: ClusterVariationalRegionGenerator.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $ */ public class ClusterVariationalRegionGenerator implements RegionGraphGenerator { private static final Logger logger = MalletLogger.getLogger (ClusterVariationalRegionGenerator.class.getName()); private static final boolean debug = false; private BaseRegionComputer regionComputer; public ClusterVariationalRegionGenerator () { this (new ByFactorRegionComputer ()); } public ClusterVariationalRegionGenerator (BaseRegionComputer regionComputer) { this.regionComputer = regionComputer; } public RegionGraph constructRegionGraph (FactorGraph mdl) { RegionGraph rg = new RegionGraph (); int depth = 0; List baseRegions = regionComputer.computeBaseRegions (mdl); List theseRegions = baseRegions; while (!theseRegions.isEmpty ()) { if (debug) System.out.println ("Depth 0 regions:\n"+CollectionUtils.dumpToString (theseRegions, "\n ")); List overlaps = computeOverlaps (theseRegions); addEdgesForOverlaps (rg, theseRegions, overlaps); theseRegions = overlaps; depth++; } rg.computeInferenceCaches (); logger.info ("ClusterVariationalRegionGenerator: Number of regions "+rg.size()+" Number of edges:"+rg.numEdges()); return rg; } private List computeOverlaps (List regions) { List overlaps = new ArrayList (); for (Iterator it1 = regions.iterator (); it1.hasNext ();) { Region r1 = (Region) it1.next (); for (Iterator it2 = regions.iterator (); it2.hasNext ();) { Region r2 = (Region) it2.next (); if (r1 != r2) { Collection intersection = CollectionUtils.intersection (r1.vars, r2.vars); if (!intersection.isEmpty () && !anySubsumes (overlaps, intersection)) { Collection ptlSet = CollectionUtils.intersection (r1.factors, r2.factors); Variable[] vars = (Variable[]) intersection.toArray (new Variable[intersection.size ()]); Factor[] ptls = (Factor[]) ptlSet.toArray (new Factor [ptlSet.size ()]); Region r = new Region (vars, ptls); overlaps.add (r); } } } } // We can still have subsumed regions in the list if the smaller region was added first. for (ListIterator it = overlaps.listIterator (); it.hasNext ();) { Region region = (Region) it.next (); List otherRegions = overlaps.subList (it.nextIndex (), overlaps.size ()); if (anySubsumes (otherRegions, region.vars)) { it.remove (); } } return overlaps; } /** Returns true if any region in regions contains all the variables in vars. */ private boolean anySubsumes (List regions, Collection vars) { for (Iterator it = regions.iterator (); it.hasNext ();) { Region region = (Region) it.next (); if (region.vars.containsAll (vars)) return true; } return false; } private void addEdgesForOverlaps (RegionGraph rg, List fromList, List toList) { for (Iterator fromIt = fromList.iterator (); fromIt.hasNext ();) { Region from = (Region) fromIt.next (); for (Iterator toIt = toList.iterator (); toIt.hasNext ();) { Region to = (Region) toIt.next (); if (from.vars.containsAll (to.vars)) { rg.add (from, to); } } } } // computing base regions public static void removeSubsumedRegions (List regions) { for (ListIterator it = regions.listIterator (); it.hasNext ();) { Region region = (Region) it.next (); for (Iterator it2 = regions.iterator (); it2.hasNext();) { Region r2 = (Region) it2.next (); if (r2 != region && r2.vars.size() >= region.vars.size ()) { if (r2.vars.containsAll (region.vars)) { it.remove (); break; } } } } } public static void addAllFactors (FactorGraph mdl, List regions) { for (Iterator it = regions.iterator (); it.hasNext ();) { Region region = (Region) it.next (); for (Iterator pIt = mdl.factorsIterator (); pIt.hasNext();) { Factor ptl = (Factor) pIt.next (); if (region.vars.containsAll (ptl.varSet ())) { region.factors.add (ptl); } } } } public static interface BaseRegionComputer { /** * Returns a list of top-level regions for use in the cluster variational method. * @param mdl An undirected model. * @return A list of regions. No region in the list may subsume another. */ List computeBaseRegions (FactorGraph mdl); } /** * Region computer where each top-level region consists of a single factor node. * If the model is pairwise, this is equivalent to using the Bethe free energy. */ public static class ByFactorRegionComputer implements BaseRegionComputer { public List computeBaseRegions (FactorGraph mdl) { List regions = new ArrayList (mdl.factors ().size ()); for (Iterator it = mdl.factorsIterator (); it.hasNext ();) { Factor ptl = (Factor) it.next (); regions.add (new Region (ptl)); } removeSubsumedRegions (regions); addAllFactors (mdl, regions); return regions; } } public static class Grid2x2RegionComputer implements BaseRegionComputer { public List computeBaseRegions (FactorGraph mdl) { List regions = new ArrayList (); UndirectedGrid grid = (UndirectedGrid) mdl; for (int x = 0; x < grid.getWidth() - 1; x++) { for (int y = 0; y < grid.getHeight() - 1; y++) { Variable[] vars = new Variable[] { grid.get (x, y), grid.get (x, y+1), grid.get (x+1, y+1), grid.get (x+1, y), }; regions.add (new Region (vars, new Factor[0])); } } addAllFactors (mdl, regions); return regions; } } }