/*******************************************************************************
* Copyright (C) 2007-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.srl.directed;
import java.io.PrintStream;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.Vector;
import probcog.logic.Conjunction;
import probcog.logic.Formula;
import weka.classifiers.trees.J48;
import weka.classifiers.trees.j48.Rule;
import weka.classifiers.trees.j48.Rule.Condition;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.ksu.cis.bnj.ver3.core.Discrete;
import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
/**
* Converts conditional probability tables to MLN formulas.
* @author Dominik Jain
*/
public class CPT2MLNFormulas {
protected RelationalBeliefNetwork rbn;
protected RelationalNode mainNode;
protected String additionalPrecondition;
public CPT2MLNFormulas(RelationalNode node) {
this.mainNode = node;
this.rbn = node.getNetwork();
additionalPrecondition = null;
}
/**
* adds a precondition that must be added to each conjunction that is generated
* @param cond
*/
public void addPrecondition(String cond) {
if(additionalPrecondition == null)
additionalPrecondition = cond;
else
additionalPrecondition += " ^ " + cond;
}
public String getPrecondition() {
return additionalPrecondition;
}
/**
* executes the conversion of the CPT to MLN formulas and prints the formulas to the given stream
* @param out
*/
public void convert(PrintStream out) {
try {
CPT2Rules cpt2rules = new CPT2Rules(mainNode);
// if the main node has constant parents, then we must learn a separate decision tree for each
// configuration of the constants
for(HashMap<String, String> constantAssignment : mainNode.getConstantAssignments()) {
Rule[] rules = cpt2rules.learnRules(constantAssignment);
// write weighted formulas for each of the decision tree's leaf nodes/rules
for(Rule rule : rules) {
if(!rule.hasAntds()) // if the rule has no antecedent (it must be the only rule) and it means that the distribution must be a uniform distribution, so we do not need any formulas
continue;
StringBuffer conjunction = new StringBuffer();
int lits = 0;
// generate literals for each of the nodes along the path from the root to a leaf
// Note: it is not guaranteed that each rule contains a check against the leaf node, so keep track if we ran across it
boolean haveMainNode = false;
for(Condition c : rule.getAntecedent()) {
RelationalNode node = cpt2rules.getRelationalNode(c);
if(node == mainNode)
haveMainNode = true;
String literal = node.toLiteralString(rbn.getDomainIndex(node.node, c.getValue()), constantAssignment);
if(lits++ > 0)
conjunction.append(" ^ ");
conjunction.append(literal);
}
// add preconditions
// TODO handle decision parents
for(RelationalNode parent : rbn.getRelationalParents(mainNode)) {
if(parent.isPrecondition) {
if(lits++ > 0)
conjunction.append(" ^ ");
conjunction.append(parent.toLiteralString(rbn.getDomainIndex(parent.node, "True"), constantAssignment));
}
}
// if we did not come across the main node above, create one variant of the conjunction for each possible setting
Vector<String> conjunctions = new Vector<String>();
if(!haveMainNode) {
for(int i = 0; i < mainNode.node.getDomain().getOrder(); i++) {
conjunctions.add(conjunction.toString() + " ^ " + mainNode.toLiteralString(i, null));
}
}
else
conjunctions.add(conjunction.toString());
// write final formulas with weights
double prob = Double.parseDouble(rule.getConsequent().getValue());
double weight = prob == 0.0 ? -100 : Math.log(prob);
for(String conj : conjunctions) {
out.print(weight + " ");
out.print(conj);
if(additionalPrecondition != null)
out.print(" ^ " + additionalPrecondition);
out.println();
}
}
}
}
catch (Exception e) {
e.printStackTrace();
}
}
/**
* the task here is to learn a decision tree as a compact representation of a
* CPT, the predictors being the values of all nodes relevant to the CPT (parents and child)
* and the predicted class attribute being the probability value
* @author Dominik Jain
*/
public static class CPT2Rules {
/**
* maps attribute names to actual attributes. Note: except for the probability value attribute "prob", the names of attributes correspond to names of relational node
*/
protected HashMap<String, Attribute> attrs;
protected RelationalBeliefNetwork rbn;
protected CPF cpf;
BeliefNode[] nodes;
FastVector fvAttribs;
HashMap<Attribute, RelationalNode> relNodes;
RelationalNode mainNode;
int zerosInCPT;
public CPT2Rules(RelationalNode relNode) {
mainNode = relNode;
rbn = relNode.getNetwork();
cpf = relNode.node.getCPF();
nodes = cpf.getDomainProduct();
// the vector of attributes
fvAttribs = new FastVector(nodes.length+1);
attrs = new HashMap<String,Attribute>();
// generate the predictor attributes (one attribute for each of the parents and the node itself)
relNodes = new HashMap<Attribute, RelationalNode>();
for(BeliefNode node : nodes) {
ExtendedNode extNode = rbn.getExtendedNode(node);
if(extNode instanceof DecisionNode)
continue;
Discrete dom = (Discrete)node.getDomain();
FastVector attValues = new FastVector(dom.getOrder());
for(int i = 0; i < dom.getOrder(); i++)
attValues.addElement(dom.getName(i));
Attribute attr = new Attribute(node.getName(), attValues);
attrs.put(node.getName(), attr);
relNodes.put(attr, rbn.getRelationalNode(node));
fvAttribs.addElement(attr);
}
// add class (predicted) attribute, which here is the probability value
// - collect set of values
TreeSet<Double> probs = new TreeSet<Double>();
zerosInCPT = 0;
walkCPT4ValueSet(new int[nodes.length], 0, probs);
FastVector attrValues = new FastVector(probs.size());
for(Double d : probs)
attrValues.addElement(Double.toString(d));
// - add attribute
Attribute probAttr = new Attribute("prob", attrValues);
attrs.put("prob", probAttr);
fvAttribs.addElement(probAttr);
}
protected void walkCPT4ValueSet(int[] addr, int i, Set<Double> values) {
BeliefNode[] nodes = cpf.getDomainProduct();
if(i == addr.length) { // we have a complete address
// get the probability value
int realAddr = cpf.addr2realaddr(addr);
double value = ((ValueDouble)cpf.get(realAddr)).getValue();
if(value == 0.0)
zerosInCPT++;
values.add(value);
}
else { // the address is yet incomplete -> consider all ways of setting the next e
Discrete dom = (Discrete)nodes[i].getDomain();
ExtendedNode extNode = rbn.getExtendedNode(nodes[i]);
if(extNode instanceof DecisionNode) {
addr[i] = 0; // True
walkCPT4ValueSet(addr, i+1, values);
}
else {
RelationalNode n = (RelationalNode)extNode;
if(n.isPrecondition) {
addr[i] = dom.findName("True");
walkCPT4ValueSet(addr, i+1, values);
}
else {
for(int j = 0; j < dom.getOrder(); j++) {
addr[i] = j;
walkCPT4ValueSet(addr, i+1, values);
}
}
}
}
}
public int getZerosInCPT() {
return zerosInCPT;
}
/**
* collects instances for the given constant assignment and learns a decision tree for which it returns the set of rules
* @param constantAssignment
* @return
* @throws Exception
*/
public Rule[] learnRules(Map<String, String> constantAssignment) throws Exception {
// collect instances
Instances instances = new Instances("foo", fvAttribs, 60000);
walkCPT4InstanceCollection(new int[nodes.length], 0, constantAssignment, instances);
// learn a J48 decision tree from the instances
instances.setClass(attrs.get("prob"));
J48 j48 = new J48();
j48.setUnpruned(true);
j48.setMinNumObj(0); // there is no minimum number of objects that has to end up at each of the tree's leaf nodes
j48.buildClassifier(instances);
// output the decision tree
//System.out.println(j48);
return j48.getRules();
}
protected void walkCPT4InstanceCollection(int[] addr, int i, Map<String,String> constantSettings, Instances instances) throws Exception {
BeliefNode[] nodes = cpf.getDomainProduct();
if(i == addr.length) { // we have a complete address
// get the probability value
int realAddr = cpf.addr2realaddr(addr);
double value = ((ValueDouble)cpf.get(realAddr)).getValue();
// create a new instance
Instance inst = new Instance(nodes.length+1);
// translate the address to attribute settings
for(int j = 0; j < addr.length; j++) {
Attribute attr = attrs.get(nodes[j].getName());
if(attr != null) {
Discrete dom = (Discrete)nodes[j].getDomain();
inst.setValue(attr, dom.getName(addr[j]));
}
}
// add value of class (predicted) attribute - i.e. the probability value
inst.setValue(attrs.get("prob"), Double.toString(value));
// add the instance to our collection
instances.add(inst);
}
else { // the address is yet incomplete -> consider all ways of setting the next e
Discrete dom = (Discrete)nodes[i].getDomain();
ExtendedNode extNode = rbn.getExtendedNode(nodes[i]);
if(extNode instanceof DecisionNode) {
addr[i] = 0; // True
walkCPT4InstanceCollection(addr, i+1, constantSettings, instances);
}
else {
RelationalNode n = (RelationalNode)extNode;
if(n.isPrecondition) {
addr[i] = dom.findName("True");
if(addr[i] == -1)
throw new Exception("The node " + nodes[i] + " is set as a precondition, but its domain does not contain the value 'True'.");
walkCPT4InstanceCollection(addr, i+1, constantSettings, instances);
}
else if(n.isConstant) {
addr[i] = dom.findName(constantSettings.get(n.getName()));
walkCPT4InstanceCollection(addr, i+1, constantSettings, instances);
}
else {
for(int j = 0; j < dom.getOrder(); j++) {
addr[i] = j;
walkCPT4InstanceCollection(addr, i+1, constantSettings, instances);
}
}
}
}
}
/**
* gets the relational node that corresponds to the attribute that is being checked against in the given condition
* @param c
* @return
*/
public RelationalNode getRelationalNode(Condition c) {
return relNodes.get(c.getAttribute());
}
public Formula getConjunction(Rule rule, Map<String,String> constantAssignment) throws Exception {
Vector<Formula> conjuncts = new Vector<Formula>();
for(Condition c : rule.getAntecedent()) {
RelationalNode node = this.getRelationalNode(c);
int value = rbn.getDomainIndex(node.node, c.getValue());
Formula literal = node.toLiteral(value, constantAssignment);
conjuncts.add(literal);
}
return new Conjunction(conjuncts);
}
public double getProbability(Rule r) {
return Double.parseDouble(r.getConsequent().getValue());
}
}
}