/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.inference.gbp; import java.util.Iterator; import cc.mallet.grmm.types.Factor; import cc.mallet.grmm.types.FactorGraph; import cc.mallet.grmm.types.UndirectedGrid; import cc.mallet.grmm.types.Variable; import cc.mallet.util.ArrayUtils; /** * Created: May 31, 2005 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: Kikuchi4SquareRegionGenerator.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $ */ public class Kikuchi4SquareRegionGenerator implements RegionGraphGenerator { public RegionGraph constructRegionGraph (FactorGraph mdl) { if (mdl instanceof UndirectedGrid) { RegionGraph rg = new RegionGraph (); UndirectedGrid grid = (UndirectedGrid) mdl; // First set up regions for all for (int x = 0; x < grid.getWidth () - 1; x++) { for (int y = 0; y < grid.getHeight () - 1; y++) { Variable[] vars = new Variable[] { grid.get (x, y), grid.get (x+1, y), grid.get (x+1, y+1), grid.get (x, y+1), }; Factor[] edges = new Factor[] { mdl.factorOf (vars[0], vars[1]), mdl.factorOf (vars[1], vars[2]), mdl.factorOf (vars[2], vars[3]), mdl.factorOf (vars[0], vars[3]), }; // Create region for 4-clique Region fourSquare = new Region (vars, edges); // Create 1-clique region for (int i = 0; i < 4; i++) { Variable var = vars[i]; Factor ptl = mdl.factorOf (var); if (ptl != null) { fourSquare.factors.add (ptl); } } // Finally create edge regions, and connect to everyone else for (int i = 0; i < 4; i++) { Factor edgePtl = edges[i]; Region edgeRgn = rg.findRegion (edgePtl, true); rg.add (fourSquare, edgeRgn); Variable v1 = (Variable) edgeRgn.vars.get (0); Region nodeRgn = createVarRegion (rg, mdl, v1); edgeRgn.factors.addAll (nodeRgn.factors); rg.add (edgeRgn, nodeRgn); Variable v2 = (Variable) edgeRgn.vars.get (1); nodeRgn = createVarRegion (rg, mdl, v2); edgeRgn.factors.addAll (nodeRgn.factors); rg.add (edgeRgn, nodeRgn); } } } rg.computeInferenceCaches (); return rg; } else { throw new UnsupportedOperationException ("Kikuchi4SquareRegionGenerator requires that you use UndirectedGrid."); } } private Region createVarRegion (RegionGraph rg, FactorGraph mdl, Variable v1) { Factor ptl = mdl.factorOf (v1); if (ptl == null) { return rg.findRegion (v1, true); } else { return rg.findRegion (ptl, true); } } private void checkAllSingles (RegionGraph rg, Region[] nodeRegions) { for (Iterator it = rg.iterator (); it.hasNext ();) { Region region = (Region) it.next (); if (region.vars.size() == 1) { if (ArrayUtils.indexOf (nodeRegions, region) < 0) { throw new IllegalStateException ("huh?"); } } } } private void checkTooManyDoubles (RegionGraph rg, FactorGraph mdl) { int nv = mdl.factors ().size (); int doubles = 0; for (Iterator it = rg.iterator (); it.hasNext ();) { Region region = (Region) it.next (); if (region.vars.size() == 2) doubles++; } if (doubles > nv) { throw new IllegalStateException ("huh? "); } } }