/******************************************************************************* * Copyright (C) 2010-2012 Stefan Waldherr, Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.bayesnets.inference; import java.util.HashMap; import java.util.Vector; import probcog.bayesnets.core.BeliefNetworkEx; import edu.ksu.cis.bnj.ver3.core.BeliefNode; import edu.tum.cs.util.datastruct.MutableDouble; /** * (iterative/loopy) belief propagation * @author Stefan Waldherr */ public class BeliefPropagation extends Sampler { protected BeliefNode[] nodes; protected int[] topOrder; protected HashMap<BeliefNode,double[]> lambda; protected HashMap<BeliefNode,double[]> pi; protected HashMap<BeliefNode, BeliefMessageContainer> messages; //links nodes to their message contaienr protected HashMap<BeliefNode,double[]> priors; public class BeliefMessageContainer{ public HashMap<BeliefNode, double[]> lambdaMessages; public HashMap<BeliefNode, double[]> piMessages; protected BeliefNode node; protected int nodeOrder; public BeliefMessageContainer(BeliefNode node){ // contains all pi and lambda messages that are sent from this node lambdaMessages = new HashMap<BeliefNode, double[]>(); piMessages = new HashMap<BeliefNode, double[]>(); this.node = node; nodeOrder = node.getDomain().getOrder(); //initialize empty messages for (BeliefNode n : bn.bn.getChildren(node)){ double[] initPi = new double[nodeOrder]; for (int i = 0; i < nodeOrder; i++){ initPi[i] = 1.0/nodeOrder; } piMessages.put(n,initPi); } for(BeliefNode n : bn.bn.getParents(node)){ int parentOrder = n.getDomain().getOrder(); double[] initLambda = new double[parentOrder]; for (int i = 0; i < parentOrder; i++){ initLambda[i] = 1.0/parentOrder; } lambdaMessages.put(n, initLambda); } } public void computePiMessages(BeliefNode n){ double normalize = 0.0; for (int i = 0; i < nodeOrder; i++){ double prod = 1.0; for (BeliefNode c : piMessages.keySet()){ if (c != n){ prod *= messages.get(c).lambdaMessages.get(node)[i]; } } double entry = prod * pi.get(node)[i]; piMessages.get(n)[i] = entry; normalize += entry; } // normalize if (normalize != 0.0){ if (normalize == 0.0) return; for (int i = 0; i < nodeOrder; i++){ piMessages.get(n)[i] /= normalize; } } } public void computeLambdaMessages(BeliefNode n, int[] nodeDomainIndices) { // determine the variables to sum over Vector<Integer> varsToSumOver = new Vector<Integer>(); for (BeliefNode p : lambdaMessages.keySet()){ if (p != n && nodeDomainIndices[getNodeIndex(p)] == -1){ // TODO for Stefan to ack: replaced bn.getNodeIndex by this.getNodeIndex, because the former is very inefficient, since it constructs the vector of nodes each time and does a linear search varsToSumOver.add(getNodeIndex(p)); } } // actual calculation of lambda message double normalize = 0.0; for (int i = 0; i < lambdaMessages.get(n).length; i++){ nodeDomainIndices[getNodeIndex(n)] = i; double sum = 0.0; for (int j = 0; j < nodeOrder; j++){ nodeDomainIndices[getNodeIndex(node)] = j; double prod = lambda.get(node)[j]; MutableDouble mutableSum = new MutableDouble(0.0); computeLambdaMessages(n, varsToSumOver,0,nodeDomainIndices,mutableSum); sum += prod * mutableSum.value; } lambdaMessages.get(n)[i] = sum; normalize += sum; } if (normalize != 0.0){ if (normalize == 0.0) return; for (int i = 0; i < lambdaMessages.get(n).length; i++){ lambdaMessages.get(n)[i] /= normalize; } } } protected void computeLambdaMessages(BeliefNode n, Vector<Integer> varsToSumOver, int i, int[] nodeDomainIndices, MutableDouble sum) { if (i == varsToSumOver.size()) { double result = getCPTProbability(node, nodeDomainIndices); // multiply with incoming pi messages for (BeliefNode p : bn.bn.getParents(node)){ if (n != p){ result *= messages.get(p).piMessages.get(node)[nodeDomainIndices[getNodeIndex(p)]]; } } sum.value += result; return; } int idxVar = varsToSumOver.get(i); for (int v = 0; v < nodes[idxVar].getDomain().getOrder(); v++) { nodeDomainIndices[idxVar] = v; computeLambdaMessages(n, varsToSumOver, i + 1, nodeDomainIndices, sum); } } public boolean sentPiMessageTo(BeliefNode c){ if (pi.containsKey(c)){ double sum = 0.0; for (double d : pi.get(c)){ sum += d; } return (sum != 0.0); } else{ return false; } } public boolean sentLambdaMessageTo(BeliefNode p){ if (lambdaMessages.containsKey(p)){ double sum = 0.0; for (double d : lambdaMessages.get(p)){ sum += d; } return (sum != 0.0); } else{ return false; } } } public void computePi(BeliefNode n, int[] nodeDomainIndices){ // determine the variables to sum over if (evidenceDomainIndices[getNodeIndex(n)] != -1) return; Vector<Integer> varsToSumOver = new Vector<Integer>(); // TODO varsToSumOver doesn't change; if need be, we can cache it for (BeliefNode p : bn.bn.getParents(n)){ if (nodeDomainIndices[getNodeIndex(p)] == -1){ varsToSumOver.add(getNodeIndex(p)); } } double normalize = 0.0; for (int i = 0; i < pi.get(n).length; i++){ nodeDomainIndices[getNodeIndex(n)] = i; MutableDouble mutableSum = new MutableDouble(0.0); computePi(n,varsToSumOver,0,nodeDomainIndices,mutableSum); pi.get(n)[i] = mutableSum.value; normalize += mutableSum.value; } if (normalize == 0.0) return; for (int i = 0; i < pi.get(n).length; i++){ pi.get(n)[i] /= normalize; } } protected void computePi(BeliefNode n, Vector<Integer> varsToSumOver, int i, int[] nodeDomainIndices, MutableDouble sum) { if (i == varsToSumOver.size()) { double result = getCPTProbability(n, nodeDomainIndices); // multiply with incoming pi messages for (BeliefNode p : bn.bn.getParents(n)){ result *= messages.get(p).piMessages.get(n)[nodeDomainIndices[getNodeIndex(p)]]; } sum.value += result; return; } int idxVar = varsToSumOver.get(i); for (int v = 0; v < nodes[idxVar].getDomain().getOrder(); v++) { nodeDomainIndices[idxVar] = v; computePi(n, varsToSumOver, i + 1, nodeDomainIndices, sum); } } public void computeLambda(BeliefNode n){ if (evidenceDomainIndices[getNodeIndex(n)] != -1) return; double normalize = 0.0; for (int i = 0; i < lambda.get(n).length; i++){ double prod = 1.0; for (BeliefNode c : bn.bn.getChildren(n)){ prod *= messages.get(c).lambdaMessages.get(n)[i]; } lambda.get(n)[i] = prod; normalize += prod; } if (normalize == 0.0) return; for (int i = 0; i < lambda.get(n).length; i++){ lambda.get(n)[i] /= normalize; } } public BeliefPropagation(BeliefNetworkEx bn) throws Exception { super(bn); // Initialization of BP nodes = bn.getNodes(); topOrder = bn.getTopologicalOrder(); lambda = new HashMap<BeliefNode,double[]>(); pi = new HashMap<BeliefNode,double[]>(); messages = new HashMap<BeliefNode, BeliefMessageContainer>(); } @Override public String getAlgorithmName() { return String.format("Belief Propagation"); } @Override protected void _infer() throws Exception { // initialization of lambda and pi. priors = bn.computePriors(evidenceDomainIndices); for (BeliefNode n : nodes){ int domSize = n.getDomain().getOrder(); int domIdx = evidenceDomainIndices[getNodeIndex(n)]; // build message function hashmap messages.put(n, new BeliefMessageContainer(n)); // initialize probability distribution double[] init = new double[domSize]; for (int i = 0; i < init.length; i++){ init[i] = 0.0; } // initialize evidence variables if (domIdx != -1){ init[domIdx] = 1.0; lambda.put(n, init.clone()); pi.put(n, init.clone()); } else{ // initialize nodes without parents if (bn.bn.getParents(n).length == 0){ double[] prior = priors.get(n); pi.put(n, prior); } // initialize nodes without children if (bn.bn.getChildren(n).length == 0){ double normalized = 1 / (double) domSize; double[] uniform = new double[domSize]; for (int i = 0; i < uniform.length; i++){ uniform[i] = normalized; } lambda.put(n, uniform); } if (!pi.containsKey(n)){ pi.put(n, init.clone()); } if (!lambda.containsKey(n)){ lambda.put(n, init.clone()); } } } if (debug){ out.println("After initialization process"); for (BeliefNode n : nodes){ out.println(" Node: " + n); out.println(" Pi(x):" + n); for(int i = 0; i < pi.get(n).length; i++){ out.println(" " + i + ": " + pi.get(n)[i]); } out.println(" Lambda(x):" + n); for(int i = 0; i < lambda.get(n).length; i++){ out.println(" " + i + ": " + lambda.get(n)[i]); } } } // Belief Propagation Steps for (int step = 1; step <= this.numSamples; step++) { if(verbose && step % this.infoInterval == 0) out.println("step " + step); // calculate pi(x) for (BeliefNode n : nodes){ int[] nodeDomainIndices = evidenceDomainIndices.clone(); // TODO why do we clone each time? Isn't it enough to have one copy to work with for all nodes. // Not cloning this lead to complications in the IJGP algo so I continued cloning here. boolean receivedAll = true; // check whether n has received all pi messages from its parents for (BeliefNode c : bn.bn.getParents(n)){ double sum = 0.0; for (double d : messages.get(c).piMessages.get(n)){ sum += d; } if (sum == 0.0){ receivedAll = false; } } if (receivedAll){ computePi(n, nodeDomainIndices); } } // calculate lambda(x): for (BeliefNode n : nodes){ boolean receivedAll = true; //check whether n has received all lambda messages from its children for (BeliefNode c : bn.bn.getChildren(n)){ double sum = 0.0; for (double d : messages.get(c).lambdaMessages.get(n)){ sum += d; } if (sum == 0.0){ receivedAll = false; break; } } if (receivedAll && bn.bn.getChildren(n).length > 0){ computeLambda(n); } } // calculate outgoing pi messages for every node for (BeliefNode n : nodes){ //if pi has been calculated... double sum = 0.0; for (double d : pi.get(n)){ sum += d; } if (sum != 0){ BeliefNode[] children = bn.bn.getChildren(n); // TODO for Stefan to ack: getChildren was called twice, should avoid because it's an expensive call (always reallocates the vector of children) for (BeliefNode c : children){ // ... and lambda received from all children except c boolean receivedAll = true; for (BeliefNode c2 : children){ if ((c2 != c) && !messages.get(c2).sentLambdaMessageTo(n)){ receivedAll = false; break; } } if (receivedAll){ messages.get(n).computePiMessages(c); } } } } // calculate outgoing lambda messages for every node for (BeliefNode n : nodes){ //if lambda has been calculated... double sum = 0.0; for (double d : lambda.get(n)){ sum += d; } if (sum != 0){ for (BeliefNode p : bn.bn.getParents(n)){ // TODO note: a slightly better way of getting the parents (because it doesn't allocate any memory) is to use getCPF().getDomainProduct and iterate over the elements 1 to end // ... and pi received from all parents except p boolean receivedAll = true; for (BeliefNode p2 : bn.bn.getParents(n)){ if ((p2 != p) && !messages.get(p2).sentPiMessageTo(n)){ receivedAll = false; break; } } if (receivedAll){ int[] nodeDomainIndices = evidenceDomainIndices.clone(); messages.get(n).computeLambdaMessages(p, nodeDomainIndices); } } } } if(debug){ out.println("\n\n****After step " + step + "****"); out.println("\n Pi and Lambda Functions"); for (BeliefNode n : nodes){ out.println(" Node: " + n); out.println(" Pi(x):" + n); for(int i = 0; i < pi.get(n).length; i++){ out.println(" " + i + ": " + pi.get(n)[i]); } out.println(" Lambda(x):" + n); for(int i = 0; i < lambda.get(n).length; i++){ out.println(" " + i + ": " + lambda.get(n)[i]); } } out.println("\n Message Functions"); for (BeliefNode n : nodes){ out.println("Node: " + n); for (BeliefNode c : messages.get(n).piMessages.keySet()){ out.println(" Pi-Message to " + c + ":"); for(int i = 0; i < messages.get(n).piMessages.get(c).length; i++){ out.println(" " + i + ": " + messages.get(n).piMessages.get(c)[i]); } } for (BeliefNode c : messages.get(n).lambdaMessages.keySet()){ out.println(" Lambda to (x):" + c); for(int i = 0; i < messages.get(n).lambdaMessages.get(c).length; i++){ out.println(" " + i + ": " + messages.get(n).lambdaMessages.get(c)[i]); } } } } } // compute probabilities and store results in distribution if(verbose) out.println("computing results...."); SampledDistribution dist = createDistribution(); dist.Z = 1.0; for (BeliefNode n : nodes) { int i = getNodeIndex(n); if (evidenceDomainIndices[i] >= 0) { dist.values[i][evidenceDomainIndices[i]] = 1.0; continue; } int domSize = dist.values[i].length; double normalize = 0.0; for (int j = 0; j < domSize; j++) { dist.values[i][j] = lambda.get(n)[j]*pi.get(n)[j]; normalize += dist.values[i][j]; } for (int j = 0; j < domSize; j++) { if (normalize == 0.0) continue; dist.values[i][j] /= normalize; } } ((ImmediateDistributionBuilder)distributionBuilder).setDistribution(dist); } protected IDistributionBuilder createDistributionBuilder() { return new ImmediateDistributionBuilder(); } }