/*******************************************************************************
* Copyright (C) 2008-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.logic.sat;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Random;
import java.util.Vector;
import probcog.inference.IParameterHandler;
import probcog.inference.ParameterHandler;
import probcog.logic.GroundAtom;
import probcog.logic.GroundLiteral;
import probcog.logic.PossibleWorld;
import probcog.logic.WorldVariables;
import probcog.logic.WorldVariables.Block;
import probcog.srl.AbstractVariable;
import probcog.srl.Database;
import probcog.srl.directed.bln.BayesianLogicNetwork;
import probcog.srl.directed.bln.GroundBLN;
import edu.tum.cs.util.Stopwatch;
import edu.tum.cs.util.StringTool;
/**
* Implementation of the stochastic SAT sampling algorithm SampleSAT by Wei et al.
* It near-uniformly samples a solution from the set of solutions
*
* @author Dominik Jain
*/
public class SampleSAT implements IParameterHandler {
protected HashMap<Integer,Vector<Constraint>> bottlenecks;
protected HashMap<Integer,Vector<Constraint>> GAOccurrences;
protected PossibleWorld state;
protected Vector<Constraint> unsatisfiedConstraints;
protected Vector<Constraint> constraints;
protected Random rand;
protected WorldVariables vars;
protected boolean debug = false;
protected EvidenceHandler evidenceHandler;
protected HashMap<Integer,Boolean> evidence;
protected boolean useUnitPropagation = false;
Iterable<? extends probcog.logic.sat.Clause> kb;
protected ParameterHandler paramHandler;
/**
* SampleSAT's p parameter: probability of performing a greedy (WalkSAT-style) move rather than a simulated annealing-style move
*/
protected double pSampleSAT = 0.9; // 0.5
/**
* WalkSAT's p parameter: random walk parameter, probability of non-greedy move (random flip in unsatisfied clause) rather than greedy (locally optimal) move.
* According to the WalkSAT paper, optimal values were always between 0.5 and 0.6
*/
protected double pWalkSAT = 0.5; // 0.5
/**
* @param kb a collection of clauses to satisfy (such as a ClausalKB)
* @param state a possible world to write to (can be arbitrarily initialized, as it is completely reinitialized)
* @param vars the set of variables the SAT problem is defined on
* @param db an evidence database indicating truth values of evidence atoms (which are to be respected by the algorithm); the state is initialized to respect it and the respective variables are never touched again
* @throws Exception
*/
public SampleSAT(Iterable<? extends probcog.logic.sat.Clause> kb, PossibleWorld state, WorldVariables vars, Iterable<? extends AbstractVariable<?>> db) throws Exception {
this.state = state;
this.vars = vars;
this.kb = kb;
rand = new Random();
constraints = null;
// parameter handling
paramHandler = new ParameterHandler(this);
paramHandler.add("pSampleSAT", "setPSampleSAT");
paramHandler.add("pWalkSAT", "setPWalkSAT");
// read evidence
evidenceHandler = new EvidenceHandler(vars, db);
evidence = evidenceHandler.getEvidence();
}
/**
* initializes the sampler without a set of constraints
* @param state a possible world to write to (can be arbitrarily initialized, as it is completely reinitialized)
* @param vars the set of variables the SAT problem is defined on
* @param db an evidence database indicating truth values of evidence atoms (which are to be respected by the algorithm); the state is initialized to respect it and the respective variables are never touched again
* @throws Exception
*/
public SampleSAT(PossibleWorld state, WorldVariables vars, Iterable<? extends AbstractVariable<?>> db) throws Exception {
this(null, state, vars, db);
}
public void setDebugMode(boolean active) {
debug = active;
}
/**
* enables unit propagation when initializing the set of constraints
*/
public void enableUnitPropagation() {
useUnitPropagation = true;
}
/**
* prepares this sampler for a new set of constraints. NOTE: This method only needs to be called explicitly when switching to a new set of constraints or when using the construction method without the KB
* @param kb
* @throws Exception
*/
public void initConstraints(Iterable<? extends probcog.logic.sat.Clause> kb) throws Exception {
// if constraints were previously instantiated, check whether a reinstantiation is allowed
if(constraints != null && useUnitPropagation)
throw new Exception("Resetting the set of constraints is not allowed when using unit propagation, because unit propagation extends the evidence database, which currently cannot be reversed.");
this.kb = kb;
// initialize data structures for constraints (used during algorithm)
unsatisfiedConstraints = new Vector<Constraint>();
bottlenecks = new HashMap<Integer,Vector<Constraint>>();
// build constraint data
constraints = new Vector<Constraint>();
GAOccurrences = new HashMap<Integer,Vector<Constraint>>();
for(probcog.logic.sat.Clause c : kb)
constraints.add(makeConstraint(c));
// preprocessing
if(useUnitPropagation)
unitPropagation(); // may extend evidence
// set evidence in state
evidenceHandler.setEvidenceInState(state);
}
protected Constraint makeConstraint(probcog.logic.sat.Clause c) {
return new Clause(c.lits);
}
/**
* performs unit propagation on clauses to simplify the set of constraints
*/
protected void unitPropagation() {
int oldSize = constraints.size();
LinkedList<Clause> unitClauses = new LinkedList<Clause>();
for(Constraint c : constraints) {
if(c instanceof Clause) {
Clause cl = (Clause)c;
if(cl.size() == 1)
unitClauses.add(cl);
}
}
while(!unitClauses.isEmpty()) {
Clause cl = unitClauses.remove();
GroundLiteral lit = cl.getLiterals()[0];
evidence.put(lit.gndAtom.index, lit.isPositive);
Vector<Constraint> affected = GAOccurrences.get(lit.gndAtom.index);
if(affected != null) {
Vector<Clause> scheduledForRemoval = new Vector<Clause>();
for(Constraint c : affected) {
if(c instanceof Clause) {
Clause acl = (Clause)c;
for(GroundLiteral l : acl.getLiterals()) {
if(l.gndAtom.index == lit.gndAtom.index) {
if(l.isPositive == lit.isPositive) // the affected clause is always true because the unit clause appears as a subset
scheduledForRemoval.add(acl); // schedule for removal to avoid ConcurrentModificationExceptions
else { // otherwise the literal in the clause is false and we can remove it
acl.removeLiteral(lit.gndAtom.index);
if(acl.size() == 1)
unitClauses.add(acl);
if(acl.size() == 0)
constraints.remove(acl);
}
}
}
}
}
for(Clause acl : scheduledForRemoval)
removeClause(acl);
}
// remove the unit clause from the set of constraints
constraints.remove(cl);
// we no longer need the occurrences entry
GAOccurrences.remove(lit.gndAtom.index);
}
int newSize = constraints.size();
if(debug || true) System.out.println("unit propagation removed " + (oldSize-newSize) + " constraints");
}
protected void removeClause(Clause c) {
constraints.remove(c);
// remove dangling references
for(GroundLiteral lit : c.getLiterals())
GAOccurrences.get(lit.gndAtom.index).remove(c);
}
protected void addUnsatisfiedConstraint(Constraint c) {
unsatisfiedConstraints.add(c);
}
protected void addBottleneck(GroundAtom a, Constraint c) {
Vector<Constraint> v = bottlenecks.get(a.index);
if(v == null) {
v = new Vector<Constraint>();
bottlenecks.put(a.index, v);
}
v.add(c);
}
protected void addGAOccurrence(GroundAtom a, Constraint c) {
Vector<Constraint> v = GAOccurrences.get(a.index);
if(v == null) {
v = new Vector<Constraint>();
GAOccurrences.put(a.index, v);
}
v.add(c);
}
protected void initialize() throws Exception {
// instantiate constraints
if(constraints == null)
initConstraints(kb);
// gather constraint data
bottlenecks.clear();
unsatisfiedConstraints.clear();
if(debug) System.out.println("setting random state...");
setRandomState();
if(debug) state.print();
for(Constraint c : constraints)
c.initState();
}
/**
* solves the SAT problem by first initializing the state randomly (respecting the evidence, however) and then performing greedy and SA moves (as determined by parameter p)
* @throws Exception
*/
public void run() throws Exception {
initialize();
int step = 1;
while(unsatisfiedConstraints.size() > 0) {
// debug code
if(debug) {
System.out.println("SAT step " + step + ", " + unsatisfiedConstraints.size() + " constraints unsatisfied");
if(true) {
//state.print();
if(unsatisfiedConstraints.size() < 30)
for(Constraint c : unsatisfiedConstraints) {
System.out.println(" unsatisfied: " + c);
}
}
checkIntegrity();
}
makeMove();
step++;
}
}
/**
* checks the integrity of internal data structures
* @throws Exception
*/
protected void checkIntegrity() throws Exception {
// - are unsatisfied constraints really unsatisfied?
for(Constraint c : this.constraints) {
if(c instanceof Clause) {
Clause cl = (Clause)c;
int numTrue = 0;
for(GroundLiteral lit : cl.lits)
if(lit.isTrue(state)) {
numTrue++;
if(!cl.trueOnes.contains(lit.gndAtom))
throw new Exception("Clause.trueOnes corrupted (1)");
}
if(numTrue != cl.trueOnes.size())
throw new Exception("Clause.trueOnes corrupted (2)");
boolean isTrue = numTrue > 0;
boolean contained = unsatisfiedConstraints.contains(c);
if(contained != !isTrue)
throw new Exception("Unsatisfied constraints corrupted");
}
}
// - are bottlenecks really bottlenecks?
for(java.util.Map.Entry<Integer,Vector<Constraint>> entry : bottlenecks.entrySet()) {
GroundAtom ga = this.vars.get(entry.getKey());
for(Constraint c : entry.getValue()) {
if(c instanceof Clause) {
Clause cl = (Clause)c;
boolean haveTrueOne = false;
for(GroundLiteral lit : cl.lits) {
if(lit.isTrue(state)) {
if(haveTrueOne)
throw new Exception("Bottlenecks corrupted: Clause " + cl + " contains a second true literal.");
if(lit.gndAtom != ga)
throw new Exception("Bottlenecks corrupted: Clause " + cl + " contains a true literal that isn't the bottleneck.");
haveTrueOne = true;
}
if(lit.gndAtom == ga && !lit.isTrue(state))
throw new Exception("Bottlenecks corrupted: Clause " + cl + " has " + ga + " as a bottleneck but contains a literal with " + ga + " that is false; it is likely that the clause is a tautology which should never have bottlenecks.");
}
}
}
}
}
public PossibleWorld getState() {
return state;
}
/**
* sets a random state for non-evidence atoms
* @throws Exception
*/
protected void setRandomState() throws Exception {
evidenceHandler.setRandomState(state);
}
protected void makeMove() {
if(rand.nextDouble() < this.pSampleSAT) {
if(debug) System.out.println(" WalkSAT move:");
walkSATMove();
}
else {
if(debug) System.out.println(" SA move:");
SAMove();
}
}
protected void walkSATMove() {
// pick an unsatisfied constraint
Constraint c = unsatisfiedConstraints.get(rand.nextInt(unsatisfiedConstraints.size()));
// with probability p, satisfy the constraint randomly
if(rand.nextDouble() < this.pWalkSAT)
c.satisfyRandomly();
// with probability 1-p, satisfy it greedily
else
c.satisfyGreedily();
}
protected void SAMove() {
boolean done = false;
while(!done) {
// randomly pick a ground atom to flip
int idxGA = rand.nextInt(vars.size());
GroundAtom gndAtom = vars.get(idxGA);
// if it has evidence, skip it
if(evidence.containsKey(idxGA))
continue;
// try to flip it (along with a second one, where appropriate)
done = pickSecondAtRandomAndFlip(gndAtom);
}
}
/**
* attempts to flip the variable that is given, choosing an appropriate second variable (at random where applicable) if the variable is in a block
* @param gndAtom
* @return true if the variable could be flipped
*/
protected boolean pickSecondAtRandomAndFlip(GroundAtom gndAtom) {
// if it's in a block, must choose a second to flip
GroundAtom gndAtom2 = null;
Block block = vars.getBlock(gndAtom.index);
if(block != null) {
GroundAtom trueOne = block.getTrueOne(state);
if(gndAtom == trueOne) { // if we are flipping the true one, pick the second at random among the others
Vector<GroundAtom> others = new Vector<GroundAtom>();
for(GroundAtom ga : block) {
if(ga != trueOne && !evidence.containsKey(ga.index))
others.add(ga);
}
if(others.isEmpty())
return false;
gndAtom2 = others.get(rand.nextInt(others.size()));
}
else { // second to flip must be true one
if(evidence.containsKey(trueOne.index))
return false;
gndAtom2 = trueOne;
}
}
// flip
flipGndAtom(gndAtom);
if(gndAtom2 != null)
flipGndAtom(gndAtom2);
return true;
}
protected void pickAndFlipVar(Iterable<GroundAtom> candidates) {
// find the best candidate
GroundAtom bestGA = null, bestGASecond = null;
double bestDelta = Double.NEGATIVE_INFINITY;
for(GroundAtom gndAtom : candidates) {
// if we have evidence, skip this ground atom
if(evidence.containsKey(gndAtom.index))
continue;
// calculate delta-cost
double delta = deltaCost(gndAtom);
// - if the atom is in a block, we must consider the cost of flipping the second atom
Block block = vars.getBlock(gndAtom.index);
GroundAtom secondGA = null;
if(block != null) {
GroundAtom trueOne = block.getTrueOne(state);
double delta2 = Double.NEGATIVE_INFINITY;
if(gndAtom != trueOne) { // the second one to flip must be the true one
secondGA = trueOne;
delta2 = deltaCost(secondGA);
}
else { // as the second flip any one of the others (that has no evidence)
for(GroundAtom ga2 : block) {
if(evidence.containsKey(ga2.index) || ga2 == gndAtom)
continue;
double d = deltaCost(ga2);
if(d > delta2) {
delta2 = d;
secondGA = ga2;
}
}
}
if(secondGA == null)
continue;
delta += delta2; // TODO additivity ignores possibility of first and second GA appearing in same formula (make temporary change!)
}
// is it better?
boolean newBest = false;
if(delta > bestDelta)
newBest = true;
else if(delta == bestDelta && rand.nextInt(2) == 1)
newBest = true;
if(newBest) {
bestGA = gndAtom;
bestGASecond = secondGA;
bestDelta = delta;
}
}
// perform the flip
flipGndAtom(bestGA);
if(bestGASecond != null)
flipGndAtom(bestGASecond);
}
protected void flipGndAtom(GroundAtom gndAtom) {
if(debug) System.out.println(" flipping " + gndAtom);
// modify state
boolean value = state.isTrue(gndAtom);
state.set(gndAtom, !value);
// the constraints where the literal was a bottleneck are now unsatisfied
Vector<Constraint> bn = this.bottlenecks.get(gndAtom.index);
if(bn != null) {
this.unsatisfiedConstraints.addAll(bn);
bn.clear();
}
// other stuff is handled by the constraints themselves
Vector<Constraint> occ = this.GAOccurrences.get(gndAtom.index);
if(occ != null)
for(Constraint c : occ)
c.handleFlip(gndAtom);
}
protected double deltaCost(GroundAtom gndAtom) {
double delta = 0;
// consider newly unsatisfied constraints (negative)
Vector<Constraint> bn = this.bottlenecks.get(gndAtom.index);
if(bn != null)
delta -= bn.size();
// consider newly satisfied constraints (positive)
Vector<Constraint> occs = this.GAOccurrences.get(gndAtom.index);
if(occs != null)
for(Constraint c : occs)
if(c.flipSatisfies(gndAtom))
delta++;
return delta;
}
/**
* sets the probability of a random walk (WalkSAT-style) move
* @param p
*/
public void setPSampleSAT(double p) {
this.pSampleSAT = p;
}
/**
* sets the probability of a random move (rather than a greedy move) in WalkSAT moves
* @param p
*/
public void setPWalkSAT(double p) {
this.pWalkSAT = p;
}
protected abstract class Constraint {
public abstract void satisfyGreedily();
public abstract void satisfyRandomly();
public abstract boolean flipSatisfies(GroundAtom gndAtom);
public abstract void handleFlip(GroundAtom gndAtom);
public abstract void initState();
public abstract boolean isTrue(PossibleWorld w);
}
protected class Clause extends Constraint {
protected GroundLiteral[] lits;
protected Vector<GroundAtom> gndAtoms;
protected HashSet<GroundAtom> trueOnes;
public Clause(GroundLiteral[] lits) {
this.lits = lits;
// collect ground atom occurrences
gndAtoms = new Vector<GroundAtom>(lits.length);
trueOnes = new HashSet<GroundAtom>((lits.length+1)/2);
for(GroundLiteral lit : lits) {
GroundAtom gndAtom = lit.gndAtom;
gndAtoms.add(gndAtom);
addGAOccurrence(gndAtom, this);
}
}
public boolean isTrue(PossibleWorld w) {
for(GroundLiteral lit : lits)
if(lit.isTrue(w))
return true;
return false;
}
@Override
public void satisfyGreedily() {
pickAndFlipVar(gndAtoms);
}
public void satisfyRandomly() {
boolean done = false;
while(!done) {
// randomly pick a ground atom from the clause to flip
GroundAtom gndAtom = this.gndAtoms.get(rand.nextInt(this.gndAtoms.size()));
// if it has evidence, skip it
if(evidence.containsKey(gndAtom.index))
continue;
// try to flip it (along with a second one, where appropriate)
done = pickSecondAtRandomAndFlip(gndAtom);
}
}
@Override
public boolean flipSatisfies(GroundAtom gndAtom) {
return trueOnes.size() == 0;
}
@Override
public void handleFlip(GroundAtom gndAtom) {
int numTrueLits = trueOnes.size();
if(trueOnes.contains(gndAtom)) { // the lit was true and is now false, remove it from the clause's list of true lits
trueOnes.remove(gndAtom);
numTrueLits--;
// if no more true lits are left, the clause is now unsatisfied; this is handled in flipGndAtom
}
else { // the lit was false and is now true, add it to the clause's list of true lits
if(numTrueLits == 0) // the clause was previously unsatisfied, it is now satisfied
unsatisfiedConstraints.remove(this);
else if(numTrueLits == 1) // we are adding a second true lit, so the first one is no longer a bottleneck of this clause
bottlenecks.get(trueOnes.iterator().next().index).remove(this);
trueOnes.add(gndAtom);
numTrueLits++;
}
if(numTrueLits == 1)
addBottleneck(trueOnes.iterator().next(), this);
}
@Override
public String toString() {
return StringTool.join(" v ", lits);
}
@Override
public void initState() {
trueOnes.clear();
// find out which lits are true
for(GroundLiteral lit : lits)
if(lit.isTrue(state))
trueOnes.add(lit.gndAtom);
// if there are no true ones, this constraint is unsatisfied
if(trueOnes.size() == 0)
addUnsatisfiedConstraint(this);
// if there is exactly one true literal, it is a bottleneck
// (unless the clause also contains the negated literal,
// but a sat.Clause guarantees that this cannot be the case)
else if(trueOnes.size() == 1)
addBottleneck(trueOnes.iterator().next(), this);
}
public int size() {
return this.lits.length;
}
public GroundLiteral[] getLiterals() {
return lits;
}
public void removeLiteral(int idxGndAtom) {
GroundLiteral[] newlits = new GroundLiteral[this.lits.length-1];
gndAtoms.clear();
for(int i = 0, j = 0; i < lits.length; i++)
if(lits[i].gndAtom.index != idxGndAtom) {
newlits[j++] = lits[i];
gndAtoms.add(lits[i].gndAtom);
}
lits = newlits;
}
}
public static void main(String[] args) throws Exception {
/*
String blog = "relxy.blog";
String net = "relxy.xml";
String blnfile = "relxy.bln";
String dbfile = "relxy.blogdb";
*/
String blog = "meals_any_for.blog";
String net = "meals_any_for_functional.xml";
String blnfile = "meals_any_for_functional.bln";
String dbfile = "lorenzExample.blogdb";
BayesianLogicNetwork bln = new BayesianLogicNetwork(blog, net, blnfile);
// read evidence
Database db = new Database(bln.rbn);
db.readBLOGDB(dbfile);
// ground model
GroundBLN gbln = new GroundBLN(bln, db);
gbln.instantiateGroundNetwork();
// run algorithm
PossibleWorld state = new PossibleWorld(gbln.getWorldVars());
ClausalKB ckb = new ClausalKB(gbln.getKB());
Stopwatch sw = new Stopwatch();
sw.start();
SampleSAT ss = new SampleSAT(ckb, state, gbln.getWorldVars(), gbln.getDatabase().getEntries());
ss.run();
sw.stop();
/*System.out.println("SECOND RUN");
ss.run();*/
System.out.println("done");
state.print();
System.out.println("time taken: " + sw.getElapsedTimeSecs());
}
public ParameterHandler getParameterHandler() {
return paramHandler;
}
public String getAlgorithmName() {
return String.format("%s[%f;%f]", this.getClass().getSimpleName(), pSampleSAT, pWalkSAT);
}
}