/**
* Copyright (c) 2011 Michael Kutschke.
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*
* Contributors:
* Michael Kutschke - initial API and implementation.
*/
package org.eclipse.recommenders.jayes.lbp;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.eclipse.recommenders.jayes.BayesNet;
import org.eclipse.recommenders.jayes.BayesNode;
import org.eclipse.recommenders.jayes.factor.AbstractFactor;
import org.eclipse.recommenders.jayes.factor.DenseFactor;
import org.eclipse.recommenders.jayes.factor.arraywrapper.DoubleArrayWrapper;
import org.eclipse.recommenders.jayes.inference.AbstractInferrer;
import org.eclipse.recommenders.jayes.util.MathUtils;
import org.eclipse.recommenders.jayes.util.Pair;
/**
* an implementation of Loopy Belief Propagation. Not ready for production use, only serves to check the correctness of
* the other algorithms on simple networks.
*
* @author Michael Kutschke
*/
public class LoopyBeliefPropagation extends AbstractInferrer {
/**
* the Factor/Node Graph is represented as bipartite Graph, with nodes being edge sources and factors being the edge
* targets
*/
private final List<List<Edge>> graph = new ArrayList<List<Edge>>();
private final List<List<Edge>> transponedGraph = new ArrayList<List<Edge>>();
private final List<Integer> dirty = new LinkedList<Integer>();
private final Map<Edge, int[]> preparedOps = new HashMap<Edge, int[]>();
private BayesNet net;
private static class Edge {
/**
* message from node to factor
*/
double[] messageXF;
/**
* message from Factor to Node
*/
double[] messageFX;
/**
* source node
*/
int source;
/**
* target factor
*/
int target;
boolean dirty = true;
@Override
public String toString() {
return source + "->" + target;
}
}
@Override
public void addEvidence(final BayesNode node, final String outcome) {
super.addEvidence(node, outcome);
dirty.add(node.getId());
}
@Override
public double[] getBeliefs(final BayesNode node) {
if (!beliefsValid) {
beliefsValid = true;
updateBeliefs();
}
double[] p = new double[node.getOutcomeCount()];
if (evidence.containsKey(node)) {
p[node.getOutcomeIndex(evidence.get(node))] = 1.0;
} else {
for (final Edge e : graph.get(node.getId())) {
// the belief is the product of all adjacent Factor messages,
// normalized
for (int i = 0; i < p.length; i++) {
p[i] += e.messageFX[i];
}
}
p = MathUtils.normalizeLog(p);
MathUtils.exp(p);
}
return p;
}
@Override
public void setNetwork(final BayesNet bn) {
super.setNetwork(bn);
this.net = bn;
for (final BayesNode n : bn.getNodes()) {
transponedGraph.add(new ArrayList<Edge>());
dirty.add(n.getId());
}
for (final BayesNode n : bn.getNodes()) {
final List<Edge> adjacency = new ArrayList<Edge>();
graph.add(adjacency);
// a node is connected to it's own factor
final Edge own = addEdgeToOwnFactor(n, adjacency);
prepareOwnFactor(n, own);
for (final BayesNode c : n.getChildren()) {
// n is parent of c, therefore it's logical "node"-Node is
// connected to c's
// logical "factor"-Node
final Edge e = new Edge();
e.source = n.getId();
e.target = c.getId();
e.messageFX = new double[n.getOutcomeCount()];
e.messageXF = new double[n.getOutcomeCount()];
adjacency.add(e);
transponedGraph.get(c.getId()).add(e);
// we always pass via mult/sum from target to source (or from
// all other sources),
// so we need to prepare operations with c's factor
final AbstractFactor f = new DenseFactor();
f.setDimensions(new int[] { n.getOutcomeCount() });
f.setDimensionIDs(new int[] { n.getId() });
final int[] prep = c.getFactor().prepareMultiplication(f);
preparedOps.put(e, prep);
}
}
resetMessages();
}
private void prepareOwnFactor(final BayesNode n, final Edge own) {
final AbstractFactor fOwn = new DenseFactor();
fOwn.setDimensions(new int[] { n.getOutcomeCount() });
fOwn.setDimensionIDs(new int[] { n.getId() });
final int[] prepOwn = n.getFactor().prepareMultiplication(fOwn);
preparedOps.put(own, prepOwn);
}
private Edge addEdgeToOwnFactor(final BayesNode n, final List<Edge> adjacency) {
final Edge own = new Edge();
own.source = n.getId();
own.target = n.getId();
own.messageFX = new double[n.getOutcomeCount()];
own.messageXF = new double[n.getOutcomeCount()];
adjacency.add(own);
transponedGraph.get(n.getId()).add(own);
return own;
}
@Override
public void setEvidence(final Map<BayesNode, String> evidence) {
beliefsValid = false;
for (final BayesNode n : net.getNodes()) {
if (evidence.containsKey(n)) {
this.evidence.put(n, evidence.get(n));
dirty.add(n.getId());
} else if (this.evidence.remove(n) != null) {
dirty.add(n.getId());
}
}
}
@Override
protected void updateBeliefs() {
resetMessages();
List<Integer> messagePassingOrder = new ArrayList<Integer>();
for (int root = 0; root < graph.size(); root++) {
messagePassingOrder.addAll(postOrder(transponedGraph, root, 1));
messagePassingOrder.addAll(preOrder(graph, root, 1));
}
for (int n : messagePassingOrder) {
updateFactor(n);
for (final Edge e : transponedGraph.get(n)) {
updateNode(e.source);
}
}
}
private List<Integer> postOrder(List<List<Edge>> graph, int root, int maxDepth) {
Deque<Pair<Integer, Iterator<Edge>>> deque = new ArrayDeque<Pair<Integer, Iterator<Edge>>>();
int[] depth = new int[graph.size()];
List<Integer> result = new ArrayList<Integer>();
deque.add(new Pair<Integer, Iterator<Edge>>(root, graph.get(root).iterator()));
depth[root] = 1;
while (!deque.isEmpty()) {
Pair<Integer, Iterator<Edge>> pair = deque.peek();
Iterator<Edge> it = pair.getSecond();
if (!it.hasNext()) {
deque.pop();
result.add(pair.getFirst());
continue;
}
Edge next = it.next();
if (next.source == next.target || depth[next.source] <= depth[next.target]) {
continue;
}
depth[next.target]++;
if (depth[next.target] <= maxDepth) {
deque.push(new Pair<Integer, Iterator<Edge>>(next.target, graph.get(next.target).iterator()));
}
}
return result;
}
private List<Integer> preOrder(List<List<Edge>> graph, int root, int maxDepth) {
Deque<Iterator<Edge>> deque = new ArrayDeque<Iterator<Edge>>();
int[] depth = new int[graph.size()];
List<Integer> result = new ArrayList<Integer>();
result.add(root);
deque.add(graph.get(root).iterator());
depth[root] = 1;
while (!deque.isEmpty()) {
Iterator<Edge> it = deque.peek();
if (!it.hasNext()) {
deque.pop();
continue;
}
Edge next = it.next();
if (next.source == next.target || depth[next.source] <= depth[next.target]) {
continue;
}
result.add(next.target);
depth[next.target]++;
if (depth[next.target] <= maxDepth) {
deque.push(graph.get(next.target).iterator());
}
}
return result;
}
private void resetMessages() {
for (final List<Edge> ad : graph) {
for (final Edge e : ad) {
Arrays.fill(e.messageFX, 0.0);
Arrays.fill(e.messageXF, 0.0);
}
}
}
private void updateNode(final int source) {
for (final Edge e : graph.get(source)) {
if (!e.dirty) {
continue;
}
double[] result = new double[e.messageFX.length];
Arrays.fill(result, 0.0);
// the message to the Factor consists of
// the product of the messages from all other adjacent factors
for (final Edge e2 : graph.get(source)) {
if (e2 != e) {
for (int i = 0; i < result.length; i++) {
result[i] += e2.messageFX[i];
}
}
}
result = MathUtils.normalizeLog(result);
e.messageXF = result;
}
}
private void updateFactor(final int index) {
for (final Edge e : transponedGraph.get(index)) {
e.dirty = false;
final BayesNode n = net.getNode(e.target);
final AbstractFactor f = n.getFactor().clone();
MathUtils.log(f.getValues());
f.setLogScale(true);
selectEvidence(f);
for (final Edge e2 : transponedGraph.get(index)) {
if (e2 != e) {
f.multiplyPrepared(new DoubleArrayWrapper(e2.messageXF), preparedOps.get(e2)); // TODO
}
}
double[] result = new double[net.getNode(e.source).getOutcomeCount()];
f.sumPrepared(new DoubleArrayWrapper(result), preparedOps.get(e));
result = MathUtils.normalizeLog(result);
e.messageFX = result;
e.dirty = true;
}
}
private void selectEvidence(final AbstractFactor f) {
for (int dim : f.getDimensionIDs()) {
BayesNode node = net.getNode(dim);
if (evidence.containsKey(node)) {
f.select(dim, node.getOutcomeIndex(evidence.get(node)));
} else {
f.select(dim, -1);
}
}
}
}