/******************************************************************************* * Copyright (C) 2009-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.logic.sat.weighted; import java.util.Map.Entry; import java.util.Random; import java.util.Vector; import probcog.inference.IParameterHandler; import probcog.inference.ParameterHandler; import probcog.logic.Formula; import probcog.logic.GroundAtom; import probcog.logic.GroundLiteral; import probcog.logic.PossibleWorld; import probcog.logic.WorldVariables; import probcog.logic.sat.SampleSAT; import probcog.srl.Database; /** * Implementatoin of the MC-SAT inference algorithm (Poon and Domingos 2006). * Also includes extensions for soft evidence, MC-SAT-PC (Jain and Beetz 2010). * @author Dominik Jain */ public class MCSAT implements IParameterHandler { protected WeightedClausalKB kb; protected WorldVariables vars; protected Database db; protected Random rand; protected GroundAtomDistribution dist; protected boolean verbose = true, debug = false; protected int infoInterval = 100; protected ParameterHandler paramHandler; protected SampleSAT sat; protected Vector<SoftEvidence> softEvidence; public class SoftEvidence { public WeightedClause wc; public double p; public double count; public SoftEvidence(WeightedClause wc, double p) { count = 0; this.wc = wc; this.p = p; } } public MCSAT(WeightedClausalKB kb, WorldVariables vars, Database db) throws Exception { this.kb = kb; this.vars = vars; this.db = db; this.rand = new Random(); this.dist = new GroundAtomDistribution(vars); this.paramHandler = new ParameterHandler(this); this.softEvidence = new Vector<SoftEvidence>(); PossibleWorld state = new PossibleWorld(vars); sat = new SampleSAT(state, vars, db.getEntries()); paramHandler.addSubhandler(sat.getParameterHandler()); paramHandler.add("infoInterval", "setInfoInterval"); paramHandler.add("verbose", "setVerbose"); /* 0.95 similarPos(Square1,SquareN1) 0.95 similarPos(Square2,SquareN2) 0.05 similarPos(Square1,SquareN2) 0.05 similarPos(Square2,SquareN1) 0.05 similarPos(Square3,SquareN1) 0.05 similarPos(Square3,SquareN2) 0.05 similarPos(Circle,SquareN1) 0.05 similarPos(Circle,SquareN2) */ /* System.out.println("setting soft ev"); addSoftEvidence(vars.get("similarPos(Square1,SquareN1)"), 0.95); addSoftEvidence(vars.get("similarPos(Square2,SquareN2)"), 0.95); addSoftEvidence(vars.get("similarPos(Square1,SquareN2)"), 0.05); addSoftEvidence(vars.get("similarPos(Square2,SquareN1)"), 0.05); addSoftEvidence(vars.get("similarPos(Square3,SquareN1)"), 0.05); addSoftEvidence(vars.get("similarPos(Square3,SquareN2)"), 0.05); addSoftEvidence(vars.get("similarPos(Circle,SquareN1)"), 0.05); addSoftEvidence(vars.get("similarPos(Circle,SquareN2)"), 0.05); */ /* 0.9 similarApp(O1,N1) 0.9 similarPos(O1,N1) 0.9 similarApp(G2,N2) 0.0 similarPos(G2,N2) 0.9 similarApp(G3,N2) 0.0 similarPos(G3,N2) 0.9 similarApp(G4,N2) 0.0 similarPos(G4,N2) */ } public WeightedClausalKB getKB() { return kb; } public void setVerbose(boolean verbose) { this.verbose = verbose; } public void setDebugMode(boolean active) { this.debug = active; } public void setInfoInterval(int interval) { this.infoInterval = interval; } public GroundAtomDistribution run(int steps) throws Exception { if(debug) { System.out.println("\nMC-SAT constraints:"); for(WeightedClause wc : kb) System.out.println(" " + wc); System.out.println(); } verbose = verbose || debug; if(verbose) System.out.printf("%s sampling (%d weighted formulas)...\n", this.getAlgorithmName(), this.kb.size()); // find initial state satisfying all hard constraints if(verbose) System.out.println("finding initial state..."); Vector<WeightedClause> M = new Vector<WeightedClause>(); for(Entry<WeightedFormula, Vector<WeightedClause>> e : kb.getFormulasAndClauses()) { WeightedFormula wf = e.getKey(); if(wf.isHard) { M.addAll(e.getValue()); } } sat.setDebugMode(debug); sat.initConstraints(M); sat.run(); // actual MC-SAT sampling for(int i = 0; i < steps; i++) { M.clear(); for(Entry<WeightedFormula, Vector<WeightedClause>> e : kb.getFormulasAndClauses()) { WeightedFormula wf = e.getKey(); if(wf.formula.isTrue(sat.getState())){ boolean satisfy = wf.isHard || rand.nextDouble() * Math.exp(wf.weight) > 1.0; if(satisfy) M.addAll(e.getValue()); } } // soft evidence clauses if(i > 0) for(SoftEvidence se : this.softEvidence) { if(se.wc.isTrue(sat.getState())) { se.count += 1; if(se.count/i < se.p) M.add(se.wc); } } if(verbose && (i+1) % infoInterval == 0) { System.out.printf("MC-SAT step %d: %d constraints to be satisfied\n", i+1, M.size()); if(debug) { for(WeightedClause wc : M) System.out.println(" " + wc); } } sat.initConstraints(M); sat.run(); synchronized(dist) { dist.addSample(sat.getState(), 1.0); } } synchronized(dist) { dist.normalize(); } return dist; } public void setP(double p) { sat.setPSampleSAT(p); } public static class GroundAtomDistribution implements Cloneable { public double[] sums; public double Z; public int numSamples; public GroundAtomDistribution(WorldVariables vars){ this.Z = 0.0; this.numSamples = 0; this.sums = new double[vars.size()]; } public void addSample(PossibleWorld w, double weight){ for(GroundAtom ga : w.getVariables()){ if(w.isTrue(ga)){ sums[ga.index] += weight; } } Z += weight; numSamples++; } public void normalize(){ if(Z != 1.0) { for(int i = 0; i < sums.length; i++){ sums[i] /= Z; } Z = 1.0; } } public double getResult(int indx){ return sums[indx]; } public GroundAtomDistribution clone() throws CloneNotSupportedException { return (GroundAtomDistribution)super.clone(); } } public double getResult(GroundAtom ga) { return dist.getResult(ga.index); } public GroundAtomDistribution pollResults() throws CloneNotSupportedException { GroundAtomDistribution ret = null; synchronized(dist) { ret = this.dist.clone(); } return ret; } public ParameterHandler getParameterHandler() { return paramHandler; } public String getAlgorithmName() { return String.format("%s[%s]", this.getClass().getSimpleName(), sat.getAlgorithmName()); } public void addSoftEvidence(GroundAtom ga, double p) throws Exception { Formula nga = new GroundLiteral(false, ga); this.softEvidence.add(new SoftEvidence(new WeightedClause(ga, 0.0, false), p)); this.softEvidence.add(new SoftEvidence(new WeightedClause(nga, 0.0, false), 1.0-p)); } }