/*******************************************************************************
* Copyright (C) 2008-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.util.HashMap;
import java.util.Vector;
import probcog.bayesnets.core.BeliefNetworkEx;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.ksu.cis.bnj.ver3.core.Discrete;
import edu.ksu.cis.bnj.ver3.core.Domain;
/**
* @author Dominik Jain
*/
public class BackwardSamplingWithPriors extends BackwardSampling {
public HashMap<BeliefNode, double[]> priors;
public static class BackSamplingDistribution extends probcog.bayesnets.inference.BackwardSampling.BackSamplingDistribution {
public Vector<Double> parentProbs;
public BackSamplingDistribution(BackwardSamplingWithPriors sampler) {
super(sampler);
parentProbs = new Vector<Double>();
}
/**
* recursively gets a distribution to backward sample from (represented in probs; the corresponding node states stored in states)
* @param i the node to instantiate next (as an index into the CPF's domain product)
* @param addr the current setting of node indices of the CPF's domain product
* @param cpf the conditional probability function of the node we are backward sampling
*/
@Override
protected void construct(int i, int[] addr, CPF cpf, int[] nodeDomainIndices) {
BeliefNode[] domProd = cpf.getDomainProduct();
if(i == addr.length) {
double child_prob = cpf.getDouble(addr);
double parent_prob = 1.0;
for(int j = 1; j < addr.length; j++) {
double[] parentPrior = ((BackwardSamplingWithPriors)sampler).priors.get(domProd[j]);
parent_prob *= parentPrior[addr[j]];
}
double p = child_prob * parent_prob;
if(p != 0) {
addValue(p, addr.clone());
parentProbs.add(parent_prob);
}
return;
}
int nodeIdx = sampler.nodeIndices.get(domProd[i]);
if(nodeDomainIndices[nodeIdx] >= 0) {
addr[i] = nodeDomainIndices[nodeIdx];
construct(i+1, addr, cpf, nodeDomainIndices);
}
else {
Discrete dom = (Discrete)domProd[i].getDomain();
for(int j = 0; j < dom.getOrder(); j++) {
addr[i] = j;
construct(i+1, addr, cpf, nodeDomainIndices);
}
}
}
@Override
public double getWeightingFactor(int sampledValue) {
// child_prob / ((child_prob * parent_prob) / Z) = Z / parent_prob
return Z / parentProbs.get(sampledValue);
}
}
public BackwardSamplingWithPriors(BeliefNetworkEx bn) throws Exception {
super(bn);
}
@Override
protected BackSamplingDistribution getBackSamplingDistribution(BeliefNode node, WeightedSample s) {
BackSamplingDistribution d = new BackSamplingDistribution(this);
d.construct(node, s.nodeDomainIndices);
return d;
}
@Override
protected void _initialize() throws Exception {
super._initialize();
if(verbose) out.println("computing priors...");
computePriors(evidenceDomainIndices);
}
protected void computePriors(int[] evidenceDomainIndices) {
priors = new HashMap<BeliefNode, double[]>();
int[] topOrder = bn.getTopologicalOrder();
for(int i : topOrder) {
BeliefNode node = nodes[i];
double[] dist = new double[node.getDomain().getOrder()];
int evidence = evidenceDomainIndices[i];
if(evidence >= 0) {
for(int j = 0; j < dist.length; j++)
dist[j] = evidence == j ? 1.0 : 0.0;
}
else {
CPF cpf = node.getCPF();
computePrior(cpf, 0, new int[cpf.getDomainProduct().length], dist);
}
priors.put(node, dist);
}
}
protected void computePrior(CPF cpf, int i, int[] addr, double[] dist) {
BeliefNode[] domProd = cpf.getDomainProduct();
if(i == addr.length) {
double p = cpf.getDouble(addr); // p = P(node setting | parent configuration)
for(int j = 1; j < addr.length; j++) {
double[] parentPrior = priors.get(domProd[j]);
p *= parentPrior[addr[j]];
} // p = P(node setting, parent configuration)
dist[addr[0]] += p;
return;
}
BeliefNode node = domProd[i];
int nodeIdx = getNodeIndex(node);
if(evidenceDomainIndices[nodeIdx] >= 0) {
addr[i] = evidenceDomainIndices[nodeIdx];
computePrior(cpf, i+1, addr, dist);
}
else {
Domain dom = node.getDomain();
for(int j = 0; j < dom.getOrder(); j++) {
addr[i] = j;
computePrior(cpf, i+1, addr, dist);
}
}
}
}