/******************************************************************************* * Copyright (C) 2007-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.bayesnets.conversion; import java.io.PrintStream; import java.util.HashMap; import java.util.HashSet; import java.util.Random; import probcog.bayesnets.core.BeliefNetworkEx; import probcog.bayesnets.learning.CPTLearner; import probcog.srldb.Database; import probcog.srldb.Object; import probcog.srldb.datadict.DDAttribute; import probcog.srldb.datadict.DDException; import probcog.srldb.datadict.DDObject; import probcog.srldb.datadict.DataDictionary; import probcog.srldb.datadict.domain.AutomaticDomain; import probcog.srldb.datadict.domain.BooleanDomain; import edu.ksu.cis.bnj.ver3.core.BeliefNode; import edu.ksu.cis.bnj.ver3.core.Domain; import edu.ksu.cis.util.graph.core.Graph; import edu.ksu.cis.util.graph.core.Vertex; /** * creates an srldb.Database by sampling a Bayesian network * @author Dominik Jain */ public class BN2SRLDB { protected BeliefNetworkEx bn; protected Database db; protected HashSet<String> booleanConversion; protected HashMap<String,String> undoConversion; public BN2SRLDB(BeliefNetworkEx bn) { this.bn = bn; this.db = null; this.booleanConversion = null; } public void setBooleanConversion(String attrName) { if(booleanConversion == null) { booleanConversion = new HashSet<String>(); } booleanConversion.add(attrName); } public Database getDB(int numSamples) throws DDException, Exception { return getDB(numSamples, new Random()); } protected boolean isBooleanNode(BeliefNode node) { Domain nodeDomain = node.getDomain(); return nodeDomain.getOrder() == 2 && (nodeDomain.getName(0).equalsIgnoreCase("true") || nodeDomain.getName(0).equalsIgnoreCase("false")); } public Database getDB(int numSamples, Random generator) throws DDException, Exception { // create data dictionary with a single object DataDictionary datadict = new DataDictionary(); DDObject ddObj = new DDObject(Object.class.getSimpleName()); // - add all nodes as attributes BeliefNode[] nodes = bn.bn.getNodes(); for(int i = 0; i < nodes.length; i++) { probcog.srldb.datadict.domain.Domain domain; // check if the node is boolean... if(isBooleanNode(nodes[i])) { domain = BooleanDomain.getInstance(); ddObj.addAttribute(new DDAttribute(nodes[i].getName(), domain)); } else { // it's not boolean String name = nodes[i].getName(); // add an attribute with an automatic domain domain = new AutomaticDomain("dom" + nodes[i].getName()); DDAttribute ddAttr = new DDAttribute(nodes[i].getName(), domain); ddObj.addAttribute(ddAttr); // check if we need to convert this node's values to boolean attributes if(booleanConversion != null && booleanConversion.contains(name)) { // add a boolean attribute for each outcome in the node's domain Domain nodeDomain = nodes[i].getDomain(); for(int j = 0; j < nodeDomain.getOrder(); j++) { ddObj.addAttribute(new DDAttribute(nodeDomain.getName(j), BooleanDomain.getInstance())); } // mark the original attribute as discarded so it doesn't get included in any outputs // (we still keep the attribute added because we may need to use it for CPT learning) ddAttr.discard(); } } } datadict.addObject(ddObj); // create database db = new Database(datadict); // generate samples and add as objects for(int i = 0; i < numSamples; i++) { HashMap<String,String> sample = bn.getSample(generator); Object obj = new Object(db, "object"); if(booleanConversion != null) { for(String attrName : booleanConversion) { String value = sample.get(attrName); sample.put(value, "true"); //sample.remove(attrName); } } obj.addAttributes(sample); obj.commit(); System.out.println(sample); } db.check(); return db; } public void relearnBN() throws Exception { if(db == null) throw new Exception("No sampled data available for learning; call getDB() first!"); // relearn new Bayesian network CPTs from the samples CPTLearner cptLearner = new CPTLearner(bn); for(Object obj : db.getObjects()) { cptLearner.learn(obj.getAttributes()); } cptLearner.finish(); } protected void writeNodeLiteralAllCombs(PrintStream out, BeliefNode n, int varidx) { if(isBooleanNode(n)) out.print("*" + Database.stdPredicateName(n.getName()) + "(o)"); else out.print(Database.stdPredicateName(n.getName()) + "(o,+a" + varidx + ")"); } public void writeMLNFormulas(PrintStream out) { Graph g = bn.bn.getGraph(); Vertex[] vertices = g.getVertices(); BeliefNode[] nodes = bn.bn.getNodes(); for(int i = 0; i < vertices.length; i++) { Vertex[] parents = g.getParents(vertices[i]); if(parents.length == 0) continue; int varidx = 0; for(int j = 0; j < parents.length; j++) { BeliefNode n = nodes[parents[j].loc()]; if(j > 0) out.print(" ^ "); writeNodeLiteralAllCombs(out, n, varidx++); } //if(parents.length > 0) out.print(" => "); writeNodeLiteralAllCombs(out, nodes[vertices[i].loc()], varidx); out.println(); } } }