/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.inference; import gnu.trove.THashSet; import gnu.trove.THashMap; import gnu.trove.TIntObjectHashMap; import java.util.logging.Logger; import java.util.logging.Level; import java.util.*; import java.io.*; import org._3pq.jgrapht.UndirectedGraph; import org._3pq.jgrapht.Graph; import org._3pq.jgrapht.Edge; import org._3pq.jgrapht.traverse.BreadthFirstIterator; import org._3pq.jgrapht.graph.SimpleGraph; import org.jdom.Document; import org.jdom.JDOMException; import org.jdom.Element; import org.jdom.input.SAXBuilder; import cc.mallet.grmm.types.*; import cc.mallet.util.MalletLogger; /** * Implementation of Wainwright's TRP schedule for loopy BP * in general graphical models. * * @author Charles Sutton * @version $Id: TRP.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $ */ public class TRP extends AbstractBeliefPropagation { private static Logger logger = MalletLogger.getLogger (TRP.class.getName ()); private static final boolean reportSpanningTrees = false; private TreeFactory factory; private TerminationCondition terminator; private Random random = new Random (); /* Make sure that we've included all edges before we terminate. */ transient private TIntObjectHashMap factorTouched; transient private boolean hasConverged; transient private File verboseOutputDirectory = null; public TRP () { this (null, null); } public TRP (TreeFactory f) { this (f, null); } public TRP (TerminationCondition cond) { this (null, cond); } public TRP (TreeFactory f, TerminationCondition cond) { factory = f; terminator = cond; } public static TRP createForMaxProduct () { TRP trp = new TRP (); trp.setMessager (new MaxProductMessageStrategy ()); return trp; } // Accessors public TRP setTerminator (TerminationCondition cond) { terminator = cond; return this; } public TRP setFactory (TreeFactory factory) { this.factory = factory; return this; } // xxx should this be static? public void setRandomSeed (long seed) { random = new Random (seed); } public void setVerboseOutputDirectory (File verboseOutputDirectory) { this.verboseOutputDirectory = verboseOutputDirectory; } public boolean isConverged () { return hasConverged; } protected void initForGraph (FactorGraph m) { super.initForGraph (m); int numNodes = m.numVariables (); factorTouched = new TIntObjectHashMap (numNodes); hasConverged = false; if (factory == null) { factory = new AlmostRandomTreeFactory (); } if (terminator == null) { terminator = new DefaultConvergenceTerminator (); } else { terminator.reset (); } } private static cc.mallet.grmm.types.Tree graphToTree (Graph g) throws Exception { // Perhaps handle gracefully?? -cas if (g.vertexSet ().size () <= 0) { throw new RuntimeException ("Empty graph."); } Tree tree = new cc.mallet.grmm.types.Tree (); Object root = g.vertexSet ().iterator ().next (); tree.add (root); for (Iterator it1 = new BreadthFirstIterator (g, root); it1.hasNext();) { Object v1 = it1.next (); for (Iterator it2 = g.edgesOf (v1).iterator (); it2.hasNext ();) { Edge edge = (Edge) it2.next (); Object v2 = edge.oppositeVertex (v1); if (tree.getParent (v1) != v2) { tree.addNode (v1, v2); assert tree.getParent (v2) == v1; } } } return tree; } /** * Interface for tree-generation strategies for TRP. * <p/> * TRP works by repeatedly doing exact inference over spanning tree * of the original graph. But the trees can be chosen arbitrarily. * In fact, they don't need to be spanning trees; any acyclic * substructure will do. Users of TRP can tell it which strategy * to use by passing in an implementation of TreeFactory. */ public interface TreeFactory extends Serializable { public cc.mallet.grmm.types.Tree nextTree (FactorGraph mdl); } // This works around what appears to be a bug in OpenJGraph // connected sets. private static class SimpleUnionFind { private Map obj2set = new THashMap (); private Set findSet (Object obj) { Set container = (Set) obj2set.get (obj); if (container != null) { return container; } else { Set newSet = new THashSet (); newSet.add (obj); obj2set.put (obj, newSet); return newSet; } } private void union (Object obj1, Object obj2) { Set set1 = findSet (obj1); Set set2 = findSet (obj2); set1.addAll (set2); for (Iterator it = set2.iterator (); it.hasNext ();) { Object obj = it.next (); obj2set.put (obj, set1); } } public boolean noPairConnected (VarSet varSet) { for (int i = 0; i < varSet.size (); i++) { for (int j = i + 1; j < varSet.size (); j++) { Variable v1 = varSet.get (i); Variable v2 = varSet.get (j); if (findSet (v1) == findSet (v2)) { return false; } } } return true; } public void unionAll (Factor factor) { VarSet varSet = factor.varSet (); for (int i = 0; i < varSet.size (); i++) { Variable var = varSet.get (i); union (var, factor); } } } /** * Always adds edges that have not been touched, after that * adds random edges. */ public class AlmostRandomTreeFactory implements TreeFactory { public Tree nextTree (FactorGraph fullGraph) { SimpleUnionFind unionFind = new SimpleUnionFind (); ArrayList edges = new ArrayList (fullGraph.factors ()); ArrayList goodEdges = new ArrayList (fullGraph.numVariables ()); Collections.shuffle (edges, random); // First add all edges that haven't been used so far try { for (Iterator it = edges.iterator (); it.hasNext ();) { Factor factor = (Factor) it.next (); VarSet varSet = factor.varSet (); if (!isFactorTouched (factor) && unionFind.noPairConnected (varSet)) { goodEdges.add (factor); unionFind.unionAll (factor); it.remove (); } } // Now add as many other edges as possible for (Iterator it = edges.iterator (); it.hasNext ();) { Factor factor = (Factor) it.next (); VarSet varSet = factor.varSet (); if (unionFind.noPairConnected (varSet)) { goodEdges.add (factor); unionFind.unionAll (factor); } } for (Iterator it = goodEdges.iterator (); it.hasNext ();) { Factor factor = (Factor) it.next (); touchFactor (factor); } UndirectedGraph g = new SimpleGraph (); for (Iterator it = fullGraph.variablesIterator (); it.hasNext ();) { Variable var = (Variable) it.next (); g.addVertex (var); } for (Iterator it = goodEdges.iterator (); it.hasNext ();) { Factor factor = (Factor) it.next (); g.addVertex (factor); for (Iterator vit = factor.varSet ().iterator (); vit.hasNext ();) { Variable var = (Variable) vit.next (); g.addEdge (factor, var); } } Tree tree = graphToTree (g); if (reportSpanningTrees) { System.out.println ("********* SPANNING TREE *************"); System.out.println (tree.dumpToString ()); System.out.println ("********* END TREE *************"); } return tree; } catch (Exception e) { e.printStackTrace (); throw new RuntimeException (e); } } private static final long serialVersionUID = -7461763414516915264L; } /** * Generates spanning trees cyclically from a predefined collection. */ static public class TreeListFactory implements TreeFactory { private List lst; private Iterator it; public TreeListFactory (List l) { lst = l; it = lst.iterator (); } public TreeListFactory (cc.mallet.grmm.types.Tree[] arr) { lst = new ArrayList (java.util.Arrays.asList (arr)); it = lst.iterator (); } public static TreeListFactory makeFromReaders (FactorGraph fg, List readerList) { List treeList = new ArrayList (); for (Iterator it = readerList.iterator (); it.hasNext ();) { try { Reader reader = (Reader) it.next (); Document doc = new SAXBuilder ().build (reader); Element treeElt = doc.getRootElement (); Element rootElt = (Element) treeElt.getChildren ().get (0); Tree tree = readTreeRec (fg, rootElt); System.out.println (tree.dumpToString ()); treeList.add (tree); } catch (JDOMException e) { throw new RuntimeException (e); } catch (IOException e) { throw new RuntimeException (e); } } return new TreeListFactory (treeList); } /** @param fileList List of File objects. Each file should be an XML document describing a tree. */ public static TreeListFactory readFromFiles (FactorGraph fg, List fileList) { List treeList = new ArrayList (); for (Iterator it = fileList.iterator (); it.hasNext ();) { try { File treeFile = (File) it.next (); Document doc = new SAXBuilder ().build (treeFile); Element treeElt = doc.getRootElement (); Element rootElt = (Element) treeElt.getChildren ().get (0); treeList. add (readTreeRec (fg, rootElt)); } catch (JDOMException e) { throw new RuntimeException (e); } catch (IOException e) { throw new RuntimeException (e); } } return new TreeListFactory (treeList); } private static Tree readTreeRec (FactorGraph fg, Element elt) { List subtrees = new ArrayList (); for (Iterator it = elt.getChildren ().iterator (); it.hasNext ();) { Element child = (Element) it.next (); Tree subtree = readTreeRec (fg, child); subtrees.add (subtree); } Object parent = objFromElt (fg, elt); return Tree.makeFromSubtree (parent, subtrees); } private static Object objFromElt (FactorGraph fg, Element elt) { String type = elt.getName (); if (type.equals ("VAR")) { String vname = elt.getAttributeValue ("NAME"); return fg.findVariable (vname); } else if (type.equals("FACTOR")) { String varSetStr = elt.getAttributeValue ("VARS"); String[] vnames = varSetStr.split ("\\s+"); Variable[] vars = new Variable [vnames.length]; for (int i = 0; i < vnames.length; i++) { vars[i] = fg.findVariable (vnames[i]); } return fg.factorOf (new HashVarSet (vars)); } else { throw new RuntimeException ("Can't figure out element "+elt); } } public cc.mallet.grmm.types.Tree nextTree (FactorGraph mdl) { // If no more trees, rewind. if (!it.hasNext ()) { it = lst.iterator (); } return (cc.mallet.grmm.types.Tree) it.next (); } } // Termination conditions // will this need to be subclassed from outside? Will such // subclasses need access to the private state of TRP? static public interface TerminationCondition extends Cloneable, Serializable { // This takes the instances of trp as a parameter so that if a // TRP instance is cloned, and the terminator copied over, it // will still work. public boolean shouldContinue (TRP trp); public void reset (); // boy do I hate Java cloning public Object clone () throws CloneNotSupportedException; } static public class IterationTerminator implements TerminationCondition { int current; int max; public void reset () { current = 0; } public IterationTerminator (int m) { max = m; reset (); } public boolean shouldContinue (TRP trp) { current++; if (current >= max) { logger.finest ("***TRP quitting: Iteration " + current + " >= " + max); } return current <= max; } public Object clone () throws CloneNotSupportedException { return super.clone (); } } //xxx Delta is currently ignored. public static class ConvergenceTerminator implements TerminationCondition { double delta = 0.01; public ConvergenceTerminator () {} public ConvergenceTerminator (double delta) { this.delta = delta; } public void reset () { } public boolean shouldContinue (TRP trp) { /* if (oldMessages != null) retval = !checkForConvergence (trp); copyMessages(trp); return retval; */ boolean retval = !trp.hasConverged (delta); trp.copyOldMessages (); return retval; } public Object clone () throws CloneNotSupportedException { return super.clone (); } } // Runs until convergence, but doesn't stop until all edges have // been used at least once, and always stops after 1000 iterations. public static class DefaultConvergenceTerminator implements TerminationCondition { ConvergenceTerminator cterminator; IterationTerminator iterminator; String msg; public DefaultConvergenceTerminator () { this (0.001, 1000); } public DefaultConvergenceTerminator (double delta, int maxIter) { cterminator = new ConvergenceTerminator (delta); iterminator = new IterationTerminator (maxIter); msg = "***TRP quitting: over " + maxIter + " iterations"; } public void reset () { iterminator.reset (); cterminator.reset (); } // Terminate if converged or at insanely high # of iterations public boolean shouldContinue (TRP trp) { boolean notAllTouched = !trp.allEdgesTouched (); if (!iterminator.shouldContinue (trp)) { logger.warning (msg); if (notAllTouched) { logger.warning ("***TRP warning: Not all edges used!"); } return false; } if (notAllTouched) { return true; } else { return cterminator.shouldContinue (trp); } } public Object clone () throws CloneNotSupportedException { DefaultConvergenceTerminator dup = (DefaultConvergenceTerminator) super.clone (); dup.iterminator = (IterationTerminator) iterminator.clone (); dup.cterminator = (ConvergenceTerminator) cterminator.clone (); return dup; } } // And now, the heart of TRP: public void computeMarginals (FactorGraph m) { resetMessagesSentAtStart (); initForGraph (m); int iter = 0; while (terminator.shouldContinue (this)) { logger.finer ("TRP iteration " + (iter++)); cc.mallet.grmm.types.Tree tree = factory.nextTree (m); propagate (tree); dumpForIter (iter, tree); } iterUsed = iter; logger.info ("TRP used " + iter + " iterations."); doneWithGraph (m); } private void dumpForIter (int iter, Tree tree) { if (verboseOutputDirectory != null) { try { // output messages FileWriter writer = new FileWriter (new File (verboseOutputDirectory, "iter" + iter + ".txt")); dump (new PrintWriter (writer, true)); writer.close (); FileWriter bfWriter = new FileWriter (new File (verboseOutputDirectory, "beliefs" + iter + ".txt")); dumpBeliefs (new PrintWriter (bfWriter, true)); bfWriter.close (); // output spanning tree FileWriter treeWriter = new FileWriter (new File (verboseOutputDirectory, "tree" + iter + ".txt")); treeWriter.write (tree.toString ()); treeWriter.write ("\n"); treeWriter.close (); } catch (IOException e) { e.printStackTrace (); } } } private void dumpBeliefs (PrintWriter writer) { for (int vi = 0; vi < mdlCurrent.numVariables (); vi++) { Variable var = mdlCurrent.get (vi); Factor mrg = lookupMarginal (var); writer.println (mrg.dumpToString ()); writer.println (); } } private void propagate (cc.mallet.grmm.types.Tree tree) { Object root = tree.getRoot (); lambdaPropagation (tree, root); piPropagation (tree, root); } /** Sends BP messages starting from children to parents. This version uses constant stack space. */ private void lambdaPropagation (cc.mallet.grmm.types.Tree tree, Object root) { LinkedList openList = new LinkedList (); LinkedList closedList = new LinkedList (); openList.addAll (tree.getChildren (root)); while (!openList.isEmpty ()) { Object var = openList.removeFirst (); openList.addAll (tree.getChildren (var)); closedList.addFirst (var); } // Now open list contains all of the nodes (except the root) in reverse topological order. Send the messages. for (Iterator it = closedList.iterator (); it.hasNext ();) { Object child = it.next (); Object parent = tree.getParent (child); sendMessage (mdlCurrent, child, parent); } } /** Sends BP messages starting from parents to children. This version uses constant stack space. */ private void piPropagation (cc.mallet.grmm.types.Tree tree, Object root) { LinkedList openList = new LinkedList (); openList.add (root); while (!openList.isEmpty ()) { Object current = openList.removeFirst (); List children = tree.getChildren (current); for (Iterator it = children.iterator (); it.hasNext ();) { Object child = it.next (); sendMessage (mdlCurrent, current, child); openList.add (child); } } } private void sendMessage (FactorGraph fg, Object parent, Object child) { if (logger.isLoggable (Level.FINER)) logger.finer ("Sending message: "+parent+" --> "+child); if (parent instanceof Factor) { sendMessage (fg, (Factor) parent, (Variable) child); } else if (parent instanceof Variable) { sendMessage (fg, (Variable) parent, (Factor) child); } } private boolean allEdgesTouched () { Iterator it = mdlCurrent.factorsIterator (); while (it.hasNext ()) { Factor factor = (Factor) it.next (); int idx = mdlCurrent.getIndex (factor); int numTouches = getNumTouches (idx); if (numTouches == 0) { logger.finest ("***TRP continuing: factor " + idx + " not touched."); return false; } } return true; } private void touchFactor (Factor factor) { int idx = mdlCurrent.getIndex (factor); incrementTouches (idx); } private boolean isFactorTouched (Factor factor) { int idx1 = mdlCurrent.getIndex (factor); return (getNumTouches (idx1) > 0); } private int getNumTouches (int idx1) { Integer integer = (Integer) factorTouched.get (idx1); return (integer == null) ? 0 : integer.intValue (); } private void incrementTouches (int idx1) { int nt = getNumTouches (idx1); factorTouched.put (idx1, new Integer (nt + 1)); } public Factor query (DirectedModel m, Variable var) { throw new UnsupportedOperationException ("GRMM doesn't yet do directed models."); } //xxx could get moved up to AbstractInferencer, if mdlCurrent did. public Assignment bestAssignment () { int[] outcomes = new int [mdlCurrent.numVariables ()]; for (int i = 0; i < outcomes.length; i++) { Variable var = mdlCurrent.get (i); TableFactor ptl = (TableFactor) lookupMarginal (var); outcomes[i] = ptl.argmax (); } return new Assignment (mdlCurrent, outcomes); } // Deep copy termination condition public Object clone () { try { TRP dup = (TRP) super.clone (); if (terminator != null) { dup.terminator = (TerminationCondition) terminator.clone (); } return dup; } catch (CloneNotSupportedException e) { // should never happen throw new RuntimeException (e); } } // Serialization private static final long serialVersionUID = 1; // If seralization-incompatible changes are made to these classes, // then smarts can be added to these methods for backward compatibility. private void writeObject (ObjectOutputStream out) throws IOException { out.defaultWriteObject (); } private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException { in.defaultReadObject (); } }