/******************************************************************************* * Copyright (C) 2009-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.srl.directed.inference; import java.util.Vector; import probcog.bayesnets.core.BeliefNetworkEx; import probcog.bayesnets.inference.ITimeLimitedInference; import probcog.bayesnets.inference.SampledDistribution; import probcog.logic.Disjunction; import probcog.logic.Formula; import probcog.logic.GroundLiteral; import probcog.logic.sat.weighted.WeightedClausalKB; import probcog.logic.sat.weighted.WeightedClause; import probcog.logic.sat.weighted.WeightedFormula; import probcog.logic.sat.weighted.MCSAT.GroundAtomDistribution; import probcog.srl.directed.bln.GroundBLN; import edu.ksu.cis.bnj.ver3.core.BeliefNode; import edu.ksu.cis.bnj.ver3.core.CPF; /** * MC-SAT inference for Bayesian logic networks * @author Dominik Jain */ public class MCSAT extends Sampler implements ITimeLimitedInference { protected GroundBLN gbln; protected WeightedClausalKB kb; protected double maxWeight = 0; /** * temporary collection of hard constraints appearing in the CPTs of the ground BN */ protected Vector<Disjunction> hardConstraintsInCPTs = new Vector<Disjunction>(); protected probcog.logic.sat.weighted.MCSAT sampler; public MCSAT(GroundBLN gbln) throws Exception { super(gbln); this.gbln = gbln; } @Override protected void _initialize() throws Exception { kb = new WeightedClausalKB(); // add weighted clauses for probabilistic constraints for(BeliefNode n : gbln.getRegularVariables()) { CPF cpf = n.getCPF(); BeliefNode[] domProd = cpf.getDomainProduct(); walkCPT4ClauseCollection(cpf, domProd, new int[domProd.length], 0); } // add weighted clauses for hard constraints double hardWeight = maxWeight + 100; for(Formula f : gbln.getKB()) { kb.addFormula(new WeightedFormula(f, hardWeight, true), false); } for(Disjunction f : hardConstraintsInCPTs) kb.addClause(new WeightedClause(f, hardWeight, true)); // clean up hardConstraintsInCPTs = null; // construct sampler sampler = new probcog.logic.sat.weighted.MCSAT(kb, gbln.getWorldVars(), gbln.getDatabase()); // pass on parameter handling paramHandler.addSubhandler(sampler.getParameterHandler()); } protected void walkCPT4ClauseCollection(CPF cpf, BeliefNode[] domProd, int[] domainIndices, int i) throws Exception { if(i == domainIndices.length) { // create disjunction of negated literals corresponding to domain index configuration GroundLiteral[] lits = new GroundLiteral[domainIndices.length]; for(int j = 0; j < domainIndices.length; j++) { lits[j] = gbln.getGroundLiteral(domProd[j], domainIndices[j]); lits[j].negate(); } Disjunction f = new Disjunction(lits); // obtain probability value and add to collection double p = cpf.getDouble(domainIndices); if(p == 0.0) { // this constraint is actually hard, so remember it for later hardConstraintsInCPTs.add(f); } else { // it is a soft constraint, whose negation we add to the KB double weight = -Math.log(p); kb.addClause(new WeightedClause(f, weight, false)); if(weight > maxWeight) maxWeight = weight; } return; } // recurse for(int j = 0; j < domProd[i].getDomain().getOrder(); j++) { domainIndices[i] = j; walkCPT4ClauseCollection(cpf, domProd, domainIndices, i+1); } } @Override public SampledDistribution _infer() throws Exception { sampler.setDebugMode(this.debug); sampler.setVerbose(true); sampler.setInfoInterval(infoInterval); GroundAtomDistribution gad = sampler.run(numSamples); return getSampledDistribution(gad); } protected SampledDistribution getSampledDistribution(GroundAtomDistribution gad) throws Exception { gad.normalize(); BeliefNetworkEx bn = gbln.getGroundNetwork(); SampledDistribution dist = new SampledDistribution(bn); for(BeliefNode n : gbln.getRegularVariables()) { int idx = bn.getNodeIndex(n); for(int k = 0; k < n.getDomain().getOrder(); k++) { GroundLiteral lit = gbln.getGroundLiteral(n, k); dist.values[idx][k] = gad.getResult(lit.gndAtom.index); if(!lit.isPositive) dist.values[idx][k] = 1-dist.values[idx][k]; } } for(BeliefNode n : gbln.getAuxiliaryVariables()) { int idx = bn.getNodeIndex(n); dist.values[idx][0] = 1.0; dist.values[idx][1] = 0.0; } dist.Z = 1.0; dist.trials = dist.steps = gad.numSamples; return dist; } public SampledDistribution pollResults() throws Exception { return getSampledDistribution(sampler.pollResults()); } }