/*******************************************************************************
* Copyright (C) 2008-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.util.HashSet;
import probcog.bayesnets.core.BeliefNetworkEx;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.ksu.cis.bnj.ver3.core.Discrete;
import edu.ksu.cis.bnj.ver3.core.Domain;
import edu.tum.cs.util.Stopwatch;
import edu.tum.cs.util.datastruct.Cache2D;
import edu.tum.cs.util.datastruct.MutableDouble;
/**
* a backward sampling algorithm that, to sample the parents of an instantiated node N, considers
* not only the conditional probability of N given its parents but also the the children of N's parents
* and their parents (using existing instantiations and, where nodes are yet uninstantiated,
* the prior probability of the nodes)
*
* @author Dominik Jain
*/
public class BackwardSamplingWithChildren extends BackwardSamplingWithPriors {
protected Cache2D<CPF, Integer, Double> probCache;
protected Cache2D<BeliefNode, Long, BackSamplingDistribution> distCache;
protected Stopwatch probSW, distSW;
public class BackSamplingDistribution extends probcog.bayesnets.inference.BackwardSamplingWithPriors.BackSamplingDistribution {
public BackSamplingDistribution(BackwardSamplingWithPriors sampler) {
super(sampler);
}
/**
* recursively gets a distribution to backward sample from
* @param i the node to instantiate next (as an index into the CPF's domain product)
* @param addr the current setting of node indices of the CPF's domain product
* @param cpf the conditional probability function of the node we are backward sampling
*/
@Override
protected void construct(int i, int[] addr, CPF cpf, int[] nodeDomainIndices) {
BeliefNode[] domProd = cpf.getDomainProduct();
if(i == addr.length) {
double child_prob = cpf.getDouble(addr);
// temporarily set evidence
boolean[] tempEvidence = new boolean[addr.length];
for(int k = 1; k < addr.length; k++) {
int nodeIdx = sampler.nodeIndices.get(domProd[k]);
tempEvidence[k] = nodeDomainIndices[nodeIdx] == -1;
if(tempEvidence[k])
nodeDomainIndices[nodeIdx] = addr[k];
}
// consider parent configuration
double parent_prob = 1.0;
HashSet<BeliefNode> handledChildren = new HashSet<BeliefNode>();
handledChildren.add(domProd[0]);
for(int j = 1; j < addr.length; j++) {
double[] parentPrior = ((BackwardSamplingWithPriors)sampler).priors.get(domProd[j]);
parent_prob *= parentPrior[addr[j]];
// consider children of parents with evidence
// get child probability
BeliefNode[] children = sampler.bn.bn.getChildren(domProd[j]);
for(BeliefNode child : children) {
if(nodeDomainIndices[sampler.getNodeIndex(child)] >= 0 && !handledChildren.contains(child)) {
CPF childCPF = child.getCPF();
//getProb(childCPF, 0, new int[childCPF.getDomainProduct().length], nodeDomainIndices, p);
double p = getProb(childCPF, nodeDomainIndices);
parent_prob *= p;
handledChildren.add(child);
}
}
}
// unset temporary evidence
for(int k = 1; k < addr.length; k++) {
if(tempEvidence[k])
nodeDomainIndices[sampler.nodeIndices.get(domProd[k])] = -1;
}
// add to distribution
double p = child_prob * parent_prob;
if(p != 0) {
addValue(p, addr.clone());
parentProbs.add(parent_prob);
}
return;
}
int nodeIdx = sampler.nodeIndices.get(domProd[i]);
if(nodeDomainIndices[nodeIdx] >= 0) {
addr[i] = nodeDomainIndices[nodeIdx];
construct(i+1, addr, cpf, nodeDomainIndices);
}
else {
Discrete dom = (Discrete)domProd[i].getDomain();
for(int j = 0; j < dom.getOrder(); j++) {
addr[i] = j;
construct(i+1, addr, cpf, nodeDomainIndices);
}
}
}
protected double getProb(CPF cpf, int[] nodeDomainIndices) {
final boolean debugCache = false;
probSW.start();
// get the key in the CPF-specific cache
Double cacheValue = null;
BeliefNode[] domProd = cpf.getDomainProduct();
int[] addr = new int[domProd.length];
boolean allSet = true;
int key = 0;
for(int i = 0; i < addr.length; i++) {
int idx = nodeDomainIndices[sampler.getNodeIndex(domProd[i])];
allSet = allSet && idx >= 0;
addr[i] = idx;
key *= cpf._SizeBuffer[i]+1;
key += idx == -1 ? cpf._SizeBuffer[i] : idx;
}
if(allSet) {
probSW.stop();
return cpf.getDouble(addr);
}
// check if we already have the value in the cache
Double value = cacheValue = probCache.get(cpf, key);
if(!debugCache && value != null) {
probSW.stop();
return value;
}
// not in the cache, so calculate the value
MutableDouble p = new MutableDouble(0.0);
getProb(cpf, 0, addr, nodeDomainIndices, p);
// store in cache
probCache.put(p.value);
// return value
if(cacheValue != null && p.value != cacheValue) {
throw new RuntimeException("cache mismatch");
}
probSW.stop();
return p.value;
}
/**
* gets the probability indicated by the given CPF for the given domain indices, summing over all parents whose values are not set (i.e. set to -1) in nodeDomainIndices;
* i.e. computes the probability of the node whose CPF is provided given the evidence set in nodeDomainIndices
* @param cpf the conditional probability function
* @param i index of the next node to instantiate
* @param addr the address (list of node domain indices relevant to the CPF)
* @param nodeDomainIndices evidences (mapping of all nodes in the network to domain indices, -1 for no evidence)
* @param ret variable in which to store the result (initialize to 0.0, because we are summing probability values)
*/
protected void getProb(CPF cpf, int i, int[] addr, int[] nodeDomainIndices, MutableDouble ret) {
BeliefNode[] domProd = cpf.getDomainProduct();
// if all nodes have been instantiated...
if(i == addr.length) {
double p = cpf.getDouble(addr);
for(int j = 1; j < addr.length; j++) {
if(nodeDomainIndices[sampler.getNodeIndex(domProd[j])] == -1); {
double[] parentPrior = ((BackwardSamplingWithPriors)sampler).priors.get(domProd[j]);
p *= parentPrior[addr[j]];
}
}
ret.value += p;
return;
}
// otherwise instantiate the next node
BeliefNode node = domProd[i];
int nodeIdx = sampler.getNodeIndex(node);
// - if we have evidence, use it
if(nodeDomainIndices[nodeIdx] >= 0) {
addr[i] = nodeDomainIndices[nodeIdx];
getProb(cpf, i+1, addr, nodeDomainIndices, ret);
}
// - otherwise sum over all settings
else {
Domain dom = node.getDomain();
for(int j = 0; j < dom.getOrder(); j++) {
addr[i] = j;
getProb(cpf, i+1, addr, nodeDomainIndices, ret);
}
}
}
}
@Override
protected BackSamplingDistribution getBackSamplingDistribution(BeliefNode node, WeightedSample s) {
BackSamplingDistribution d;
long key = 0;
final boolean useCache = true;
distSW.start();
if(useCache) { // TODO optimize this further (semi-lifted): because the distributions of many nodes are identical, use some index that combines the relational node's index plus possible constant node settings
// calculate key
BeliefNode[] domProd = node.getCPF().getDomainProduct();
// - consider node itself and all parents
for(int i = 0; i < domProd.length; i++) {
BeliefNode n = domProd[i];
int idx = s.nodeDomainIndices[getNodeIndex(n)];
int order = n.getDomain().getOrder();
key *= order + 1;
key += idx == -1 ? order : idx;
// - children of parents
if(i != 0) {
BeliefNode[] children = bn.bn.getChildren(n);
for(int j = 0; j < children.length; j++) {
if(children[j] != node) {
n = children[j];
idx = s.nodeDomainIndices[getNodeIndex(n)];
order = n.getDomain().getOrder();
key *= order + 1;
key += idx == -1 ? order : idx;
// - parents of children
BeliefNode[] parentsofchildren = children[j].getCPF().getDomainProduct();
for(int k = 1; k < parentsofchildren.length; k++) {
n = parentsofchildren[k];
idx = s.nodeDomainIndices[getNodeIndex(n)];
order = n.getDomain().getOrder();
key *= order + 1;
key += idx == -1 ? order : idx;
}
}
}
}
}
// check if we have a cache value
d = distCache.get(node, key);
if(d != null)
return d;
}
// obtain new distribution
d = new BackSamplingDistribution(this);
d.construct(node, s.nodeDomainIndices);
// store in cache
if(useCache)
distCache.put(d);
distSW.stop();
return d;
}
public BackwardSamplingWithChildren(BeliefNetworkEx bn) throws Exception {
super(bn);
}
@Override
public void _initialize() throws Exception {
probCache = new Cache2D<CPF, Integer, Double>();
distCache = new Cache2D<BeliefNode, Long, BackSamplingDistribution>();
super._initialize();
}
@Override
public void _infer() throws Exception {
probSW = new Stopwatch();
distSW = new Stopwatch();
super._infer();
report("prob time: " + probSW.getElapsedTimeSecs());
report(String.format(" cache hit ratio: %f (%d accesses)", this.probCache.getHitRatio(), this.probCache.getNumAccesses()));
report("dist time: " + distSW.getElapsedTimeSecs());
report(String.format(" cache hit ratio: %f (%d accesses)", this.distCache.getHitRatio(), this.distCache.getNumAccesses()));
}
}