/*******************************************************************************
* Copyright (C) 2010-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.util.HashSet;
import java.util.Vector;
import probcog.bayesnets.core.BeliefNetworkEx;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.ksu.cis.bnj.ver3.core.CPT;
import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
import edu.tum.cs.util.Stopwatch;
import edu.tum.cs.util.StringTool;
/**
* The variable elimination algorithm for exact inference in Bayesian networks (see, e.g., AIMA ch. 14)
* @author Dominik Jain
*/
public class VariableElimination extends Sampler {
protected int[] nodeOrder;
protected Stopwatch timer;
protected int[] nodeDomainIndices;
protected SampledDistribution dist;
public VariableElimination(BeliefNetworkEx bn) throws Exception {
super(bn);
nodeOrder = bn.getTopologicalOrder();
}
protected class Factor {
CPF cpf;
public Factor(BeliefNode n) {
cpf = n.getCPF();
BeliefNode[] domprod = cpf.getDomainProduct();
for(int i = 0; i < domprod.length; i++) {
if(evidenceDomainIndices[getNodeIndex(domprod[i])] != 0) {
cpf = removeEvidence(cpf);
break;
}
}
}
protected CPF removeEvidence(CPF cpf) {
BeliefNode[] domprod = cpf.getDomainProduct();
Vector<BeliefNode> domprod2 = new Vector<BeliefNode>();
for(int i = 0; i < domprod.length; i++)
if(evidenceDomainIndices[getNodeIndex(domprod[i])] == -1)
domprod2.add(domprod[i]);
CPF cpf2 = new CPT(domprod2.toArray(new BeliefNode[domprod2.size()]));
int[] addr = new int[domprod.length];
int[] addr2 = new int[domprod2.size()];
removeEvidence(cpf, cpf2, 0, addr, 0, addr2);
return cpf2;
}
protected void removeEvidence(CPF cpf, CPF cpf2, int i, int[] addr, int j, int[] addr2) {
if(i == addr.length) {
cpf2.put(addr2, cpf.get(addr));
return;
}
BeliefNode[] domprod = cpf.getDomainProduct();
BeliefNode[] domprod2 = cpf2.getDomainProduct();
BeliefNode node = domprod[i];
boolean transfer = false;
if(j < domprod2.length)
transfer = domprod2[j] == domprod[i];
int evidence = evidenceDomainIndices[getNodeIndex(node)];
if(evidence != -1) { // this should never happen
addr[i] = evidence;
if(transfer)
addr2[j] = evidence;
removeEvidence(cpf, cpf2, i+1, addr, transfer ? j+1 : j, addr2);
}
else {
int domSize = node.getDomain().getOrder();
for(int domIdx = 0; domIdx < domSize; domIdx++) {
addr[i] = domIdx;
if(transfer)
addr2[j] = domIdx;
removeEvidence(cpf, cpf2, i+1, addr, transfer ? j+1 : j, addr2);
}
}
}
public Factor(CPF cpf) {
this.cpf = cpf;
}
public double getValue(int[] nodeDomainIndices) {
BeliefNode[] domProd = cpf.getDomainProduct();
int[] addr = new int[domProd.length];
for(int i = 0; i < addr.length; i++)
addr[i] = nodeDomainIndices[getNodeIndex(domProd[i])];
return cpf.getDouble(addr);
}
public Factor sumOut(BeliefNode n) {
BeliefNode[] domprod = cpf.getDomainProduct();
BeliefNode[] domprod2 = new BeliefNode[domprod.length-1];
int j = 0;
for(int i = 0; i < domprod.length; i++)
if(domprod[i] != n)
domprod2[j++] = domprod[i];
CPF cpf2 = new CPT(domprod2);
int[] addr = new int[domprod.length];
int[] addr2 = new int[domprod2.length];
sumOut(cpf2, n, 0, addr, 0, addr2);
return new Factor(cpf2);
}
protected void sumOut(CPF cpf2, BeliefNode n, int i, int[] addr, int j, int[] addr2) {
if(i == addr.length) {
int realaddr2 = cpf2.addr2realaddr(addr2);
double v = cpf2.getDouble(realaddr2);
v += cpf.getDouble(addr);
cpf2.put(realaddr2, new ValueDouble(v));
return;
}
BeliefNode node = this.cpf.getDomainProduct()[i];
int evidence = evidenceDomainIndices[getNodeIndex(node)];
if(evidence != -1) { // this should never happen
addr[i] = evidence;
if(node != n)
addr2[j] = evidence;
sumOut(cpf2, n, i+1, addr, node == n ? j : j+1, addr2);
}
else {
int domSize = node.getDomain().getOrder();
for(int domIdx = 0; domIdx < domSize; domIdx++) {
addr[i] = domIdx;
if(node != n)
addr2[j] = domIdx;
sumOut(cpf2, n, i+1, addr, node == n ? j : j+1, addr2);
}
}
}
public String toString() {
return "F(" + StringTool.join(",", cpf.getDomainProduct()) + ")";
}
}
protected Factor join(Iterable<Factor> factors) {
HashSet<BeliefNode> domain = new HashSet<BeliefNode>();
for(Factor f : factors) {
for(BeliefNode n : f.cpf.getDomainProduct())
domain.add(n);
}
BeliefNode[] domProd = domain.toArray(new BeliefNode[domain.size()]);
CPF cpf;
try {
cpf = new CPT(domProd);
} catch (OutOfMemoryError e) {
e.printStackTrace();
double size = 1;
for(int i = 0; i < domProd.length; i++)
size *= domProd[i].getDomain().getOrder();
throw new RuntimeException("Out of memory: Needed at least " + size*8 + " bytes to represent function");
}
int[] addr = new int[domProd.length];
fillCPF(factors, cpf, 0, addr);
return new Factor(cpf);
}
protected void fillCPF(Iterable<Factor> factors, CPF cpf, int i, int[] addr) {
if(i == addr.length) {
double value = 1.0;
for(Factor f : factors) {
value *= f.getValue(nodeDomainIndices);
}
try {
cpf.put(addr, new ValueDouble(value));
}
catch(Exception e) {
System.err.println(StringTool.join(", ", cpf.getDomainProduct()));
throw new RuntimeException(e);
}
return;
}
BeliefNode[] domProd = cpf.getDomainProduct();
int domSize = domProd[i].getDomain().getOrder();
for(int j = 0; j < domSize; j++) {
addr[i] = j;
nodeDomainIndices[getNodeIndex(domProd[i])] = j;
fillCPF(factors, cpf, i+1, addr);
}
}
protected Vector<Factor> sumout(Vector<Factor> factors, BeliefNode n) {
Vector<Factor> newFacs = new Vector<Factor>();
Vector<Factor> joinFacs = new Vector<Factor>();
for(Factor f : factors) {
BeliefNode[] domProd = f.cpf.getDomainProduct();
boolean sumover = false;
for(int i = 0; i < domProd.length; i++)
if(domProd[i] == n)
sumover = true;
if(sumover)
joinFacs.add(f);
else
newFacs.add(f);
}
Factor joinedFac = join(joinFacs);
if(debug) out.println("Summing out " + n + " from " + joinedFac);
newFacs.add(joinedFac.sumOut(n));
return newFacs;
}
protected void computeMarginal(BeliefNode Q) {
Vector<Factor> factors = new Vector<Factor>();
for(int i = nodeOrder.length-1; i >= 0; i--) {
if(!debug) out.printf(" %s %d \r", Q.getName(), i);
int nodeIdx = nodeOrder[i];
BeliefNode node = nodes[nodeIdx];
if(debug) out.println("Current node: " + node);
factors.add(new Factor(node));
if(debug) out.println(factors);
if(evidenceDomainIndices[nodeIdx] == -1 && node != Q)
factors = sumout(factors, node);
}
if(!debug) out.println();
if(debug) out.printf("%d final factors: %s\n", factors.size(), StringTool.join(", ", factors));
// save results to distribution
int nodeIdx = getNodeIndex(Q);
double[] marginal = new double[Q.getDomain().getOrder()];
double Z = 0.0;
for(int i = 0; i < marginal.length; i++) {
nodeDomainIndices[nodeIdx] = i;
marginal[i] = 1.0;
for(Factor f : factors) {
marginal[i] *= f.getValue(nodeDomainIndices);
}
//out.println(factors.get(0).getValue(nodeDomainIndices));
Z += marginal[i];
}
for(int i = 0; i < marginal.length; i++)
marginal[i] /= Z;
dist.values[nodeIdx] = marginal;
}
public void _infer() throws Exception {
Stopwatch sw = new Stopwatch();
sw.start();
dist = createDistribution();
dist.Z = 1.0;
nodeDomainIndices = evidenceDomainIndices.clone();
for(Integer nodeIdx : queryVars)
computeMarginal(nodes[nodeIdx]);
((ImmediateDistributionBuilder)distributionBuilder).setDistribution(dist);
sw.stop();
}
protected IDistributionBuilder createDistributionBuilder() {
return new ImmediateDistributionBuilder();
}
}