/*******************************************************************************
* Copyright (C) 2012 Gregor Wylezich, Dominik Jain and Paul Maier.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.wcsp;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map.Entry;
import java.util.TreeSet;
import java.util.Vector;
import probcog.inference.IParameterHandler;
import probcog.inference.ParameterHandler;
import probcog.logic.ComplexFormula;
import probcog.logic.Conjunction;
import probcog.logic.Disjunction;
import probcog.logic.Formula;
import probcog.logic.GroundAtom;
import probcog.logic.GroundLiteral;
import probcog.logic.IPossibleWorld;
import probcog.logic.Negation;
import probcog.logic.PossibleWorld;
import probcog.logic.WorldVariables;
import probcog.logic.WorldVariables.Block;
import probcog.logic.sat.weighted.WeightedFormula;
import probcog.srl.BooleanDomain;
import probcog.srl.Database;
import probcog.srl.Signature;
import probcog.srl.mln.MarkovLogicNetwork;
import probcog.srl.mln.MarkovRandomField;
import probcog.wcsp.Constraint.ArrayKey;
import probcog.wcsp.Constraint.Tuple;
import edu.tum.cs.util.StringTool;
/**
* Converts an instantiated MLN (i.e. a ground MRF) into the Toulbar2 WCSP format
* @author Gregor Wylezich
* @author Dominik Jain
* @author Paul Maier
*/
public class WCSPConverter implements IParameterHandler {
protected MarkovLogicNetwork mln;
protected MarkovRandomField mrf;
protected PossibleWorld world;
protected Double divisor;
protected HashMap<String, HashSet<String>> doms;
protected HashMap<Integer, Integer> gndID_BlockID;
/**
* list of WCSP variable names
*/
protected ArrayList<String> vars;
/**
* maps a ground atom index to a WCSP variable index
*/
protected HashMap<Integer, Integer> gndAtomIdx2varIdx;
/**
* maps WCSP variable indices to sets of ground atoms encompassed by the variable
*/
protected HashMap<Integer, Vector<GroundAtom>> varIdx2groundAtoms;
protected HashMap<String, String> func_dom;
protected HashMap<Formula, Long> wcspConstraints = new HashMap<Formula, Long>();
protected PrintStream ps;
protected long hardCost = -1;
protected boolean debug = false, verbose = false;
protected Database db;
protected boolean cacheConstraints = false;
protected ParameterHandler paramHandler;
/**
* @param mrf
* @throws Exception
*/
public WCSPConverter(MarkovRandomField mrf) throws Exception {
this.mln = mrf.mln;
this.mrf = mrf;
this.paramHandler = new ParameterHandler(this);
paramHandler.add("verbose", "setVerbose");
paramHandler.add("debug", "setDebug");
}
public void setCacheConstraints(boolean cache) {
this.cacheConstraints = cache;
}
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
public void setDebug(boolean debug) {
this.debug = debug;
}
/**
* computes the divisor that is used to convert MLN weights to WCSP costs
* @return
*/
protected double computeDivisor() {
// get minimum weight and build sorted tree set of weights
TreeSet<Double> weight = new TreeSet<Double>();
double minWeight = Double.MAX_VALUE;
for(WeightedFormula wf : mln.getFormulas()) {
double w = Math.abs(wf.weight);
weight.add(w);
if(w < minWeight && w != 0)
minWeight = w;
}
// calculate the smallest difference between consecutive weights
double deltaMin = Double.MAX_VALUE;
Iterator<Double> iter = weight.iterator();
Double w1 = iter.next();
while(iter.hasNext()) {
Double w2 = iter.next();
double diff = w2 - w1;
if(diff < deltaMin)
deltaMin = diff;
w1 = w2;
}
double divisor = 1.0;
if(minWeight < 1.0)
divisor *= minWeight;
if(deltaMin < 1.0)
divisor *= deltaMin;
return divisor;
}
/**
* performs the conversion of the ground MRF to the WCSP file
* @param wcspFilename
* @throws Exception
*/
public WCSP run() throws Exception {
initialize();
// instantiate WCSP
int[] domSizes = new int[vars.size()];
for(int i = 0; i < vars.size(); i++) {
HashSet<String> domSet = doms.get(func_dom.get(vars.get(i)));
domSizes[i] = domSet == null ? 2 : domSet.size();
}
long top = hardCost;
WCSP wcsp = new WCSP(domSizes, top);
// generate evidence constraints
if(verbose)
System.out.println("generating evidence constraints...");
generateEvidenceConstraints(wcsp);
// generate constraints for weighted formulas, merging constraints with the same domains
if(verbose) System.out.printf("generating constraints for %d weighted formulas...\n", mrf.getNumFormulas());
HashMap<ArrayKey, Constraint> collectedConstraints = new HashMap<ArrayKey, Constraint>();
for(WeightedFormula wf : mrf) {
Constraint c = generateConstraint(wf);
if(c != null) {
// check if we have a previous constraint with the same domain
ArrayKey key = new ArrayKey(c.getVarIndices());
Constraint prevConstraint = collectedConstraints.get(key);
if(prevConstraint != null)
prevConstraint.merge(c);
else {
collectedConstraints.put(key, c);
wcsp.addConstraint(c);
}
}
}
if(verbose)
System.out.printf("constructed %d constraints in total\n", wcsp.size());;
return wcsp;
}
/**
* this method generates a variable for each ground atom; for blocks, only one variable is created
*/
protected void createVariables() {
vars = new ArrayList<String>(); // list of new variables
gndAtomIdx2varIdx = new HashMap<Integer, Integer>(); // maps ground atom indices to WCSP variable indices
func_dom = new HashMap<String, String>(); // maps variable to domain the variable uses
varIdx2groundAtoms = new HashMap<Integer, Vector<GroundAtom>>(); // maps a variable index to all ground atoms that are set by this variable
HashSet<Block> handledBlocks = new HashSet<Block>();
WorldVariables ww = world.getVariables();
for(int i = 0; i < ww.size(); i++) {
GroundAtom ga = ww.get(i);
// check whether ground atom is in a block
Block block = ww.getBlock(ga.index);
if(block != null) {
if(handledBlocks.contains(block))
continue;
handledBlocks.add(block);
// generate the new variable name
StringBuffer shortened = new StringBuffer(ga.predicate);
int funcArgIdx = mln.getFunctionallyDeterminedArgument(ga.predicate);
shortened.append('(');
int k = 0;
for(int j = 0; j < ga.args.length; j++) {
if(j == funcArgIdx)
continue;
if(k++ > 0)
shortened.append(',');
shortened.append(ga.args[j]);
}
shortened.append(')');
String varName = shortened.toString();
int varIdx = vars.size();
//System.out.printf("adding WCSP block variable %s\n", varName);
vars.add(varName);
Signature sig = mln.getSignature(ga.predicate);
func_dom.put(varName, sig.argTypes[funcArgIdx]);
Vector<GroundAtom> tmp = new Vector<GroundAtom>();
for(GroundAtom gndAtom : block) {
gndAtomIdx2varIdx.put(gndAtom.index, varIdx);
tmp.add(gndAtom);
}
varIdx2groundAtoms.put(varIdx, tmp);
}
else { // it's a boolean variable
String varName = ga.toString();
int varIdx = vars.size();
//System.out.printf("adding WCSP variable %s\n", varName);
vars.add(varName);
gndAtomIdx2varIdx.put(ga.index, varIdx);
// in this case, the mapping of this variable is set to "boolean" domain
func_dom.put(varName, "boolean");
// in this case, the HashSet of Groundatoms only contains the selected Worldvariable
Vector<GroundAtom> tmp = new Vector<GroundAtom>();
tmp.add(ga);
varIdx2groundAtoms.put(varIdx, tmp);
}
}
if(debug) {
System.out.println("WCSP variables:");
for(Entry<Integer,Vector<GroundAtom>> e : varIdx2groundAtoms.entrySet()) {
System.out.printf("%s %s\n", e.getKey(), StringTool.join(", ", e.getValue()));
}
}
}
/**
* this method simplifies the generated variables (if a variable is given by the evidence, it's not necessary for the WCSP)
* @param db evidence database
* @throws Exception
*/
protected void simplifyVars(Database db) throws Exception {
ArrayList<String> simplifiedVars = new ArrayList<String>(); // list of simplified variables
HashMap<Integer, Integer> sf_gndAtomIdx2varIdx = new HashMap<Integer, Integer>(); // mapping of ground atom indices to simplified variable indices
HashMap<Integer, Vector<GroundAtom>> sf_varIdx2groundAtoms = new HashMap<Integer, Vector<GroundAtom>>(); // mapping of simplified variable to ground atom
// check all variables for an evidence-entry
for (int i = 0; i < vars.size(); i++) {
// check all entries in HashSet of the selected variable for an entry in evidence
int evidenceAtoms = 0;
Vector<GroundAtom> gndAtoms = varIdx2groundAtoms.get(i);
for(GroundAtom g : gndAtoms) {
if (db.getVariableValue(g.toString(), false) != null) // evidence entry exists
evidenceAtoms++;
}
// if hashsets (givenAtoms and hashset of the variable) have same size, then all ground atoms are set by the evidence
// we don't need to handle this variable anymore
// if hashsets don't have same size, the variable must be handled
if ((gndAtoms.size() != evidenceAtoms)) {
// add variable to simplifiedVars
int idx = simplifiedVars.size();
simplifiedVars.add(vars.get(i));
// save mapping of ground atoms to the new simplified variable
for (GroundAtom g : varIdx2groundAtoms.get(i))
sf_gndAtomIdx2varIdx.put(g.index, idx);
// clone hashset and save it in the mapping of simplified variables to groundatoms
sf_varIdx2groundAtoms.put(idx, gndAtoms);
}
}
if(verbose) System.out.printf("simplification: reduced %d to %d variables\n", vars.size(), simplifiedVars.size());
this.vars = simplifiedVars;
this.gndAtomIdx2varIdx = sf_gndAtomIdx2varIdx;
this.varIdx2groundAtoms = sf_varIdx2groundAtoms;
}
/**
* this method generates a WCSP Constraint for a weighted formula
* @param wf the weighted formula
* @throws Exception
*/
protected Constraint generateConstraint(WeightedFormula wf) throws Exception {
// if the weight is negative, negate the formula and its weight
Formula f = wf.formula;
double weight = wf.weight;
if(weight < 0) {
f = new Negation(f);
weight *= -1;
}
// convert to negation normal form so we get many flat conjunctions or disjunctions, which can be efficiently converted
f = f.toNNF();
// get all ground atoms of the formula
HashSet<GroundAtom> gndAtoms = new HashSet<GroundAtom>();
f.getGroundAtoms(gndAtoms);
// get corresponding list of WCSP variables
HashSet<Integer> setVarIndices = new HashSet<Integer>(gndAtoms.size());
for(GroundAtom g : gndAtoms) {
// add simplified variable only if the array doesn't contain this sf_variable already
Integer idx = gndAtomIdx2varIdx.get(g.index);
if(idx == null)
throw new Exception("Variable index for '" + g + "' is null");
setVarIndices.add(idx);
}
int[] referencedVarIndices = new int[setVarIndices.size()];
int i = 0;
for(Integer varIdx : setVarIndices)
referencedVarIndices[i++] = varIdx;
Arrays.sort(referencedVarIndices); // have the array sorted to simplify constraint unification
// get cost value for this constraint
long cost;
if(wf.isHard)
cost = hardCost;
else
cost = Math.round(weight / divisor);
ArrayList<Tuple> relevantSettings = null;
long defaultCosts = -1;
// try the simplified conversion method
boolean generateAllPossibilities = true;
boolean isConjunction = f instanceof Conjunction;
if(isConjunction || f instanceof Disjunction) {
generateAllPossibilities = false;
try {
relevantSettings = new ArrayList<Tuple>();
this.gatherConstraintTuplesSimplified((ComplexFormula)f, referencedVarIndices, cost, relevantSettings, isConjunction);
defaultCosts = isConjunction ? cost : 0;
}
catch(SimplifiedConversionNotSupportedException e) {
if(debug) System.out.printf("No simplified conversion (%s): %s\n", e.getMessage(), f.toString());
generateAllPossibilities = true;
}
}
// if necessary, use the complex conversion method which looks at all possible settings
if(generateAllPossibilities) {
// generate all possibilities for this constraint
ArrayList<Tuple> settingsZero = new ArrayList<Tuple>();
ArrayList<Tuple> settingsOther = new ArrayList<Tuple>();
gatherConstraintTuples(f, referencedVarIndices, 0, world, new int[referencedVarIndices.length], cost, settingsZero, settingsOther);
if(settingsOther.size() < settingsZero.size()) { // in this case there are more null-values than lines with a value differing from 0
relevantSettings = settingsOther;
// the default costs (0) are calculated and set in the first line of the constraint
defaultCosts = 0;
}
else { // there are fewer settings that result in 0 costs than settings with the other value
relevantSettings = settingsZero;
// the default costs correspond to the formula's weight
defaultCosts = cost;
}
}
// if the smaller set contains no lines, this constraint is either unsatisfiable or a tautology, so it need not be considered at all
if(relevantSettings.size() == 0)
return null;
// construct the constraint
Constraint c = new Constraint(defaultCosts, referencedVarIndices, relevantSettings.size());
for(Tuple tuple : relevantSettings) {
c.addTuple(tuple);
}
if(this.cacheConstraints)
wcspConstraints.put(f, cost);
return c;
}
protected void generateEvidenceConstraints(WCSP wcsp) throws Exception {
long top = wcsp.top;
// add unary constraints for evidence variables
String[][] entries = db.getEntriesAsArray();
WorldVariables worldVars = world.getVariables();
for(String[] entry : entries) {
String varName = entry[0];
boolean isTrue = entry[1].equals(BooleanDomain.True);
GroundAtom gndAtom = worldVars.get(varName);
Integer iVar = this.gndAtomIdx2varIdx.get(gndAtom.index);
if(iVar == null)
continue; // variable was removed due to simplification
Vector<GroundAtom> block = this.varIdx2groundAtoms.get(iVar);
int iValue;
long tupleCost, defaultCost;
if(block.size()==1) {
iValue = isTrue ? 0 : 1;
defaultCost = top;
tupleCost = 0;
}
else {
iValue = block.indexOf(gndAtom);
if(isTrue) {
defaultCost = top;
tupleCost = 0;
}
else {
defaultCost = 0;
tupleCost = top;
}
}
int[] varIndices = new int[]{iVar};
Constraint c = new Constraint(defaultCost, varIndices, 1);
c.addTuple(new int[]{iValue}, tupleCost);
wcsp.addConstraint(c);
if (this.cacheConstraints)
wcspConstraints.put(new GroundLiteral(!isTrue, gndAtom), top);
}
}
/**
* recursive method to generate all the constraint lines for a formula
* @param f formula (for this formula all possibilities are generated)
* @param wcspVarIndices set of all ground atoms of the formula
* @param i counter to terminate the recursion
* @param w possible world to evaluate costs for a setting of ground atoms
* @param domIndices current variable assignment
* @param cost cost associated with the formula
* @param settingsZero set to save all possibilities with costs of 0
* @param settingsOther set to save all possibilities with costs different from 0
* @throws Exception
*/
protected void gatherConstraintTuples(Formula f, int[] wcspVarIndices, int i, PossibleWorld w, int[] domIndices, long cost, ArrayList<Tuple> settingsZero, ArrayList<Tuple> settingsOther) throws Exception {
// if all ground atoms were handled, the costs for this setting can be evaluated
if (i == wcspVarIndices.length) {
if(!f.isTrue(w)) // if formula is false, costs correspond to the weight
settingsOther.add(new Tuple(domIndices.clone(), cost));
else // if formula is true, there are no costs
settingsZero.add(new Tuple(domIndices.clone(), 0L));
} else { // recursion
int wcspVarIdx = wcspVarIndices[i];
// get domain of the handled simplified variable
HashSet<String> domSet = doms.get(func_dom.get(vars.get(wcspVarIdx)));
int domSize;
if(domSet == null) // variable is boolean
domSize = 2;
else // variable is non-boolean (results from blocked ground atoms)
domSize = domSet.size();
for(int j = 0; j < domSize; j++) {
domIndices[i] = j;
setGroundAtomState(w, wcspVarIdx, j);
gatherConstraintTuples(f, wcspVarIndices, i + 1, w, domIndices, cost, settingsZero, settingsOther);
}
}
}
public long getWorldCosts(IPossibleWorld world) throws Exception {
long costs = 0;
for (Formula f : wcspConstraints.keySet()) {
if (!f.isTrue(world)) {
long newCosts = costs + wcspConstraints.get(f);
if (newCosts < costs)
throw new Exception("Numeric overflow in costs");
costs = newCosts;
}
}
return costs;
}
protected void gatherConstraintTuplesSimplified(ComplexFormula f, int[] wcspVarIndices, long cost, ArrayList<Tuple> settings, boolean isConjunction) throws Exception {
// gather assignment
HashMap<Integer,Integer> assignment = new HashMap<Integer,Integer>();
for(Formula child : f.children) {
boolean isTrue;
GroundAtom gndAtom;
if(child instanceof GroundLiteral) {
GroundLiteral lit = (GroundLiteral) child;
gndAtom = lit.gndAtom;
isTrue = lit.isPositive;
}
else if(child instanceof GroundAtom) {
gndAtom = (GroundAtom) child;
isTrue = true;
}
else
throw new SimplifiedConversionNotSupportedException("Child is not a literal");
if(!isConjunction) // for disjunction, consider the case where the child is false
isTrue = !isTrue;
int wcspVarIdx = this.gndAtomIdx2varIdx.get(gndAtom.index);
Integer value = getVariableSettingFromGroundAtomSetting(wcspVarIdx, gndAtom, isTrue);
Integer oldValue = assignment.put(wcspVarIdx, value);
if(oldValue != null && oldValue != value) // formula contains the same variable twice with different value
throw new SimplifiedConversionNotSupportedException("Multiple appearances of the same variable");
}
int[] domIndices = new int[wcspVarIndices.length];
int i = 0;
for(Integer wcspVarIdx : wcspVarIndices)
domIndices[i++] = assignment.get(wcspVarIdx);
// if the formula is true, we have no costs
// if the formula is false, costs apply.
// for conjunction, we considered the true case; for disjunction, we considered the false case
settings.add(new Tuple(domIndices, isConjunction ? 0 : cost));
}
/**
* sets the state of the ground atoms in w that correspond to the given wcsp variable
* @param w
* @param wcspVarIdx
* @param domIdx index into the wcsp variable's domain
*/
public void setGroundAtomState(PossibleWorld w, int wcspVarIdx, int domIdx) {
Vector<GroundAtom> atoms = varIdx2groundAtoms.get(wcspVarIdx);
if(atoms.size() == 1) { // var is boolean
w.set(atoms.iterator().next(), domIdx == 0);
}
else { // var corresponds to block
Object[] dom = doms.get(func_dom.get(vars.get(wcspVarIdx))).toArray();
setBlockState(w, atoms, dom[domIdx].toString());
}
//System.out.printf("%s = %s\n", this.simplifiedVars.get(wcspVarIdx), dom[domIdx].toString());
}
protected int getVariableSettingFromGroundAtomSetting(int wcspVarIdx, GroundAtom gndAtom, boolean isTrue) throws SimplifiedConversionNotSupportedException {
Vector<GroundAtom> atoms = varIdx2groundAtoms.get(wcspVarIdx);
if(atoms.size() == 1) {
return isTrue ? 0 : 1;
}
else {
if(!isTrue)
throw new SimplifiedConversionNotSupportedException("Blocked variable appears negated");
int idx = atoms.indexOf(gndAtom);
if(idx == -1)
throw new IllegalArgumentException("Ground atom does not appear in list");
return idx;
}
}
protected class SimplifiedConversionNotSupportedException extends Exception {
public SimplifiedConversionNotSupportedException(String message) { super(message); }
private static final long serialVersionUID = 1L;
}
/**
* this method sets the truth values of a block of mutually exclusive ground atoms
* @param block atoms within the block
* @param value value indicating the atom to set to true
*/
protected void setBlockState(PossibleWorld w, Vector<GroundAtom> block, String value) {
int detArgIdx = this.mln.getFunctionallyDeterminedArgument(block.iterator().next().predicate);
Iterator<GroundAtom> it = block.iterator();
GroundAtom g;
while(it.hasNext()) {
g = it.next();
boolean v = g.args[detArgIdx].equals(value);
w.set(g.index, v);
if(v)
break;
}
// set all remaining atoms false
while(it.hasNext())
w.set(it.next().index, false);
}
protected void initialize() throws Exception {
this.db = mrf.getDb();
this.world = new PossibleWorld(mrf.getWorldVariables());
doms = mrf.getDb().getDomains();
createVariables();
simplifyVars(mrf.getDb());
divisor = computeDivisor();
if(verbose) System.out.printf("divisor: %g\n", divisor);
long sumSoftCosts = 0;
for(WeightedFormula wf : mrf) {
if(!wf.isHard) {
long cost = Math.abs(Math.round(wf.weight / divisor));
long newSum = sumSoftCosts + cost;
if (newSum < sumSoftCosts)
throw new Exception(String.format("Numeric overflow in sumSoftCosts (%d < %d)", newSum, sumSoftCosts));
sumSoftCosts = newSum;
}
}
hardCost = sumSoftCosts + 1;
if (hardCost <= sumSoftCosts)
throw new Exception("Numeric overflow in sumSoftCosts");
}
@Override
public ParameterHandler getParameterHandler() {
return paramHandler;
}
}