/*******************************************************************************
* Copyright (C) 2009-2012 Stefan Waldherr, Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Vector;
import probcog.bayesnets.core.BeliefNetworkEx;
import probcog.bayesnets.inference.IJGP.JoinGraph.Arc;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.tum.cs.util.StringTool;
import edu.tum.cs.util.datastruct.MutableDouble;
/**
* The Iterative Join-Graph Propagation algorithm as described by Dechter, Kask and Mateescu (2002)
* @author Stefan Waldherr
* @author Dominik Jain
*/
public class IJGP extends Sampler {
protected JoinGraph jg;
Vector<JoinGraph.Node> jgNodes;
protected BeliefNode[] nodes;
protected final boolean debug = false;
protected int ibound;
protected boolean verbose = true;
public IJGP(BeliefNetworkEx bn) throws Exception {
super(bn);
}
@Override
protected void _initialize() {
// detect minimum bound
ibound = 1;
for (BeliefNode n : nodes) {
int l = n.getCPF().getDomainProduct().length;
if (l > ibound)
ibound = l;
}
// construct join-graph
if(verbose)
out.printf("constructing join-graph with i-bound %d...\n", ibound);
jg = new JoinGraph(bn, ibound);
//jg.writeDOT(new File("jg.dot"));
}
@Override
public String getAlgorithmName() {
return String.format("IJGP[i-bound %d]", this.ibound);
}
/*
* public IJGP(BeliefNetworkEx bn, int bound) { super(bn); this.nodes =
* bn.bn.getNodes(); jg = new JoinGraph(bn, bound); jg.print(out);
* jgNodes = jg.getTopologicalorder(); // construct join-graph }
*/
protected IDistributionBuilder createDistributionBuilder() {
return new ImmediateDistributionBuilder();
}
@Override
public void _infer() throws Exception {
// Create topological order
if(verbose) out.println("determining order...");
jgNodes = jg.getTopologicalorder();
if(debug) {
out.println("Topological Order: ");
for (int i = 0; i < jgNodes.size(); i++) {
out.println(jgNodes.get(i).getShortName());
}
}
// process observed variables
if(verbose) out.println("processing observed variables...");
for (JoinGraph.Node n : jgNodes) {
Vector<BeliefNode> nodes = new Vector<BeliefNode>(n.getNodes());
for (BeliefNode belNode : nodes) {
int nodeIdx = bn.getNodeIndex(belNode);
int domainIdx = evidenceDomainIndices[nodeIdx];
if (domainIdx > -1)
n.nodes.remove(belNode);
}
}
out.printf("running propagation (%d steps)...\n", this.numSamples);
for (int step = 1; step <= this.numSamples; step++) {
out.printf("step %d\n", step);
// for every node in JG in topological order and back:
int s = jgNodes.size();
boolean direction = true;
for (int j = 0; j < 2 * s; j++) {
int i;
if (j < s)
i = j;
else {
i = 2 * s - j - 1;
direction = false;
}
JoinGraph.Node u = jgNodes.get(i);
//out.printf("step %d, %d/%d: %-60s\r", step, j, 2*s, u.getShortName());
int topIndex = jgNodes.indexOf(u);
for (JoinGraph.Node v : u.getNeighbors()) {
if ((direction && jgNodes.indexOf(v) < topIndex)
|| (!direction && jgNodes.indexOf(v) > topIndex)) {
continue;
}
Arc arc = u.getArcToNode(v);
arc.clearOutMessages(u);
// construct cluster_v(u)
Cluster cluster_u = new Cluster(u, v);
// Include in cluster_H each function in cluster_u which
// scope does not contain variables in elim(u,v)
HashSet<BeliefNode> elim = new HashSet<BeliefNode>(u.nodes);
// out.println(" Node " + u.getShortName() +
// " and node " +v.getShortName() + " have separator " +
// StringTool.join(", ", arc.separator));
elim.removeAll(arc.separator);
Cluster cluster_H = cluster_u.getReducedCluster(elim);
// denote by cluster_A the remaining functions
Cluster cluster_A = cluster_u.copy();
cluster_A.subtractCluster(cluster_H);
// DEBUG OUTPUT
if (debug) {
out.println(" cluster_v(u): \n" + cluster_u);
out.println(" A: \n" + cluster_A);
out.println(" H_(u,v): \n" + cluster_H);
}
// convert eliminator into varsToSumOver
int[] varsToSumOver = new int[elim.size()];
int k = 0;
for (BeliefNode n : elim)
varsToSumOver[k++] = bn.getNodeIndex(n);
// create message function and send to v
MessageFunction m = new MessageFunction(arc.separator,
varsToSumOver, cluster_A);
m.calcuSave(evidenceDomainIndices.clone());
arc.addOutMessage(u, m);
for (MessageFunction mf : cluster_H.functions) {
mf.calcuSave(evidenceDomainIndices.clone());
arc.addOutMessage(u, mf);
}
for (BeliefNode n : cluster_H.cpts) {
arc.addCPTOutMessage(u, n);
}
}
}
}
// compute probabilities and store results in distribution
out.println("computing results...");
SampledDistribution dist = createDistribution();
dist.Z = 1.0;
for (int i = 0; i < nodes.length; i++) {
//out.println("Computing: " + nodes[i].getName() + "\n");
if (evidenceDomainIndices[i] >= 0) {
dist.values[i][evidenceDomainIndices[i]] = 1.0;
continue;
}
// For every node X let u be a vertex in the join graph such that X
// is in u
// out.println(nodes[i]);
JoinGraph.Node u = null;
for (JoinGraph.Node node : jgNodes) {
if (node.nodes.contains(nodes[i])) {
u = node;
break;
}
}
if (u == null)
throw new Exception(
"Could not find vertex in join graph containing variable "
+ nodes[i].getName());
// out.println("\nCalculating results for " + nodes[i]);
// out.println(u);
// compute sum for each domain value of i-th node
int domSize = dist.values[i].length;
double Z = 0.0;
int[] nodeDomainIndices = evidenceDomainIndices.clone();
for (int j = 0; j < domSize; j++) {
nodeDomainIndices[i] = j;
MutableDouble sum = new MutableDouble(0.0);
BeliefNode[] nodesToSumOver = u.nodes
.toArray(new BeliefNode[u.nodes.size()]);
computeSum(0, nodesToSumOver, nodes[i], new Cluster(u),
nodeDomainIndices, sum);
Z += (dist.values[i][j] = sum.value);
}
// normalize
for (int j = 0; j < domSize; j++)
dist.values[i][j] /= Z;
}
// dist.print(out);
((ImmediateDistributionBuilder)distributionBuilder).setDistribution(dist);
}
protected void computeSum(int i, BeliefNode[] varsToSumOver,
BeliefNode excludedNode, Cluster u, int[] nodeDomainIndices,
MutableDouble result) {
if (i == varsToSumOver.length) {
result.value += u.product(nodeDomainIndices);
return;
}
if (varsToSumOver[i] == excludedNode)
computeSum(i + 1, varsToSumOver, excludedNode, u,
nodeDomainIndices, result);
else {
for (int j = 0; j < varsToSumOver[i].getDomain().getOrder(); j++) {
nodeDomainIndices[this.getNodeIndex(varsToSumOver[i])] = j;
computeSum(i + 1, varsToSumOver, excludedNode, u,
nodeDomainIndices, result);
}
}
}
protected class Cluster {
HashSet<BeliefNode> cpts = new HashSet<BeliefNode>();
HashSet<MessageFunction> functions = new HashSet<MessageFunction>();
JoinGraph.Node node;
public Cluster(JoinGraph.Node u) {
// Constructor for cluster(u)
this.node = u;
// add to the cluster all CPTs of the given node
for (CPF cpf : u.functions)
cpts.add(cpf.getDomainProduct()[0]);
// add all incoming messages of n
for (JoinGraph.Node nb : u.getNeighbors()) {
JoinGraph.Arc arc = u.arcs.get(nb);
HashSet<MessageFunction> m = arc.getInMessage(u);
if (!m.isEmpty())
functions.addAll(m);
HashSet<BeliefNode> bn = arc.getCPTInMessage(u);
if (!bn.isEmpty())
cpts.addAll(bn);
}
}
public Cluster(JoinGraph.Node u, JoinGraph.Node v) {
// Constructor for cluster_v(u)
this.node = u;
// add to the cluster all CPTs of the given node
for (CPF cpf : u.functions)
cpts.add(cpf.getDomainProduct()[0]);
// add all incoming messages of n
for (JoinGraph.Node nb : u.getNeighbors()) {
if (!nb.equals(v)) {
JoinGraph.Arc arc = u.arcs.get(nb);
HashSet<MessageFunction> m = arc.getInMessage(u);
if (!m.isEmpty())
functions.addAll(m);
HashSet<BeliefNode> bn = arc.getCPTInMessage(u);
if (!bn.isEmpty())
cpts.addAll(bn);
}
}
}
public Cluster() {
}
public String toString() {
StringBuffer sb = new StringBuffer();
sb.append(StringTool.join(", ", this.cpts));
sb.append("; ");
sb.append(StringTool.join(", ", this.functions));
return sb.toString();
}
public void excludeMessagesFrom(JoinGraph.Node n) {
JoinGraph.Arc arc = node.arcs.get(n);
for (MessageFunction mf : arc.getInMessage(node)) {
if (functions.contains(mf))
functions.remove(mf);
}
for (BeliefNode bn : arc.getCPTInMessage(node)) {
if (cpts.contains(bn))
functions.remove(bn);
}
}
public Cluster copy() {
Cluster copyCluster = new Cluster();
for (BeliefNode cpt : cpts) {
copyCluster.cpts.add(cpt);
}
for (MessageFunction f : functions) {
copyCluster.functions.add(f);
}
return copyCluster;
}
public Cluster getReducedCluster(HashSet<BeliefNode> nodes)
throws CloneNotSupportedException {
// deletes all functions and arcs in the cluster whose scope
// contains the given nodes
Cluster redCluster = this.copy();
for (BeliefNode bn : nodes) {
HashSet<BeliefNode> foo = (HashSet<BeliefNode>) cpts.clone();
for (BeliefNode n : foo) {
BeliefNode[] domProd = n.getCPF().getDomainProduct();
/*
* if (bn.equals(n)){ redCluster.cpts.remove(n); }
*/
for (int i = 0; i < domProd.length; i++) {
if (bn.equals(domProd[i])) {
redCluster.cpts.remove(n);
break;
}
}
}
for (MessageFunction m : ((HashSet<MessageFunction>) functions
.clone())) {
if (m.scope.contains(bn))
redCluster.functions.remove(m);
}
}
return redCluster;
}
public void subtractCluster(Cluster c2) {
// deletes all functions and arcs of the cluster that are also in
// cluster c2
for (BeliefNode n : ((HashSet<BeliefNode>) c2.cpts.clone())) { // TODO
// nonsense
cpts.remove(n);
}
for (MessageFunction m : ((HashSet<MessageFunction>) c2.functions
.clone())) {
functions.remove(m);
}
}
public double product(int[] nodeDomainIndices) {
double ret = 1.0;
for (BeliefNode n : cpts) {
// out.println(" " + n.getCPF().toString());
ret *= getCPTProbability(n, nodeDomainIndices);
}
for (MessageFunction f : this.functions) {
// out.println(" " + f);
ret *= f.compute(nodeDomainIndices);
}
return ret;
}
}
protected class MessageFunction {
protected int[] varsToSumOver;
protected MessageTable table;
HashSet<BeliefNode> cpts;
Iterable<MessageFunction> childFunctions;
HashSet<BeliefNode> scope;
public MessageFunction(HashSet<BeliefNode> scope, int[] varsToSumOver,
Cluster cluster) {
this.scope = scope;
this.varsToSumOver = varsToSumOver;
this.cpts = cluster.cpts;
this.childFunctions = cluster.functions;
this.table = null;
}
public void calcuSave(int[] nodeDomainIndices) {
table = new MessageTable(new Vector<BeliefNode>(scope), 0);
int[] scopeToSumOver = new int[scope.size()];
int k = 0;
for (BeliefNode n : scope)
scopeToSumOver[k++] = bn.getNodeIndex(n);
calcuSave(scopeToSumOver, 0, nodeDomainIndices.clone());
}
public void calcuSave(int[] scopeToSumOver, int i, int[] nodeDomainIndices) {
if (i == scope.size()) {
table.addEntry(nodeDomainIndices, compute(nodeDomainIndices));
return;
} else {
int idxVar = scopeToSumOver[i];
for (int v = 0; v < nodes[idxVar].getDomain().getOrder(); v++) {
nodeDomainIndices[idxVar] = v;
calcuSave(scopeToSumOver, i + 1, nodeDomainIndices);
}
}
}
public double compute(int[] nodeDomainIndices) {
if (!table.containsEntry(nodeDomainIndices)) {
MutableDouble sum = new MutableDouble(0.0);
compute(varsToSumOver, 0, nodeDomainIndices, sum);
return sum.value;
} else {
return table.getEntry(nodeDomainIndices);
}
}
protected void compute(int[] varsToSumOver, int i,
int[] nodeDomainIndices, MutableDouble sum) {
if (i == varsToSumOver.length) {
double result = 1.0;
for (BeliefNode node : cpts)
result *= getCPTProbability(node, nodeDomainIndices);
for (MessageFunction h : childFunctions)
result *= h.compute(nodeDomainIndices);
sum.value += result;
return;
}
int idxVar = varsToSumOver[i];
for (int v = 0; v < nodes[idxVar].getDomain().getOrder(); v++) {
nodeDomainIndices[idxVar] = v;
compute(varsToSumOver, i + 1, nodeDomainIndices, sum);
}
}
public String toString() {
StringBuffer sb = new StringBuffer("MF[");
sb.append("scope: " + StringTool.join(", ", scope));
sb.append("; CPFs:");
int i = 0;
for (BeliefNode n : this.cpts) {
if (i++ > 0)
sb.append("; ");
sb.append(n.getCPF().toString());
}
sb.append("; children: ");
sb.append(StringTool.join("; ", this.childFunctions));
sb.append("]");
return sb.toString();
}
protected class MessageTable {
protected Vector<BeliefNode> scope;
protected boolean leaf;
protected MessageTable[] map;
protected Double[] result;
public MessageTable(Vector<BeliefNode> scope, int i) {
int domSize = scope.get(i).getDomain().getOrder();
this.map = new MessageTable[domSize];
this.scope = scope;
if (i == scope.size() - 1) {
leaf = true;
result = new Double[domSize];
} else {
leaf = false;
result = null;
for (int j = 0; j < scope.get(i).getDomain().getOrder(); j++) {
map[j] = new MessageTable(scope, i + 1);
}
}
}
public void addEntry(int[] domainIndices, double entry) {
addEntry(domainIndices, 0, entry);
}
public void addEntry(int[] domainIndices, int i, double entry) {
if (i != scope.size() - 1) {
int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
map[idx].addEntry(domainIndices, i + 1, entry);
} else {
int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
result[idx] = entry;
}
}
public double getEntry(int[] domainIndices) {
return getEntry(domainIndices, 0);
}
public double getEntry(int[] domainIndices, int i) {
if (i != scope.size() - 1) {
int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
return map[idx].getEntry(domainIndices, i + 1);
}
else {
int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
return result[idx];
}
}
public boolean containsEntry(int[] domainIndices){
return containsEntry(domainIndices, 0);
}
public boolean containsEntry(int[] domainIndices, int i){
if (i != scope.size() - 1){
int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
if (map[idx] == null){
return false;
}
else{
return map[idx].containsEntry(domainIndices, i+1);
}
}
else{
int idx = domainIndices[bn.getNodeIndex(scope.get(i))];
return (result[idx] != null);
}
}
}
}
protected static class BucketVar {
public HashSet<BeliefNode> nodes;
public CPF cpf = null;
public Vector<MiniBucket> parents;
public BeliefNode idxVar;
public BucketVar(HashSet<BeliefNode> nodes) {
this(nodes, null);
}
public BucketVar(HashSet<BeliefNode> nodes, MiniBucket parent) {
this.nodes = nodes;
if (nodes.size() == 0)
throw new RuntimeException(
"Must provide non-empty set of nodes.");
this.parents = new Vector<MiniBucket>();
if (parent != null)
parents.add(parent);
}
public void setFunction(CPF cpf) {
this.cpf = cpf;
}
public void addInArrow(MiniBucket parent) {
parents.add(parent);
}
public BeliefNode getMaxNode(BeliefNetworkEx bn) {
// returns the BeliefNode of a bucket variable highest in the
// topological order
BeliefNode maxNode = null;
int[] topOrder = bn.getTopologicalOrder();
for (int i = topOrder.length - 1; i > -1; i--) {
for (BeliefNode node : nodes) {
if (bn.getNodeIndex(node) == topOrder[i]) {
return node;
}
}
}
return maxNode;
}
public String toString() {
return "[" + StringTool.join(" ", this.nodes) + "]";
}
public boolean equals(BucketVar other) {
if (other.nodes.size() != this.nodes.size())
return false;
for (BeliefNode n : nodes)
if (!other.nodes.contains(n))
return false;
return true;
}
}
protected static class MiniBucket {
public HashSet<BucketVar> items;
public Bucket bucket;
public HashSet<MiniBucket> parents;
public BucketVar child;
public MiniBucket(Bucket bucket) {
this.items = new HashSet<BucketVar>();
this.bucket = bucket;
this.child = null;
this.parents = new HashSet<MiniBucket>();
}
public void addVar(BucketVar bv) {
items.add(bv);
for (MiniBucket p : bv.parents)
parents.add(p);
}
public String toString() {
return "Minibucket[" + StringTool.join(" ", items) + "]";
}
}
protected static class Bucket {
public BeliefNode bucketNode;
public HashSet<BucketVar> vars = new HashSet<BucketVar>();
public Vector<MiniBucket> minibuckets = new Vector<MiniBucket>();
public Bucket(BeliefNode bucketNode) {
this.bucketNode = bucketNode;
}
public void addVar(BucketVar bv) {
for (BucketVar v : vars)
if (v.equals(bv)) {
for (MiniBucket p : bv.parents)
v.addInArrow(p);
return;
}
vars.add(bv);
}
/**
* create minibuckets of size bound
*
* @param bound
*/
public void partition(int bound) {
minibuckets.add(new MiniBucket(this));
HashSet<BeliefNode> count = new HashSet<BeliefNode>();
for (BucketVar bv : vars) {
int newNodes = 0;
for (BeliefNode n : bv.nodes) {
if (!count.contains(n)) {
newNodes++;
}
}
if (count.size() + newNodes > bound) { // create a new
// minibucket
minibuckets.add(new MiniBucket(this));
count.clear();
count.addAll(bv.nodes);
} else {
count.addAll(bv.nodes);
}
minibuckets.lastElement().addVar(bv);
}
}
public HashSet<BucketVar> createScopeFunctions() {
HashSet<BucketVar> newVars = new HashSet<BucketVar>();
for (MiniBucket mb : minibuckets) {
HashSet<BeliefNode> nodes = new HashSet<BeliefNode>();
for (BucketVar bv : mb.items) {
for (BeliefNode bn : bv.nodes) {
if (bn != bucketNode)
nodes.add(bn);
}
}
if (nodes.size() != 0) { // TODO check correctness
BucketVar newBucketVar = new BucketVar(nodes, mb);
newVars.add(newBucketVar);
}
}
return newVars;
}
public String toString() {
return StringTool.join(" ", vars);
}
}
protected static class SchematicMiniBucket {
public HashMap<BeliefNode, Bucket> bucketMap;
public BeliefNetworkEx bn;
public SchematicMiniBucket(BeliefNetworkEx bn, int bound) {
this.bn = bn;
bucketMap = new HashMap<BeliefNode, Bucket>();
// order the variables from X_1 to X_n
BeliefNode[] nodes = bn.bn.getNodes();
int[] topOrder = bn.getTopologicalOrder();
// place each CPT in the bucket of the highest index
for (int i = topOrder.length - 1; i > -1; i--) {
Bucket bucket = new Bucket(nodes[topOrder[i]]);
int[] cpt = bn.getDomainProductNodeIndices(nodes[topOrder[i]]);
HashSet<BeliefNode> cptNodes = new HashSet<BeliefNode>();
for (int j : cpt) {
cptNodes.add(nodes[j]);
}
BucketVar bv = new BucketVar(cptNodes);
bv.setFunction(nodes[topOrder[i]].getCPF());
bucket.addVar(bv);
bucketMap.put(nodes[topOrder[i]], bucket);
}
// partition buckets and create arcs
for (int i = topOrder.length - 1; i > -1; i--) {
Bucket oldVar = bucketMap.get(nodes[topOrder[i]]);
oldVar.partition(bound);
HashSet<BucketVar> scopes = oldVar.createScopeFunctions();
for (BucketVar bv : scopes) {
// add new variables to the bucket with the highest index
BeliefNode node = bv.getMaxNode(bn);
bucketMap.get(node).addVar(bv);
}
}
}
public void print(PrintStream out) {
BeliefNode[] nodes = bn.bn.getNodes();
int[] order = bn.getTopologicalOrder();
for (int i = nodes.length - 1; i >= 0; i--) {
BeliefNode n = nodes[order[i]];
out.printf("%s: %s\n", n.toString(), bucketMap.get(n));
}
}
public Vector<MiniBucket> getMiniBuckets() {
Vector<MiniBucket> mb = new Vector<MiniBucket>();
for (Bucket b : bucketMap.values()) {
mb.addAll(b.minibuckets);
}
return mb;
}
public Vector<Bucket> getBuckets() {
return new Vector<Bucket>(bucketMap.values());
}
}
protected static class JoinGraph {
HashSet<Node> nodes;
HashMap<MiniBucket, Node> bucket2node = new HashMap<MiniBucket, Node>();
public JoinGraph(BeliefNetworkEx bn, int bound) {
nodes = new HashSet<Node>();
// apply procedure schematic mini-bucket(bound)
SchematicMiniBucket smb = new SchematicMiniBucket(bn, bound);
// out.println("\nJoin graph decomposition:");
// smb.print(out);
Vector<MiniBucket> minibuckets = smb.getMiniBuckets();
// associate each minibucket with a node
// out.println("\nJoin graph nodes:");
for (MiniBucket mb : minibuckets) {
// out.println(mb);
Node newNode = new Node(mb);
// out.println(newNode);
nodes.add(newNode);
bucket2node.put(mb, newNode);
}
// copy parent structure
for (MiniBucket mb : minibuckets) {
for (MiniBucket p : mb.parents) {
bucket2node.get(mb).parents.add(bucket2node.get(p));
}
}
// keep the arcs and label them by regular separator
for (MiniBucket mb : minibuckets) {
for (MiniBucket par : mb.parents) {
Node n1 = bucket2node.get(par);
Node n2 = bucket2node.get(mb);
new Arc(n1, n2);
}
}
// connect the mini-bucket clusters
for (MiniBucket mb1 : minibuckets) {
for (MiniBucket mb2 : minibuckets) {
if (mb1 != mb2 && mb1.bucket == mb2.bucket) {
new Arc(bucket2node.get(mb1), bucket2node.get(mb2));
}
}
}
}
public void print(PrintStream out) {
int i = 0;
for (Node n : nodes) {
out.printf("Node%d: %s\n", i++, StringTool.join(", ", n.nodes));
for (CPF cpf : n.functions) {
out.printf(" CPFS: %s | %s\n", cpf.getDomainProduct()[0],
StringTool.join(", ", cpf.getDomainProduct()));
}
}
}
public void writeDOT(File f) throws FileNotFoundException {
PrintStream ps = new PrintStream(f);
ps.println("graph {");
for (Node n : nodes) {
for (Node n2 : n.getNeighbors()) {
ps.printf("\"%s\" -- \"%s\";\n", n.getShortName(), n2
.getShortName());
}
}
ps.println("}");
}
public Vector<Node> getTopologicalorder() {
Vector<Node> topOrder = new Vector<Node>();
HashSet<Node> nodesLeft = new HashSet<Node>();
nodesLeft.addAll(nodes);
for (Node n : nodes) {
if (n.parents.isEmpty()) {
topOrder.add(n);
nodesLeft.remove(n);
}
}
// out.println("Start topological order with "
// +StringTool.join(", ", topOrder));
int i = 0;
while (!nodesLeft.isEmpty() && i < 10) {
HashSet<Node> removeNodes = new HashSet<Node>();
// out.println(" Current order: " +StringTool.join(", ",
// topOrder));
for (Node n : nodesLeft) {
// out.println(" - Check for " + n.getShortName() +
// " with parents " + StringTool.join(", ", n.mb.parents));
if (topOrder.containsAll(n.parents)) {
// out.println(" -- Can be inserted!");
topOrder.add(n);
removeNodes.add(n);
}
}
nodesLeft.removeAll(removeNodes);
// i++;
}
return topOrder;
}
public static class Arc {
HashSet<BeliefNode> separator = new HashSet<BeliefNode>();
// messages between Nodes
Vector<Node> nodes = new Vector<Node>();
HashMap<Node, HashSet<MessageFunction>> outMessage = new HashMap<Node, HashSet<MessageFunction>>();
HashMap<Node, HashSet<BeliefNode>> outCPTMessage = new HashMap<Node, HashSet<BeliefNode>>();
public Arc(Node n0, Node n1) {
if (n0 != n1) {
// create separator
/*
* if (n0.mb.bucket == n1.mb.bucket)
* separator.add(n0.mb.bucket.bucketNode); else {
*/
separator = (HashSet<BeliefNode>) n0.nodes.clone();
separator.retainAll(n1.nodes);
// }
// arc informations
nodes.add(n0);
nodes.add(n1);
n0.addArc(n1, this);
n1.addArc(n0, this);
outMessage.put(n0, new HashSet<MessageFunction>());
outMessage.put(n1, new HashSet<MessageFunction>());
outCPTMessage.put(n0, new HashSet<BeliefNode>());
outCPTMessage.put(n1, new HashSet<BeliefNode>());
} else
throw new RuntimeException("1-node loop in graph");
}
public Node getNeighbor(Node n) {
// needs to throw exception when n not in nodes
return nodes.get((nodes.indexOf(n) + 1) % 2);
}
public void addOutMessage(Node n, MessageFunction m) {
outMessage.get(n).add(m);
}
public HashSet<MessageFunction> getOutMessages(Node n) {
return outMessage.get(n);
}
public HashSet<MessageFunction> getInMessage(Node n) {
return this.getOutMessages(this.getNeighbor(n));
}
public void addCPTOutMessage(Node n, BeliefNode bn) {
outCPTMessage.get(n).add(bn);
}
public HashSet<BeliefNode> getCPTOutMessages(Node n) {
return outCPTMessage.get(n);
}
public HashSet<BeliefNode> getCPTInMessage(Node n) {
return this.getCPTOutMessages(this.getNeighbor(n));
}
public void clearOutMessages(Node n) {
outMessage.get(n).clear();
outCPTMessage.get(n).clear();
}
}
public static class Node {
MiniBucket mb;
Vector<CPF> functions = new Vector<CPF>();
HashSet<BeliefNode> nodes = new HashSet<BeliefNode>();
HashSet<Node> parents;
HashMap<Node, Arc> arcs = new HashMap<Node, Arc>();
public Node(MiniBucket mb) {
this.mb = mb;
this.parents = new HashSet<Node>();
for (BucketVar var : mb.items) {
nodes.addAll(var.nodes);
if (var.cpf != null)
functions.add(var.cpf);
}
}
public void addArc(Node n, Arc arc) {
arcs.put(n, arc);
}
public HashSet<Node> getNeighbors() {
return new HashSet<Node>(arcs.keySet());
}
public Arc getArcToNode(Node n) {
return arcs.get(n);
}
public Collection<BeliefNode> getNodes() {
return nodes;
}
public String toString() {
return "Supernode[" + StringTool.join(",", nodes) + "; "
+ StringTool.join("; ", this.functions) + "]";
}
public String getShortName() {
return StringTool.join(",", nodes);
}
}
}
}