/*******************************************************************************
* Copyright (C) 2007-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.srl.directed;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Collection;
import java.util.Vector;
import java.util.Map.Entry;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import probcog.bayesnets.core.BeliefNetworkEx;
import probcog.srl.BooleanDomain;
import probcog.srl.Database;
import probcog.srl.RealDomain;
import probcog.srl.RelationKey;
import probcog.srl.Signature;
import probcog.srl.directed.learning.CPTLearner;
import probcog.srl.directed.learning.DomainLearner;
import probcog.srl.taxonomy.Concept;
import probcog.srl.taxonomy.Taxonomy;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.ksu.cis.bnj.ver3.core.Discrete;
import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
import edu.tum.cs.util.FileUtil;
import edu.tum.cs.util.StringTool;
/**
* Advanced Bayesian Logical (ABL) Model
*
* This class contains reading and writing methods that are specific to an implementation
* of a relational belief network.
*
* @author Dominik Jain
*/
public class ABLModel extends RelationalBeliefNetwork {
protected File networkFile = null;
protected File[] declsFiles = null;
public static Pattern regexFunctionName = Pattern.compile("[\\w]+"); // NOTE: should actually start with lower-case (because of Prolog compatibility), but left this way for backward comp. with older models
public static Pattern regexTypeName = regexFunctionName;
public static Pattern regexEntity = Pattern.compile("(?:[a-zA-Z][\\w]*|[0-9]+(?:\\.[0-9]+)?)");
/**
* constructs a model by obtaining the node data from a fragment
* network and declarations from one or more files.
*
* @param declarationsFiles
* @param networkFile
* a fragment network file
* @throws Exception
*/
public ABLModel(String[] declarationsFiles, String networkFile) throws Exception {
init(declarationsFiles, networkFile);
}
/**
* constructs a BLOG model by obtaining the node data from a Bayesian
* network template and function signatures from a BLOG file.
*
* @param declarationsFile
* @param networkFile
* @throws Exception
*/
public ABLModel(String declarationsFile, String networkFile) throws Exception {
String[] decls = null;
if(declarationsFile != null)
decls = new String[] { declarationsFile };
init(decls, networkFile);
}
public ABLModel(String declarationsFile) throws Exception {
if(declarationsFile == null)
throw new Exception("Declarations file cannot be null");
init(new String[]{ declarationsFile }, null);
}
public static boolean isValidEntityName(String s) {
return regexEntity.matcher(s).matches();
}
public static boolean isValidFunctionName(String s) {
return regexFunctionName.matcher(s).matches();
}
public static boolean isValidTypeName(String s) {
return regexTypeName.matcher(s).matches();
}
private void init(String[] declarationsFiles, String networkFile) throws Exception {
if(networkFile != null)
this.networkFile = new File(networkFile);
boolean guessedSignatures = true;
if(declarationsFiles != null) {
declsFiles = new File[declarationsFiles.length];
for(int i = 0; i < declarationsFiles.length; i++)
declsFiles[i] = new File(declarationsFiles[i]).getAbsoluteFile();
String content = readBlogContent(declsFiles);
readDeclarations(content);
guessedSignatures = false;
}
if(this.networkFile == null)
throw new Exception("No fragment network was given");
initNetwork(this.networkFile);
if(guessedSignatures)
guessSignatures();
else
checkSignatures();
}
protected void readDeclarations(String decls) throws Exception {
// remove comments
Pattern comments = Pattern.compile("//.*?$|/\\*.*?\\*/", Pattern.MULTILINE | Pattern.DOTALL);
Matcher matcher = comments.matcher(decls);
decls = matcher.replaceAll("");
// read line by line
String[] lines = decls.split("\n");
for (String line : lines) {
line = line.trim();
if (line.length() == 0)
continue;
if (!readDeclaration(line))
if (!line.contains("~"))
throw new Exception("Could not interpret the line '" + line + "'");
}
}
protected boolean readDeclaration(String line) throws Exception {
// function signature
// TODO: logical Boolean required - split this into random / logical w/o Boolean / utility?
if(line.startsWith("random") || line.startsWith("logical") || line.startsWith("utility")) {
Pattern pat = Pattern.compile("(random|logical|utility)\\s+(\\w+)\\s+(\\w+)\\s*\\((.*)\\)\\s*;?", Pattern.CASE_INSENSITIVE);
Matcher matcher = pat.matcher(line);
if (matcher.matches()) {
boolean isLogical = matcher.group(1).equals("logical");
boolean isUtility = matcher.group(1).equals("utility");
String retType = matcher.group(2);
String[] argTypes = matcher.group(4).trim().split("\\s*,\\s*");
Signature sig = new Signature(matcher.group(3), retType, argTypes, isLogical, isUtility);
addSignature(sig);
// functions declared as logical are always given (either implicitly through the closed-world assumption which assumes false)
// or explicitly (in the explicit case, we do not insist that the variable must be Boolean, which is why we do not throw the exception).
//if(isLogical && !sig.isBoolean())
// throw new Exception("Function '" + sig.functionName + "' was declared as logical but isn't a Boolean function");
// ensure types used in signature exist, adding them if necessary
addType(sig.returnType, false);
for(String t : sig.argTypes)
addType(t, false);
return true;
}
return false;
}
// obtain guaranteed domain elements
if (line.startsWith("guaranteed")) {
Pattern pat = Pattern.compile("guaranteed\\s+(\\w+)\\s+(.*?)\\s*;?");
Matcher matcher = pat.matcher(line);
if (matcher.matches()) {
String domName = matcher.group(1);
String[] elems = matcher.group(2).split("\\s*,\\s*");
elems = makeDomainElements(elems);
guaranteedDomElements.put(domName, Arrays.asList(elems));
return true;
}
return false;
}
// read functional dependencies among relation arguments
if (line.startsWith("relationKey") || line.startsWith("RelationKey")) {
Pattern pat = Pattern.compile("[Rr]elationKey\\s+(\\w+)\\s*\\((.*)\\)\\s*;?");
Matcher matcher = pat.matcher(line);
if (matcher.matches()) {
String relation = matcher.group(1);
String[] arguments = matcher.group(2).trim().split("\\s*,\\s*");
addRelationKey(new RelationKey(relation, arguments));
return true;
}
return false;
}
// read type information
if (line.startsWith("type") || line.startsWith("Type")) {
Pattern pat = Pattern.compile("[Tt]ype\\s+(.*?)\\s*;?$");
Matcher matcher = pat.matcher(line);
Pattern typeDecl = Pattern.compile("(\\w+)(?:\\s+isa\\s+(\\w+))?");
if (matcher.matches()) {
String[] decls = matcher.group(1).split("\\s*,\\s*");
for (String d : decls) {
Matcher m = typeDecl.matcher(d);
if(m.matches()) {
Concept c = addType(m.group(1), true);
if (m.group(2) != null) {
Concept parent = addType(m.group(2), false);
c.setParent(parent);
}
}
else
throw new Exception("The type declaration '" + d + "' is invalid");
}
return true;
}
return false;
}
// prolog rule
if (line.startsWith("prolog")) {
String rule = line.substring(6).trim();
if(rule.endsWith(";"))
rule = rule.substring(0, rule.length()-1);
if(!rule.endsWith("."))
rule += ".";
prologRules.add(rule);
return true;
}
// combining rule
if(line.startsWith("combining-rule")) {
Pattern pat = Pattern.compile("combining-rule\\s+(\\w+)\\s+([-\\w]+)\\s*;?");
Matcher matcher = pat.matcher(line);
if(matcher.matches()) {
String function = matcher.group(1);
String strRule = matcher.group(2);
Signature sig = getSignature(function);
CombiningRule rule;
if(sig == null)
throw new Exception("Defined combining rule for unknown function '" + function + "'");
try {
rule = CombiningRule.fromString(strRule);
}
catch(IllegalArgumentException e) {
Vector<String> v = new Vector<String>();
for(CombiningRule cr : CombiningRule.values())
v.add(cr.stringRepresention);
throw new Exception("Invalid combining rule '" + strRule + "'; valid options: " + StringTool.join(", ", v));
}
this.combiningRules.put(function, rule);
return true;
}
}
// declaration of uniform default distribution if no fragment applicable
if(line.startsWith("uniform-default")) {
Pattern pat = Pattern.compile("uniform-default\\s+([-\\w]+(?:\\s*,\\s*[-\\w]+)*)\\s*;?");
Matcher matcher = pat.matcher(line);
if(matcher.matches()) {
String[] functions = matcher.group(1).split("\\s*,\\s*");
for(String f : functions)
this.uniformDefaultFunctions.add(f);
}
return true;
}
// fragment network file reference
if(line.startsWith("fragments")) {
Pattern pat = Pattern.compile("fragments\\s+([^;\\s]+)\\s*;?");
Matcher matcher = pat.matcher(line);
if(matcher.matches()) {
String filename = matcher.group(1);
File f = findReferencedFile(filename);
if(f == null)
throw new Exception("Declared fragments file " + filename + " could not be found");
if(networkFile != null) { // if we already have another network file, then the one that is declared here is not used
if(!networkFile.getAbsoluteFile().equals(f.getAbsoluteFile()))
System.err.println("Notice: Declared network file " + filename + " is overridden by " + networkFile);
return true;
}
networkFile = f;
return true;
}
}
return false;
}
public static String[] makeDomainElements(String[] elems) {
// handle "i..j" -> list of integers from i to j
Vector<String> vElems = null;
for(int i = 0; i < elems.length; i++) {
String item = elems[i];
if(item.contains("..")) {
if(vElems == null) {
vElems = new Vector<String>();
for(int j = 0; j < i; j++)
vElems.add(elems[j]);
}
String[] strBounds = item.split("\\.\\.");
Integer from = Integer.parseInt(strBounds[0]);
Integer to = Integer.parseInt(strBounds[1]);
for(Integer k = from; k <= to; k++)
vElems.add(k.toString());
}
else {
if(vElems != null)
vElems.add(item);
}
}
if(vElems != null)
return vElems.toArray(new String[vElems.size()]);
return elems;
}
protected File findReferencedFile(String filename) {
File f = new File(filename).getAbsoluteFile();
if(f.exists())
return f;
else {
for(File parentFile : this.declsFiles) {
f = new File(parentFile.getParentFile().getAbsoluteFile(), filename);
if(f.exists())
return f;
}
}
return null;
}
/**
* adds a concept for a type to the taxonomy unless it is already present
* @param typeName the name of the type
* @param explicitlyDeclared whether the type is to be added due to its explicit declaration (all other type creations must issue a warning!)
* @return the taxonomy object for the given
*/
protected Concept addType(String typeName, boolean explicitlyDeclared) {
if(BooleanDomain.isBooleanType(typeName) || RealDomain.isRealType(typeName))
return null;
if(taxonomy == null)
taxonomy = new Taxonomy();
Concept c = taxonomy.getConcept(typeName);
if(c == null) {
taxonomy.addConcept(c = new Concept(typeName));
if(!explicitlyDeclared)
System.err.println("Warning: The type '" + typeName + "' was not explicitly declared before it was first used; implicitly adding it to the taxonomy...");
}
return c;
}
/**
* read the contents of one or more (BLOG) files into a single string
*
* @param files
* @return
* @throws FileNotFoundException
* @throws IOException
*/
protected String readBlogContent(File[] files)
throws FileNotFoundException, IOException {
// read the blog files
StringBuffer buf = new StringBuffer();
for (File blogFile : files) {
buf.append(FileUtil.readTextFile(blogFile));
buf.append('\n');
}
return buf.toString();
}
/**
* generates the ground Bayesian network for the template network that this
* model represents, instantiating it with the guaranteed domain elements
*
* @return
* @throws Exception
* @deprecated no longer maintained; for BLNs superseded by the respective grounding process
*/
public BeliefNetworkEx getGroundBN() throws Exception {
// create a new Bayesian network
BeliefNetworkEx gbn = new BeliefNetworkEx();
// add nodes in topological order
int[] order = this.getTopologicalOrder();
for (int i = 0; i < order.length; i++) { // for each template node (in
// topological order)
RelationalNode node = getRelationalNode(order[i]);
// get all possible argument groundings
Signature sig = getSignature(node.functionName);
if (sig == null)
throw new Exception("Could not retrieve signature for node "
+ node.functionName);
Vector<String[]> argGroundings = groundParams(sig);
// create a new node for each grounding with the same domain and CPT
// as the template node
for (String[] args : argGroundings) {
String newName = Signature.formatVarName(node.functionName,
args);
BeliefNode newNode = new BeliefNode(newName, node.node
.getDomain());
gbn.addNode(newNode);
// link from all the parent nodes
String[] parentNames = getParentVariableNames(node, args);
for (String parentName : parentNames) {
BeliefNode parent = gbn.getNode(parentName);
gbn.bn.connect(parent, newNode);
}
// transfer the CPT (the entries for the new node may not be in
// the same order so determine the appropriate mapping)
// TODO this assumes that a function name occurs only in one
// parent
CPF newCPF = newNode.getCPF(), oldCPF = node.node.getCPF();
BeliefNode[] oldProd = oldCPF.getDomainProduct();
BeliefNode[] newProd = newCPF.getDomainProduct();
int[] old2newindex = new int[oldProd.length];
for (int j = 0; j < oldProd.length; j++) {
for (int k = 0; k < newProd.length; k++)
if (RelationalNode.extractFunctionName(
newProd[k].getName()).equals(
RelationalNode.extractFunctionName(oldProd[j]
.getName())))
old2newindex[j] = k;
}
for (int j = 0; j < oldCPF.size(); j++) {
int[] oldAddr = oldCPF.realaddr2addr(j);
int[] newAddr = new int[oldAddr.length];
for (int k = 0; k < oldAddr.length; k++)
newAddr[old2newindex[k]] = oldAddr[k];
newCPF.put(newCPF.addr2realaddr(newAddr), oldCPF.get(j));
}
}
}
return gbn;
}
/**
* gets a list of lists of constants representing all possible combinations
* of elements of the given domains (domNames)
*
* @param domNames
* a list of domain names
* @param setting
* the current setting (initially empty) - same length as
* domNames
* @param idx
* the index of the domain from which to choose next
* @param ret
* the vector in which all settings shall be stored
* @throws Exception
*/
protected void groundParams(String[] domNames, String[] setting, int idx,
Vector<String[]> ret) throws Exception {
if (idx == domNames.length) {
ret.add(setting.clone());
return;
}
Collection<String> elems = guaranteedDomElements.get(domNames[idx]);
if (elems == null) {
throw new Exception("No guaranteed domain elements for "
+ domNames[idx]);
}
for (String elem : elems) {
setting[idx] = elem;
groundParams(domNames, setting, idx + 1, ret);
}
}
protected Vector<String[]> groundParams(Signature sig) throws Exception {
Vector<String[]> ret = new Vector<String[]>();
groundParams(sig.argTypes, new String[sig.argTypes.length], 0, ret);
return ret;
}
public void write(PrintStream out) throws Exception {
BeliefNode[] nodes = bn.getNodes();
// write declarations for types, guaranteed domain elements and
// functions
writeDeclarations(out);
// write conditional probability distributions
// NOTE: These declarations are only included for compatibility with BLOG, they are not necessary otherwise, as distributions are read from fragment networks
// TODO handle decision parents properly by using if-then-else?
for (RelationalNode relNode : getRelationalNodes()) {
if (relNode.isAuxiliary)
continue;
CPF cpf = nodes[relNode.index].getCPF();
BeliefNode[] deps = cpf.getDomainProduct();
Discrete[] domains = new Discrete[deps.length];
StringBuffer args = new StringBuffer();
int[] addr = new int[deps.length];
for (int j = 0; j < deps.length; j++) {
if (deps[j].getType() == BeliefNode.NODE_DECISION)
// ignore decision nodes (they are not dependencies because
// they are assumed to be true)
continue;
if (j > 0) {
if (j > 1)
args.append(", ");
args.append(getRelationalNode(deps[j]).getCleanName());
}
domains[j] = (Discrete) deps[j].getDomain();
}
Vector<String> lists = new Vector<String>();
getCPD(lists, cpf, domains, addr, 1);
out.printf("%s ~ TabularCPD[%s](%s);\n", relNode.getCleanName(),
StringTool.join(",", lists.toArray(new String[0])), args
.toString());
}
}
protected void writeDeclarations(PrintStream out) {
if(this.networkFile != null) {
out.printf("fragments %s;\n\n", this.networkFile.toString());
}
// write type decls
for(Concept c : this.taxonomy.getConcepts()) {
if(c.parent == null)
out.printf("type %s;\n", c.name);
else
out.printf("type %s isa %s;\n", c.name, c.parent.name);
}
out.println();
// write domains
for(Entry<String, ? extends Collection<String>> e : guaranteedDomElements.entrySet()) {
out.println("guaranteed " + e.getKey() + " " + StringTool.join(", ", e.getValue()) + ";");
}
out.println();
// signatures
for(Signature sig : getSignatures()) {
out.printf("%s %s %s(%s);\n", sig.isLogical ? "logical" : "random", sig.returnType, sig.functionName, StringTool.join(", ", sig.argTypes));
}
out.println();
// relation keys
for(Collection<RelationKey> c : this.relationKeys.values())
for(RelationKey relKey : c)
out.println(relKey.toString());
out.println();
// combining rules
for(Entry<String, CombiningRule> e : this.combiningRules.entrySet()) {
out.printf("combining-rule %s %s;\n", e.getKey(), e.getValue().stringRepresention);
}
}
protected void getCPD(Vector<String> lists, CPF cpf, Discrete[] domains, int[] addr, int i) {
if (i == addr.length) {
StringBuffer sb = new StringBuffer();
sb.append('[');
for (int j = 0; j < domains[0].getOrder(); j++) {
addr[0] = j;
int realAddr = cpf.addr2realaddr(addr);
double value = ((ValueDouble) cpf.get(realAddr)).getValue();
if (j > 0)
sb.append(',');
sb.append(value);
}
sb.append(']');
lists.add(sb.toString());
}
else {
// go through all possible parent-child configurations
BeliefNode[] domProd = cpf.getDomainProduct();
if (domProd[i].getType() == BeliefNode.NODE_DECISION) // for decision nodes, always assume true
addr[i] = 0;
else {
for (int j = 0; j < domains[i].getOrder(); j++) {
addr[i] = j;
getCPD(lists, cpf, domains, addr, i + 1);
}
}
}
}
public void setNetworkFilename(String networkFilename) {
this.networkFile = new File(networkFilename);
}
public static void main(String[] args) {
try {
String bifFile = "abl/kitchen-places/actseq.xml";
ABLModel bn = new ABLModel(new String[] { "abl/kitchen-places/actseq.abl" }, bifFile);
String dbFile = "abl/kitchen-places/train.blogdb";
// read the training database
System.out.println("Reading data...");
Database db = new Database(bn);
db.readBLOGDB(dbFile);
System.out.println(" " + db.getEntries().size()
+ " variables read.");
// learn domains
if (true) {
System.out.println("Learning domains...");
DomainLearner domLearner = new DomainLearner(bn);
domLearner.learn(db);
domLearner.finish();
}
// learn parameters
System.out.println("Learning parameters...");
CPTLearner cptLearner = new CPTLearner(bn);
cptLearner.learnTyped(db, true, true);
cptLearner.finish();
System.out.println("Writing XML-BIF output...");
bn.saveXMLBIF(bifFile);
if (true) {
System.out.println("Showing Bayesian network...");
bn.show();
}
} catch (Exception e) {
e.printStackTrace();
}
}
}