/*******************************************************************************
* Copyright (C) 2008-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.PriorityQueue;
import java.util.Vector;
import probcog.bayesnets.core.BeliefNetworkEx;
import probcog.bayesnets.util.TopologicalOrdering;
import probcog.bayesnets.util.TopologicalSort;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.ksu.cis.bnj.ver3.core.Discrete;
import edu.tum.cs.util.Stopwatch;
import edu.tum.cs.util.StringTool;
/**
* an implementation of the backward simulation algorithm as described by Robert Fung and Brendan Del Favero
* in "Backward Simulation in Bayesian Networks" (UAI 1994)
*
* @author Dominik Jain
*/
public class BackwardSampling extends Sampler {
protected Vector<BeliefNode> backwardSampledNodes;
protected Vector<BeliefNode> forwardSampledNodes;
protected HashSet<BeliefNode> outsideSamplingOrder;
protected int currentStep;
public static class BackSamplingDistribution {
public Vector<Double> distribution;
public Vector<int[]> states;
double Z;
protected Sampler sampler;
public BackSamplingDistribution(Sampler sampler) {
Z = 0.0;
distribution = new Vector<Double>();
states = new Vector<int[]>();
this.sampler = sampler;
}
public void addValue(double p, int[] state) {
distribution.add(p);
states.add(state);
Z += p;
}
public double getWeightingFactor(int sampledValue) {
return Z;
}
public void applyWeight(WeightedSample s, int sampledValue) {
s.weight *= getWeightingFactor(sampledValue);
}
public void construct(BeliefNode node, int[] nodeDomainIndices) {
CPF cpf = node.getCPF();
BeliefNode[] domProd = cpf.getDomainProduct();
int[] addr = new int[domProd.length];
addr[0] = nodeDomainIndices[sampler.nodeIndices.get(node)];
construct(1, addr, cpf, nodeDomainIndices);
}
/**
* recursively constructs the distribution to backward sample from
* @param i the node to instantiate next (as an index into the CPF's domain product)
* @param addr the current setting of node indices of the CPF's domain product
* @param cpf the conditional probability function of the node we are backward sampling
* @param d the distribution to fill
*/
protected void construct(int i, int[] addr, CPF cpf, int[] nodeDomainIndices) {
if(i == addr.length) {
double p = cpf.getDouble(addr);
if(p != 0)
addValue(p, addr.clone());
return;
}
BeliefNode[] domProd = cpf.getDomainProduct();
int nodeIdx = sampler.nodeIndices.get(domProd[i]);
if(nodeDomainIndices[nodeIdx] >= 0) {
addr[i] = nodeDomainIndices[nodeIdx];
construct(i+1, addr, cpf, nodeDomainIndices);
}
else {
Discrete dom = (Discrete)domProd[i].getDomain();
for(int j = 0; j < dom.getOrder(); j++) {
addr[i] = j;
construct(i+1, addr, cpf, nodeDomainIndices);
}
}
}
}
public BackwardSampling(BeliefNetworkEx bn) throws Exception {
super(bn);
}
/**
* for ordering belief nodes in descending order of the tier they are in (as indicated by the topological ordering)
* @author Dominik Jain
*
*/
public static class TierComparator implements Comparator<BeliefNode> {
TopologicalOrdering topOrder;
public TierComparator(TopologicalOrdering topOrder) {
this.topOrder = topOrder;
}
public int compare(BeliefNode o1, BeliefNode o2) {
return -(topOrder.getTier(o1) - topOrder.getTier(o2));
}
}
/**
* gets the sampling order by filling the members for backward and forward sampled nodes as well as the set of nodes not in the sampling order
* @param evidenceDomainIndices
* @throws Exception
*/
protected void getOrdering(int[] evidenceDomainIndices) throws Exception {
HashSet<BeliefNode> uninstantiatedNodes = new HashSet<BeliefNode>(Arrays.asList(nodes));
backwardSampledNodes = new Vector<BeliefNode>();
forwardSampledNodes = new Vector<BeliefNode>();
outsideSamplingOrder = new HashSet<BeliefNode>();
TopologicalOrdering topOrder = new TopologicalSort(bn.bn).run(true);
PriorityQueue<BeliefNode> backSamplingCandidates = new PriorityQueue<BeliefNode>(1, new TierComparator(topOrder));
// check which nodes have evidence; ones that are are candidates for backward sampling and are instantiated
for(int i = 0; i < evidenceDomainIndices.length; i++) {
if(evidenceDomainIndices[i] >= 0) {
backSamplingCandidates.add(nodes[i]);
uninstantiatedNodes.remove(nodes[i]);
}
}
// check all backward sampling candidates
while(!backSamplingCandidates.isEmpty()) {
BeliefNode node = backSamplingCandidates.remove();
// check if there are any uninstantiated parents
BeliefNode[] domProd = node.getCPF().getDomainProduct();
boolean doBackSampling = false;
for(int j = 1; j < domProd.length; j++) {
BeliefNode parent = domProd[j];
// if there are uninstantiated parents, we do backward sampling on the child node
if(uninstantiatedNodes.remove(parent)) {
doBackSampling = true;
backSamplingCandidates.add(parent);
}
}
if(doBackSampling)
backwardSampledNodes.add(node);
// if there are no uninstantiated parents, the node is not backward sampled but is instantiated,
// i.e. it is not in the sampling order
else
outsideSamplingOrder.add(node);
}
// schedule all uninstantiated node for forward sampling in the topological order
for(int i : topOrder) {
if(uninstantiatedNodes.contains(nodes[i]))
forwardSampledNodes.add(nodes[i]);
}
}
/**
* samples backward from the given node, instantiating its parents
* @param node
* @param s the sample to store the instantiation information in; the weight is also updated with the normalizing constant that is obtained
* @return true if sampling succeeded, false otherwise
*/
protected boolean sampleBackward(BeliefNode node, WeightedSample s) {
//out.println("backward sampling from " + node);
// get the distribution from which to sample
BackSamplingDistribution d = getBackSamplingDistribution(node, s);
// sample
int idx = sample(d.distribution, generator);
if(idx == -1)
return false;
int[] state = d.states.get(idx);
// apply weight
d.applyWeight(s, idx);
if(s.weight == 0.0)
return false;
// apply sampled parent setting
BeliefNode[] domProd = node.getCPF().getDomainProduct();
for(int i = 1; i < state.length; i++) {
int nodeIdx = this.nodeIndices.get(domProd[i]);
s.nodeDomainIndices[nodeIdx] = state[i];
//out.println(" sampled node " + domProd[i]);
}
return true;
}
protected BackSamplingDistribution getBackSamplingDistribution(BeliefNode node, WeightedSample s) {
BackSamplingDistribution d = new BackSamplingDistribution(this);
d.construct(node, s.nodeDomainIndices);
return d;
}
@Override
protected void _initialize() throws Exception {
getOrdering(evidenceDomainIndices);
if(debug) {
out.println("sampling backward: " + this.backwardSampledNodes);
out.println("sampling forward: " + this.forwardSampledNodes);
out.println("not in order: " + this.outsideSamplingOrder);
}
}
@Override
public void _infer() throws Exception {
Stopwatch sw = new Stopwatch();
sw.start();
if(verbose) out.println("sampling...");
WeightedSample s = new WeightedSample(this.bn, evidenceDomainIndices.clone(), 1.0, null, 0);
for(currentStep = 1; currentStep <= this.numSamples; currentStep++) {
if(verbose && currentStep % infoInterval == 0)
out.println(" step " + currentStep);
getSample(s);
this.addSample(s);
onAddedSample(s);
if(converged())
break;
}
sw.stop();
SampledDistribution dist = distributionBuilder.getDistribution();
report(String.format("time taken: %.2fs (%.4fs per sample, %.1f trials/step)\n", sw.getElapsedTimeSecs(), sw.getElapsedTimeSecs()/numSamples, dist.getTrialsPerStep()));
}
/**
* gets one full sample of all of the nodes
* @param s
* @throws Exception
*/
public void getSample(WeightedSample s) throws Exception {
int MAX_TRIALS = this.maxTrials;
loop1: for(int t = 1; t <= MAX_TRIALS || MAX_TRIALS == 0; t++) {
// initialize sample
initSample(s);
// backward sampling
for(BeliefNode node : backwardSampledNodes) {
if(!sampleBackward(node, s)) {
if(debug) out.println("!!! backward sampling failed at " + node + " in step " + currentStep);
continue loop1;
}
}
//out.println("after backward: weight = " + s.weight);
// forward sampling
for(BeliefNode node : forwardSampledNodes) {
if(!sampleForward(node, s)) {
if(debug) {/*
BeliefNode[] domain_product = node.getCPF().getDomainProduct();
StringBuffer cond = new StringBuffer();
for(int i = 1; i < domain_product.length; i++) {
if(i > 1)
cond.append(", ");
cond.append(domain_product[i].getName()).append(" = ");
cond.append(domain_product[i].getDomain().getName(s.nodeDomainIndices[this.getNodeIndex(domain_product[i])]));
}*/
out.println("!!! forward sampling failed at " + node + " in step " + currentStep + "; cond: " + s.getCPDLookupString(node));
}
continue loop1;
}
}
//out.println("after forward: weight = " + s.weight);
// nodes outside the sampling order: adjust weight
for(BeliefNode node : outsideSamplingOrder) {
double p = this.getCPTProbability(node, s.nodeDomainIndices);
s.weight *= p;
if(s.weight == 0.0) {
if(p != 0.0)
throw new Exception("Precision loss in weight calculation");
// error diagnosis
if(debug) out.println("!!! weight became zero at unordered node " + node + " in step " + currentStep + "; cond: " + s.getCPDLookupString(node));
if(debug && this instanceof BackwardSamplingWithPriors) {
double[] dist = ((BackwardSamplingWithPriors)this).priors.get(node);
out.println("prior: " + StringTool.join(", ", dist) + " value=" + s.nodeDomainIndices[getNodeIndex(node)]);
CPF cpf = node.getCPF();
BeliefNode[] domProd = cpf.getDomainProduct();
int[] addr = new int[domProd.length];
for(int i = 1; i < addr.length; i++)
addr[i] = s.nodeDomainIndices[getNodeIndex(domProd[i])];
for(int i = 0; i < dist.length; i++) {
addr[0] = i;
dist[i] = cpf.getDouble(addr);
}
out.println("cpd: " + StringTool.join(", ", dist));
}
continue loop1;
}
}
// sample could be obtained in this trial (t)
s.trials = t;
return;
}
throw new RuntimeException("Maximum number of trials exceeded.");
}
public void initSample(WeightedSample s) throws Exception {
s.nodeDomainIndices = evidenceDomainIndices.clone();
s.weight = 1.0;
s.trials = 1;
s.operations = 0;
}
protected boolean sampleForward(BeliefNode node, WeightedSample s) {
int idx = super.sampleForward(node, s.nodeDomainIndices);
if(idx == -1)
return false;
s.nodeDomainIndices[this.nodeIndices.get(node)] = idx;
return true;
}
protected void onAddedSample(WeightedSample s) throws Exception {
}
}