/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.inference.gbp; import java.util.logging.Logger; import java.util.logging.Level; import java.util.*; import cc.mallet.grmm.inference.AbstractInferencer; import cc.mallet.grmm.types.*; import cc.mallet.util.MalletLogger; import cc.mallet.util.Timing; /** * Created: May 27, 2005 * * @author <A HREF="mailto:casutton@cs.umass.edu>casutton@cs.umass.edu</A> * @version $Id: ParentChildGBP.java,v 1.1 2007/10/22 21:37:58 mccallum Exp $ */ public class ParentChildGBP extends AbstractInferencer { private static final Logger logger = MalletLogger.getLogger (ParentChildGBP.class.getName()); private static final boolean debug = false; private RegionGraphGenerator regioner; private MessageStrategy sender; private boolean useInertia = true; private double inertiaWeight = 0.5; // convergence criteria private static final double THRESHOLD = 1e-3; private static final int MAX_ITER = 500; // current inferencing state private MessageArray oldMessages; private MessageArray newMessages; private RegionGraph rg; private FactorGraph mdl; private ParentChildGBP () { } public ParentChildGBP (RegionGraphGenerator regioner) { this (regioner, new FullMessageStrategy ()); } public ParentChildGBP (RegionGraphGenerator regioner, MessageStrategy sender) { this.regioner = regioner; this.sender = sender; } public static ParentChildGBP makeBPInferencer () { ParentChildGBP inferencer = new ParentChildGBP (); inferencer.regioner = new BPRegionGenerator (); inferencer.sender = new FullMessageStrategy (); return inferencer; } public static ParentChildGBP makeKikuchiInferencer () { ParentChildGBP inferencer = new ParentChildGBP (); inferencer.regioner = new Kikuchi4SquareRegionGenerator (); inferencer.sender = new FullMessageStrategy (); return inferencer; } // accessors public boolean getUseInertia () { return useInertia; } public void setUseInertia (boolean useInertia) { this.useInertia = useInertia; } public double getInertiaWeight () { return inertiaWeight; } public void setInertiaWeight (double inertiaWeight) { this.inertiaWeight = inertiaWeight; } // inferencer interface public Factor lookupMarginal (Variable variable) { Region region = rg.findContainingRegion (variable); if (region == null) throw new IllegalArgumentException ("Could not find region containing variable "+variable+" in region graph "+rg); Factor belief = computeBelief (region); Factor varBelief = belief.marginalize (variable); return varBelief; } public Factor lookupMarginal (VarSet varSet) { Region region = rg.findContainingRegion (varSet); if (region == null) throw new IllegalArgumentException ("Could not find region containing clique "+varSet +" in region graph "+rg); Factor belief = computeBelief (region); Factor cliqueBelief = belief.marginalize (varSet); return cliqueBelief; } private Factor computeBelief (Region region) { return computeBelief (region, newMessages); } static Factor computeBelief (Region region, MessageArray messages) { DiscreteFactor result = new LogTableFactor(region.vars); for (Iterator it = region.factors.iterator(); it.hasNext();) { Factor factor = (Factor) it.next(); result.multiplyBy(factor); } for (Iterator it = region.parents.iterator(); it.hasNext();) { Region parent = (Region) it.next(); Factor msg = messages.getMessage(parent, region); result.multiplyBy(msg); } for (Iterator it = region.descendants.iterator(); it.hasNext();) { Region child = (Region) it.next(); for (Iterator it2 = child.parents.iterator(); it2.hasNext();) { Region uncle = (Region) it2.next(); if (uncle != region && !region.descendants.contains(uncle)) { result.multiplyBy(messages.getMessage(uncle, child)); } } } result.normalize(); return result; } public double lookupLogJoint (Assignment assn) { double factorProduct = mdl.logValue (assn); // value += computeFreeEnergy (rg); double F = computeFreeEnergy (rg); double value = factorProduct + F; if (debug) System.err.println ("GBP factor product:"+factorProduct+" + free energy: "+F+" = value:"+value); return value; } private double computeFreeEnergy (RegionGraph rg) { double avgEnergy = 0; double entropy = 0; for (Iterator it = rg.iterator (); it.hasNext();) { Region region = (Region) it.next(); Factor belief = computeBelief(region); double thisEntropy = belief.entropy(); if (debug) System.err.println("Region " + region + " c:" + region.countingNumber + " entropy:" + thisEntropy); entropy += region.countingNumber * thisEntropy; DiscreteFactor product = new LogTableFactor(belief.varSet()); for (Iterator ptlIt = region.factors.iterator(); ptlIt.hasNext();) { Factor ptl = (Factor) ptlIt.next(); product.multiplyBy(ptl); } double thisAvgEnergy = 0; for (AssignmentIterator assnIt = belief.assignmentIterator(); assnIt.hasNext();) { Assignment assn = assnIt.assignment(); // Note: Do not use assnIt here before fixing variable ordering issues. double thisEnergy = -product.logValue(assn); // double thisEnergy = product.phi (assnIt); double thisBel = belief.value(assn); thisAvgEnergy += thisBel * thisEnergy; assnIt.advance(); } if (debug) { System.err.println("Region " + region + " c:" + region.countingNumber + " avgEnergy: " + thisAvgEnergy); /* DiscretePotential b2 = belief.duplicate (); b2.delogify (); System.err.println ("BELIEF:"+b2); System.err.println ("ENERGY:"+product); */ } avgEnergy += region.countingNumber * thisAvgEnergy; } if (debug) System.err.println ("GBP computeFreeEnergy: avgEnergy:"+avgEnergy+" entropy:"+entropy+" free energy:"+(avgEnergy-entropy)); // return avgEnergy + entropy; return avgEnergy - entropy; } public void computeMarginals (FactorGraph mdl) { Timing timing = new Timing (); this.mdl = mdl; rg = regioner.constructRegionGraph (mdl); RegionEdge[] pairs = chooseMessageSendingOrder (); newMessages = new MessageArray (rg); timing.tick ("GBP Region Graph construction"); int iter = 0; do { oldMessages = newMessages; newMessages = oldMessages.duplicate (); sender.setMessageArray (oldMessages, newMessages); for (int i = 0; i < pairs.length; i++) { RegionEdge edge = pairs[i]; sender.sendMessage (edge); } if (logger.isLoggable (Level.FINER)) { timing.tick ("GBP iteration "+iter); } iter++; if (useInertia) newMessages = sender.averageMessages (rg, oldMessages, newMessages, inertiaWeight); } while (!hasConverged () && (iter < MAX_ITER)); logger.info ("GBP: Used "+iter+" iterations."); if (iter >= MAX_ITER) { logger.warning ("***WARNING: GBP not converged!"); } } private RegionEdge[] chooseMessageSendingOrder () { List l = new ArrayList (); for (Iterator it = rg.edgeIterator (); it.hasNext();) { RegionEdge edge = (RegionEdge) it.next (); l.add (edge); } Collections.sort (l, new Comparator () { public int compare (Object o1, Object o2) { RegionEdge e1 = (RegionEdge) o1; RegionEdge e2 = (RegionEdge) o2; int l1 = e1.to.vars.size(); int l2 = e2.to.vars.size(); return Double.compare (l1, l2); }; }); return (RegionEdge[]) l.toArray (new RegionEdge [l.size()]); } private boolean hasConverged () { for (Iterator it = rg.edgeIterator (); it.hasNext();) { RegionEdge edge = (RegionEdge) it.next (); Factor oldMsg = oldMessages.getMessage (edge.from, edge.to); Factor newMsg = newMessages.getMessage (edge.from, edge.to); if (oldMsg == null) { assert newMsg == null; } else { if (!oldMsg.almostEquals (newMsg, THRESHOLD)) { /* //xxx debug if (sender instanceof SparseMessageSender) System.out.println ("NOT CONVERGED:\n"+newMsg+"\n......."); */ return false; } } } return true; } public void dump () { for (Iterator it = rg.edgeIterator (); it.hasNext();) { RegionEdge edge = (RegionEdge) it.next (); Factor newMsg = newMessages.getMessage (edge.from, edge.to); System.out.println ("Message: "+edge.from+" --> "+edge.to+" "+newMsg); } } }