/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */ package cc.mallet.grmm.inference; import java.util.Set; import java.util.HashSet; import java.util.Iterator; import java.util.Collection; import java.util.List; import java.util.Arrays; import cc.mallet.grmm.types.*; import gnu.trove.TIntObjectHashMap; import gnu.trove.THashSet; import gnu.trove.TIntObjectIterator; /** * Datastructure for a junction tree. * * Created: Tue Sep 30 10:30:25 2003 * * @author <a href="mailto:casutton@cs.umass.edu">Charles Sutton</a> * @version $Id: JunctionTree.java,v 1.1 2007/10/22 21:37:49 mccallum Exp $ */ public class JunctionTree extends Tree { private int numNodes; private static class Sepset { Sepset(Set s, Factor p) { set = s; ptl = p; } Set set; Factor ptl; } private TIntObjectHashMap sepsets; private Factor[] cpfs; public JunctionTree(int size) { super(); numNodes = size; sepsets = new TIntObjectHashMap(); cpfs = new Factor[size]; } // JunctionTree constructor public void addNode (Object parent1, Object child1) { super.addNode(parent1, child1); VarSet parent = (VarSet) parent1; VarSet child = (VarSet) child1; Set sepset = parent.intersection(child); int id1 = lookupIndex(parent); int id2 = lookupIndex(child); putSepset(id1, id2, new Sepset (sepset, newSepsetPtl (sepset))); } private Factor newSepsetPtl (Set sepset) { if (sepset.isEmpty ()) { // use identity factor return ConstantFactor.makeIdentityFactor (); } else { return new TableFactor (sepset); } } private int hashIdxIdx(int id1, int id2) { assert (id1 < 65536) && (id2 < 65536); int id; if (id1 < id2) { id = (id1 << 16) | id2; } else { id = (id2 << 16) | id1; } return id; } private void putSepset(int id1, int id2, Sepset sepset) { int id = hashIdxIdx(id1, id2); sepsets.put(id, sepset); } private Sepset getSepset(int id1, int id2) { int id = hashIdxIdx(id1, id2); return (Sepset) sepsets.get(id); } // CPF accessors public Factor getCPF(VarSet c) { return cpfs[lookupIndex(c)]; } public void setCPF(VarSet c, Factor pot) { cpfs[lookupIndex(c)] = pot; } void clearCPFs() { for (int i = 0; i < cpfs.length; i++) { cpfs[i] = new TableFactor ((VarSet) lookupVertex (i)); } TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); Sepset sepset = (Sepset) it.value(); sepset.ptl = newSepsetPtl (sepset.set); } } public Set sepsetPotentials() { THashSet set = new THashSet(); TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); Factor ptl = ((Sepset) it.value()).ptl; set.add(ptl); } return set; } void setSepsetPot(Factor pot, VarSet v1, VarSet v2) { int id1 = lookupIndex(v1); int id2 = lookupIndex(v2); getSepset(id1, id2).ptl = pot; } public Factor getSepsetPot(VarSet v1, VarSet v2) { int id1 = lookupIndex(v1); int id2 = lookupIndex(v2); return getSepset(id1, id2).ptl; } /** * Returns a collection of all the potentials of cliques in the junction tree. * (i.e., these are the terms in the numerator of the jounction tre theorem). * @see #sepsetPotentials() */ public Collection clusterPotentials () { HashSet h = new HashSet(); for (int i = 0; i < cpfs.length; i++) { if (cpfs[i] != null) { h.add(cpfs[i]); } } return h; } public Set getSepset(VarSet v1, VarSet v2) { int id1 = lookupIndex(v1); int id2 = lookupIndex(v2); return getSepset(id1, id2).set; } public Factor lookupMarginal(Variable var) { VarSet c = findParentCluster(var); Factor pot = getCPF(c); return pot.marginalize(var); } public double lookupLogJoint(Assignment assn) { double accum = 0; for (int i = 0; i < cpfs.length; i++) { if (cpfs[i] != null) { double phi = cpfs[i].logValue (assn); accum += phi; } } TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); Factor ptl = ((Sepset) it.value()).ptl; double phi = ptl.logValue (assn); accum -= phi; } return accum; } /** Returns a cluster in the tree that contains var. */ public VarSet findParentCluster(Variable var) { int best = Integer.MAX_VALUE; VarSet retval = null; // xxx Inefficient for (Iterator it = getVerticesIterator(); it.hasNext();) { VarSet c = (VarSet) it.next(); if (c.contains(var) && c.weight() < best) { retval = c; best = c.weight(); } } return retval; } /** * Returns a cluster in the tree that contains all the vars in a * collection. */ public VarSet findParentCluster(Collection vars) { int best = Integer.MAX_VALUE; VarSet retval = null; // xxx Inefficient for (Iterator it = getVerticesIterator(); it.hasNext();) { VarSet c = (VarSet) it.next(); if (c.containsAll(vars) && c.weight() < best) { retval = c; best = c.weight(); } } return retval; } /** Returns a cluster in the tree that contains exactly the given * variables, or null if no such cluster exists. */ public VarSet findCluster(Variable[] vars) { List l = Arrays.asList(vars); for (Iterator it = getVerticesIterator(); it.hasNext();) { VarSet c2 = (VarSet) it.next(); if (c2.containsAll(l) && l.containsAll(c2)) return c2; } return null; } /** Normalizes all potentials in the tree, both node and sepset. */ public void normalizeAll() { int n = cpfs.length; for (int i = 0; i < n; i++) { if (cpfs[i] != null) { cpfs[i].normalize(); } } TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); Factor ptl = ((Sepset) it.value()).ptl; ptl.normalize(); } } int getId(VarSet c) { return lookupIndex(c); } // Debugging functions public void dump () { int n = cpfs.length; // This will cause OpenJGraph to print all our nodes and edges System.out.println(dumpToString()); // Now lets print all the cpfs System.out.println("Vertex CPFs"); for (int i = 0; i < n; i++) { if (cpfs[i] != null) { System.out.println("CPF "+i+" "+cpfs[i].dumpToString ()); } } // And the sepset potentials System.out.println("sepset CPFs"); TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); Factor ptl = ((Sepset) it.value()).ptl; System.out.println(ptl.dumpToString ()); } System.out.println ("/End JT"); } public double dumpLogJoint (Assignment assn) { double accum = 0; for (int i = 0; i < cpfs.length; i++) { if (cpfs[i] != null) { double phi = cpfs[i].logValue (assn); System.out.println ("CPF "+i+" accum = "+accum); } } TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); Factor ptl = ((Sepset) it.value()).ptl; double phi = ptl.logValue (assn); System.out.println("Sepset "+ptl.varSet()+" accum "+accum); } return accum; } public boolean isNaN() { int n = cpfs.length; for (int i = 0; i < n; i++) if (cpfs[i].isNaN()) return true; // And the sepset potentials TIntObjectIterator it = sepsets.iterator(); while (it.hasNext()) { it.advance(); Factor ptl = ((Sepset) it.value()).ptl; if (ptl.isNaN()) return true; } return false; } public double entropy () { double entropy = 0; for (Iterator it = clusterPotentials ().iterator (); it.hasNext ();) { Factor ptl = (Factor) it.next (); entropy += ptl.entropy (); } for (Iterator it = sepsetPotentials ().iterator (); it.hasNext ();) { Factor ptl = (Factor) it.next (); entropy -= ptl.entropy (); } return entropy; } // Implementation of edu.umass.cs.mallet.users.casutton.graphical.Compactible public void decompact() { cpfs = new Factor[numNodes]; clearCPFs(); } public void compact() { cpfs = null; } } // JunctionTree