/*******************************************************************************
* Copyright (C) 2009-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.PriorityQueue;
import java.util.Vector;
import probcog.bayesnets.core.BeliefNetworkEx;
import probcog.bayesnets.util.TopologicalOrdering;
import probcog.bayesnets.util.TopologicalSort;
import probcog.logic.Disjunction;
import probcog.logic.Formula;
import probcog.logic.GroundLiteral;
import probcog.logic.PossibleWorld;
import probcog.logic.TrueFalse;
import probcog.logic.WorldVariables;
import probcog.logic.sat.ClausalKB;
import probcog.logic.sat.Clause;
import probcog.logic.sat.SampleSAT;
import probcog.srl.Database;
import probcog.srl.StringVariable;
import probcog.srl.directed.bln.coupling.VariableLogicCoupling;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.ksu.cis.bnj.ver3.core.Discrete;
public class SATIS_BSampler extends BackwardSampling {
VariableLogicCoupling coupling;
/**
* the SAT sampler used to sample a (sub-)state in each iteration
*/
SampleSAT sat;
/**
* variables whose values are determined by the SAT sampler
*/
Collection<BeliefNode> determinedVars;
/**
* clausal KB of constraints that must be satisfied by the SAT sampler
*/
ClausalKB ckb;
/**
* constructs a SAT-IS backward sampler with a given SAT sampler, a given logical coupling and a known set of variables affected by the SAT sampler.
* This construction method is used for BLNs.
* @param bn
* @param sat the SAT sampler to use in each iteration
* @param coupling the logical coupling of the BN's variables
* @param determinedVars the set of variables affected by the SAT sampler, i.e. the variables that will be set if the SAT sampler is run
* @throws Exception
*/
public SATIS_BSampler(BeliefNetworkEx bn, SampleSAT sat, VariableLogicCoupling coupling, Collection<BeliefNode> determinedVars) throws Exception {
super(bn);
this.coupling = coupling;
this.sat = sat;
this.ckb = null; // not required for this construction method
this.determinedVars = determinedVars;
}
/**
* constructs a SAT-IS backward sampler for use with (propositional) Bayesian networks, creating a logical coupling and the SAT sampler automatically (using all deterministic constraints in CPTs).
* @param bn
* @throws Exception
*/
public SATIS_BSampler(BeliefNetworkEx bn) throws Exception {
super(bn);
}
@Override
protected void _initialize() throws Exception {
// build the variable-logic coupling if we don't have it yet
if(coupling == null) {
coupling = new VariableLogicCoupling();
for(BeliefNode n : nodes) {
coupling.addBlockVariable(n, (Discrete)n.getDomain(), n.getName(), new String[0]);
}
}
// gather clausal KB based on deterministic constraints in CPTs
if(ckb == null) {
ckb = new ClausalKB();
extendKBWithDeterministicConstraintsInCPTs(bn, coupling, ckb, null);
}
// get the set of variables that is determined by the sat sampler
if(determinedVars == null) {
determinedVars = new HashSet<BeliefNode>();
for(Clause c : ckb) {
for(GroundLiteral lit : c.lits) {
BeliefNode var = coupling.getVariable(lit.gndAtom);
if(var == null)
throw new Exception("Could not find node corresponding to ground atom '" + lit.gndAtom.toString() + "' with index " + lit.gndAtom.index + "; set of mapped ground atoms is " + coupling.getCoupledGroundAtoms());
determinedVars.add(var);
}
}
}
// build SAT sampler if we don't have it yet
if(this.sat == null) {
// build evidence database
Vector<PropositionalVariable> evidence = new Vector<PropositionalVariable>();
for(int i = 0; i < evidenceDomainIndices.length; i++)
if(evidenceDomainIndices[i] != -1) {
evidence.add(new PropositionalVariable(nodes[i].getName(), nodes[i].getDomain().getName(evidenceDomainIndices[i])));
}
// construct sampler
WorldVariables worldVars = this.coupling.getWorldVars();
sat = new SampleSAT(ckb, new PossibleWorld(worldVars), worldVars, evidence);
}
// pass on parameters
sat.setDebugMode(this.debug);
}
/**
*
* @param bn
* @param coupling
* @param ckb the clausal KB to extend
* @param db an evidence database with which to simplify the formulas obtained, or null if no simplification is to take place
* @throws Exception
*/
public static void extendKBWithDeterministicConstraintsInCPTs(BeliefNetworkEx bn, VariableLogicCoupling coupling, ClausalKB ckb, Database db) throws Exception {
int size = ckb.size();
System.out.print("gathering deterministic constraints from CPDs... ");
for(BeliefNode node : bn.bn.getNodes()) {
if(!coupling.hasCoupling(node))
continue;
CPF cpf = node.getCPF();
BeliefNode[] domProd = cpf.getDomainProduct();
int[] addr = new int[domProd.length];
walkCPF4HardConstraints(coupling, cpf, addr, 0, ckb, db);
}
System.out.println((ckb.size()-size) + " constraints added");
}
protected static void walkCPF4HardConstraints(VariableLogicCoupling coupling, CPF cpf, int[] addr, int i, ClausalKB ckb, Database db) throws Exception {
BeliefNode[] domProd = cpf.getDomainProduct();
if(i == addr.length) {
double p = cpf.getDouble(addr);
if(p == 0.0) {
GroundLiteral[] lits = new GroundLiteral[domProd.length];
for(int k = 0; k < domProd.length; k++) {
lits[k] = coupling.getGroundLiteral(domProd[k], addr[k]);
lits[k].negate();
}
Formula f = new Disjunction(lits);
if(db != null) {
f = f.simplify(db);
if(f instanceof TrueFalse)
return;
}
ckb.addFormula(f);
}
return;
}
for(int k = 0; k < domProd[i].getDomain().getOrder(); k++) {
addr[i] = k;
walkCPF4HardConstraints(coupling, cpf, addr, i+1, ckb, db);
}
}
protected static class PropositionalVariable extends StringVariable {
public PropositionalVariable(String varName, String value) {
super(varName, new String[0], value);
}
@Override
public String getPredicate() {
return this.functionName + "(" + value + ")";
}
@Override
public boolean isBoolean() {
return false;
}
@Override
public boolean pertainsToEvidenceFunction() {
return false;
}
}
@Override
public void initSample(WeightedSample s) throws Exception {
super.initSample(s);
// run SampleSAT to find a configuration that satisfies all logical constraints
sat.run();
PossibleWorld state = sat.getState();
// apply the state found by SampleSAT to the sample
for(BeliefNode var : determinedVars) {
int domIdx = coupling.getVariableValue(var, state);
s.nodeDomainIndices[this.getNodeIndex(var)] = domIdx;
/*if(true) {
out.printf("%s = %s\n", var.toString(), var.getDomain().getName(domIdx));
}*/
}
}
/**
* gets the sampling order by filling the members for backward and forward sampled nodes as well as the set of nodes not in the sampling order
* @param evidenceDomainIndices
* @throws Exception
*/
protected void getOrdering(int[] evidenceDomainIndices) throws Exception {
HashSet<BeliefNode> uninstantiatedNodes = new HashSet<BeliefNode>(Arrays.asList(nodes));
backwardSampledNodes = new Vector<BeliefNode>();
forwardSampledNodes = new Vector<BeliefNode>();
outsideSamplingOrder = new HashSet<BeliefNode>();
TopologicalOrdering topOrder = new TopologicalSort(bn.bn).run(true);
PriorityQueue<BeliefNode> backSamplingCandidates = new PriorityQueue<BeliefNode>(1, new TierComparator(topOrder));
// remove logically determined nodes from the set of uninstantiated nodes
// and store them as outside the sampling order so their conditional probability is considered in the sample weight
for(BeliefNode n : determinedVars) {
uninstantiatedNodes.remove(n);
outsideSamplingOrder.add(n);
}
// check which nodes have evidence; ones that are are candidates for backward sampling and are instantiated
for(int i = 0; i < evidenceDomainIndices.length; i++) {
if(evidenceDomainIndices[i] >= 0) {
backSamplingCandidates.add(nodes[i]);
uninstantiatedNodes.remove(nodes[i]);
}
}
// check all backward sampling candidates
while(!backSamplingCandidates.isEmpty()) {
BeliefNode node = backSamplingCandidates.remove();
// check if there are any uninstantiated parents
BeliefNode[] domProd = node.getCPF().getDomainProduct();
boolean doBackSampling = false;
for(int j = 1; j < domProd.length; j++) {
BeliefNode parent = domProd[j];
// if there are uninstantiated parents, we do backward sampling on the child node
if(uninstantiatedNodes.remove(parent)) {
doBackSampling = true;
backSamplingCandidates.add(parent);
}
}
if(doBackSampling)
backwardSampledNodes.add(node);
// if there are no uninstantiated parents, the node is not backward sampled but is instantiated,
// i.e. it is not in the sampling order
else
outsideSamplingOrder.add(node);
}
// schedule all uninstantiated nodes for forward sampling in the topological order
for(int i : topOrder) {
if(uninstantiatedNodes.contains(nodes[i]))
forwardSampledNodes.add(nodes[i]);
}
out.println("node ordering: " + outsideSamplingOrder.size() + " outside order, " + backwardSampledNodes.size() + " backward, " + forwardSampledNodes.size() + " forward");
}
@Override
public String getAlgorithmName() {
return String.format("%s[%s]", getClass().getSimpleName(), sat.getAlgorithmName());
}
}