/******************************************************************************* * Copyright (C) 2008-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.srl.directed.inference; import java.util.HashMap; import java.util.Vector; import probcog.bayesnets.inference.SATIS_BSampler; import probcog.logic.Formula; import probcog.logic.Negation; import probcog.logic.TrueFalse; import probcog.logic.sat.ClausalKB; import probcog.srl.directed.RelationalBeliefNetwork; import probcog.srl.directed.RelationalNode; import probcog.srl.directed.CPT2MLNFormulas.CPT2Rules; import probcog.srl.directed.bln.GroundBLN; import probcog.srl.directed.bln.coupling.VariableLogicCoupling; import weka.classifiers.trees.j48.Rule; import edu.ksu.cis.bnj.ver3.core.BeliefNode; import edu.tum.cs.util.datastruct.Map2D; /** * Extended version of the SAT-IS algorithm, where the knowledge base is augmented with formulas * based on 0 entries in probabilistic constraints, which factually represent deterministic * constraints. * @author Dominik Jain */ public class SATISEx extends SATIS { /** * whether to exploit context-specific independence (CSI) when extending the KB */ boolean exploitCSI = false; public SATISEx(GroundBLN bln) throws Exception { super(bln); this.paramHandler.add("useCSI", "useCSI"); } public void useCSI(boolean active) { exploitCSI = active; } @Override public ClausalKB getClausalKB() throws Exception { ClausalKB ckb = super.getClausalKB(); // extend the KB with formulas based on a CPD analysis if(!exploitCSI) SATIS_BSampler.extendKBWithDeterministicConstraintsInCPTs(gbln.getGroundNetwork(), gbln.getCoupling(), ckb, gbln.getDatabase()); else { // gather the hard constraints for each fragment System.out.println("CSI analysis..."); Map2D<RelationalNode, String, Vector<Formula>> constraints = new Map2D<RelationalNode, String, Vector<Formula>>(); int numFormulas = 0; int numZeros = 0; int numDirectTranslations = 0; RelationalBeliefNetwork rbn = this.gbln.getRBN(); for(RelationalNode relNode : rbn.getRelationalNodes()) { if(!relNode.isFragment()) continue; //System.out.println(relNode); CPT2Rules cpt2rules = null; for(HashMap<String,String> constantAssignment : relNode.getConstantAssignments()) { Vector<Formula> v = new Vector<Formula>(); if(relNode.hasAggregator()) { Formula f = relNode.toFormula(constantAssignment); if(f == null) throw new Exception("Relational node " + relNode + " could not be translated to a formula"); // TODO could fall back to direct reading of CPT in ground network v.add(f); numDirectTranslations++; } else { if(cpt2rules == null) { cpt2rules = new CPT2Rules(relNode); numZeros += cpt2rules.getZerosInCPT(); } // create formulas from rules Rule[] rules = cpt2rules.learnRules(constantAssignment); for(Rule rule : rules) { if(cpt2rules.getProbability(rule) == 0.0) { Formula f = cpt2rules.getConjunction(rule, constantAssignment); v.add(new Negation(f)); numFormulas++; } } } // create key for this constant assignment StringBuffer sb = new StringBuffer(); for(Integer i : relNode.getIndicesOfConstantParams()) sb.append(constantAssignment.get(relNode.params[i])); String constantKey = sb.toString(); // store constraints.put(relNode, constantKey, v); } } System.out.printf("reduced %d zeros in CPTs to %d formulas; %d direct translations\n", numZeros, numFormulas, numDirectTranslations); // ground the constraints for the actual variables System.out.println("grounding constraints..."); VariableLogicCoupling coupling = gbln.getCoupling(); int sizeBefore = ckb.size(); for(BeliefNode node : gbln.getRegularVariables()) { RelationalNode template = gbln.getTemplateOf(node); //System.out.println(node + " from " + template); Iterable<String> params = coupling.getOriginalParams(node); // get the constant key StringBuffer sb = new StringBuffer(); int i = 0; Vector<Integer> constIndices = template.getIndicesOfConstantParams(); for(String p : params) { if(constIndices.contains(i)) sb.append(p); i++; } String constantKey = sb.toString(); // check if there are any hard constraints for this template Vector<Formula> vf = constraints.get(template, constantKey); if(vf != null) { // get parameter binding i = 0; String[] actualParams = new String[template.params.length]; for(String param : params) actualParams[i++] = param; HashMap<String,String> binding = template.getParameterBinding(actualParams, gbln.getDatabase()); // ground the formulas and add them to the KB for(Formula f : vf) { //System.out.println("grounding " + f + " with " + binding); Formula gf = f.ground(binding, coupling.getWorldVars(), gbln.getDatabase()); Formula gfs = gf.simplify(gbln.getDatabase()); if(gfs instanceof TrueFalse) { TrueFalse tf = (TrueFalse)gfs; if(!tf.isTrue()) System.err.println("unsatisfiable formula" + gf); continue; } ckb.addFormula(gfs); } } } System.out.printf("added %d constraints\n", ckb.size()-sizeBefore); } return ckb; } }