/*******************************************************************************
* Copyright (C) 2008-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.srl.directed.bln;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Vector;
import java.util.Map.Entry;
import probcog.bayesnets.core.BeliefNetworkEx;
import probcog.bayesnets.core.Discretized;
import probcog.inference.IParameterHandler;
import probcog.inference.ParameterHandler;
import probcog.srl.BooleanDomain;
import probcog.srl.Database;
import probcog.srl.ParameterGrounder;
import probcog.srl.Signature;
import probcog.srl.directed.CombiningRule;
import probcog.srl.directed.ExtendedNode;
import probcog.srl.directed.RelationalBeliefNetwork;
import probcog.srl.directed.RelationalNode;
import probcog.srl.directed.ParentGrounder.ParentGrounding;
import probcog.srl.directed.RelationalNode.Aggregator;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.ksu.cis.bnj.ver3.core.CPT;
import edu.ksu.cis.bnj.ver3.core.Discrete;
import edu.ksu.cis.bnj.ver3.core.Domain;
import edu.ksu.cis.bnj.ver3.core.Value;
import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
import edu.tum.cs.util.Stopwatch;
import edu.tum.cs.util.StringTool;
import edu.tum.cs.util.datastruct.Pair;
/**
* Abstract base class for ground networks generated from BLNs.
* @author Dominik Jain
*/
public abstract class AbstractGroundBLN implements IParameterHandler {
/**
* the ground Bayesian network (or ground auxiliary Bayesian network)
*/
protected BeliefNetworkEx groundBN;
/**
* the underlying template model
*/
protected AbstractBayesianLogicNetwork bln;
/**
* list of auxiliary nodes contained in the ground Bayesian network (null if the network is not an auxiliary network)
*/
protected Vector<BeliefNode> hardFormulaNodes;
/**
* the file from which the evidence database was loaded (if any)
*/
protected String databaseFile;
/**
* the database for which the ground model was instantiated
*/
protected Database db;
/**
* temporary mapping of function names to relational nodes that can serve for instantiation (used only during grounding)
*/
protected HashMap<String, Vector<RelationalNode>> functionTemplates;
/**
* temporary storage of names of instantiated variables (to avoid duplicate instantiation during grounding)
*/
protected HashSet<String> instantiatedVariables;
protected HashMap<String, Value[]> cpfCache;
protected boolean verbose = true;
protected boolean debug = false;
protected ParameterHandler paramHandler;
/**
* maps an instantiated ground node to a string identifying the CPF template that was used to create it
*/
protected HashMap<BeliefNode, String> cpfIDs;
/**
* maps a ground node (in the ground network) to the template node in the fragment network it was instantiated from
*/
protected HashMap<BeliefNode, RelationalNode> groundNode2TemplateNode;
public AbstractGroundBLN(AbstractBayesianLogicNetwork bln, Database db) throws Exception {
init(bln, db);
}
public AbstractGroundBLN(AbstractBayesianLogicNetwork bln, String databaseFile) throws Exception {
this.databaseFile = databaseFile;
Database db = new Database(bln.rbn);
db.readBLOGDB(databaseFile, true);
init(bln, db);
}
protected void init(AbstractBayesianLogicNetwork bln, Database db) throws Exception {
paramHandler = new ParameterHandler(this);
paramHandler.add("verbose", "setVerbose");
paramHandler.add("debug", "setDebugMode");
this.bln = bln;
db.finalize(); // before we start grounding with the DB, make sure it's really finalized
this.db = db;
cpfIDs = new HashMap<BeliefNode, String>();
groundNode2TemplateNode = new HashMap<BeliefNode, RelationalNode>();
}
public AbstractBayesianLogicNetwork getBLN() {
return this.bln;
}
/**
* instantiates the auxiliary Bayesian network for this model
* @throws Exception
*/
public void instantiateGroundNetwork() throws Exception {
instantiateGroundNetwork(true);
}
/**
* instantiates the ground Bayesian network for this model
* @param addAuxiliaryVars if true, also adds auxiliary nodes to the network that correspond to the hard logical constraints
* @throws Exception
*/
public void instantiateGroundNetwork(boolean addAuxiliaryVars) throws Exception {
Stopwatch sw = new Stopwatch();
sw.start();
if(verbose) System.out.println("generating network...");
groundBN = new BeliefNetworkEx();
// ground regular probabilistic nodes (i.e. ground atoms)
if(verbose) System.out.println(" regular nodes");
RelationalBeliefNetwork rbn = bln.rbn;
// collect the RelationalNodes that can be used as templates to ground variables for the various functions
functionTemplates = new HashMap<String, Vector<RelationalNode>>();
BeliefNode[] nodes = rbn.bn.getNodes();
for(int i = 0; i < nodes.length; i++) {
ExtendedNode extNode = rbn.getExtendedNode(i);
// determine if the node can be used to instantiate a variable
if(!(extNode instanceof RelationalNode))
continue;
RelationalNode relNode = (RelationalNode)extNode;
if(!relNode.isFragment()) // nodes that do not correspond to fragments can be ignored
continue;
// remember that this node can be instantiated using this relational node
String f = relNode.getFunctionName();
Vector<RelationalNode> v = functionTemplates.get(f);
if(v == null) {
v = new Vector<RelationalNode>();
functionTemplates.put(f, v);
}
v.add(relNode);
}
// go through all function names and generate all groundings for each of them
instantiatedVariables = new HashSet<String>();
cpfCache = new HashMap<String, Value[]>();
Iterable<String> functionNames = this.bln.rbn.getFunctionNames(); // functionTemplates.keySet();
for(String functionName : functionNames) {
if(verbose) System.out.println(" " + functionName);
Collection<String[]> parameterSets = ParameterGrounder.generateGroundings(bln.rbn, functionName, db);
for(String[] params : parameterSets)
instantiateVariable(functionName, params);
}
// clean up
instantiatedVariables = null;
functionTemplates = null;
cpfCache = null;
// add auxiliary variables for formulaic constraints
if(addAuxiliaryVars) {
if(verbose) System.out.println(" formulaic nodes");
hardFormulaNodes = new Vector<BeliefNode>();
groundFormulaicNodes();
}
if(verbose) {
System.out.println("network size: " + getGroundNetwork().bn.getNodes().length + " nodes");
System.out.println(String.format("construction time: %.4fs", sw.getElapsedTimeSecs()));
}
}
/**
* instantiates the variable that corresponds to the given function name and actual parameters
* by looking for a template and applying it, or simply returns the variable if it was previously instantiated
* @param functionName
* @param params
* @throws Exception
*/
protected BeliefNode instantiateVariable(String functionName, String[] params) throws Exception {
// check if the variable was previously instantiated and return the node if so
String varName = Signature.formatVarName(functionName, params);
if(instantiatedVariables.contains(varName))
return groundBN.getNode(varName);
if(debug) System.out.println("instantiating variable " + varName);
// consider all the relational nodes that could be used to instantiate the variable
Vector<RelationalNode> templates = functionTemplates.get(functionName);
boolean combiningRuleNeeded = false;
Vector<Pair<RelationalNode, Vector<ParentGrounding>>> suitableTemplates = new Vector<Pair<RelationalNode, Vector<ParentGrounding>>>();
// check potentially applicable templates
LinkedList<Exception> exceptions = new LinkedList<Exception>();
if(templates != null) {
for(RelationalNode relNode : templates) {
Vector<ParentGrounding> groundings = null;
try {
groundings = relNode.checkTemplateApplicability(params, db);
}
catch(Exception e) { // if an exception occurs, the template is of course inapplicable
exceptions.add(e);
}
if(groundings == null)
continue;
// this template is applicable
// if we have more than one grounding, we need a combining rule if no aggregator is given
if(groundings.size() > 1 && !relNode.hasAggregator())
combiningRuleNeeded = true;
// we also need a combining rule if we already have a suitable template
if(!suitableTemplates.isEmpty())
combiningRuleNeeded = true;
suitableTemplates.add(new Pair<RelationalNode, Vector<ParentGrounding>>(relNode, groundings));
}
}
// if there are no suitable template, we may have an error case
if(suitableTemplates.isEmpty()) {
// if a uniform default distribution was defined, construct it
if(this.bln.rbn.usesUniformDefault(functionName)) {
// TODO reuse data structures
Signature sig = this.bln.rbn.getSignature(functionName);
String[] aOutcomes;
ValueDouble[] dist;
if(sig.isBoolean()) {
aOutcomes = new String[]{"True", "False"};
ValueDouble half = new ValueDouble(0.5);
dist = new ValueDouble[]{half, half};
}
else {
Iterable<String> outcomes = db.getDomain(sig.returnType);
int c = 0;
for(@SuppressWarnings("unused") String o : outcomes)
c++;
aOutcomes = new String[c];
int i = 0;
double p = 1.0 / c;
dist = new ValueDouble[c];
for(String o : outcomes) {
dist[i] = new ValueDouble(p);
aOutcomes[i++] = o;
}
}
Discrete domain = new Discrete(aOutcomes);
BeliefNode mainNode = this.groundBN.addNode(varName, domain);
CPT cpf = new CPT();
cpf.build(new BeliefNode[]{mainNode}, dist);
mainNode.setCPF(cpf);
onAddGroundAtomNode(mainNode, params, sig);
instantiatedVariables.add(mainNode.getName());
return mainNode;
}
// otherwise, if it's not an evidence function, we have an error case
if(!this.bln.rbn.isEvidenceFunction(functionName)) {
if(bln.allowPartialInstantiation)
return null;
else {
StringBuffer error = new StringBuffer("No relational node was found that could serve as the template for the variable " + varName);
if(!exceptions.isEmpty())
error.append("\nThe following errors occurred while checking template applicability:");
for(Exception e : exceptions) {
error.append('\n');
error.append(e.getMessage());
}
throw new Exception(error.toString());
}
}
else { // it's an evidence variable
// add a detached dummy node that has a single 1.0 entry for its evidence value
/*
String value = this.db.getVariableValue(varName, true);
Domain dom = new Discrete(new String[]{value});
BeliefNode node = groundBN.addNode(varName, dom);
CPT cpf = new CPT(new BeliefNode[]{node});
cpf.setValues(new Value[]{new ValueDouble(1.0)});
node.setCPF(cpf);
*/
// TODO can't call this because we don't have a relNode; Actually we wouldn't want to have this node at all
//onAddGroundAtomNode(relNode, actualParams, mainNode);
onAddEvidenceVariable(functionName, params);
if(debug)
System.out.println(" " + varName + " (skipped, is evidence)");
return null;
}
}
// get the first applicable template
Pair<RelationalNode, Vector<ParentGrounding>> template = suitableTemplates.iterator().next();
RelationalNode relNode = template.first;
// keep track of instantiated variables
String mainNodeName = relNode.getVariableName(params);
instantiatedVariables.add(mainNodeName);
if(debug)
System.out.println(" " + mainNodeName);
// add the node itself to the network
BeliefNode mainNode = groundBN.addNode(mainNodeName, relNode.node.getDomain(), relNode.node.getType());
onAddGroundAtomNode(mainNode, params, relNode.getSignature());
// we can now instantiate the variable based on the suitable templates
if(!combiningRuleNeeded) {
if(debug) System.out.println(" instantiating without combining rule");
instantiateVariableFromSingleTemplate(mainNode, template.first, template.second);
}
else { // need to use combining rule
// TODO ground nodes instantiated from combining rules do not have a template assigned to them via the mapping
CombiningRule r = bln.rbn.getCombiningRule(functionName);
if(r == null)
throw new Exception("More than one group of parents for variable " + varName + " but no combining rule was specified");
if(debug) System.out.println(" instantiating with combining rule " + r);
instantiateVariableWithCombiningRule(mainNode, suitableTemplates, r);
}
return mainNode;
}
/**
* instantiates a variable from the given node template for the actual parameters
* @param relNode the node that is to serve as the template
* @param groundings a vector of node groundings, i.e. mappings from node indices to parameter lists
* @return
* @throws Exception
*/
protected void instantiateVariableFromSingleTemplate(BeliefNode mainNode, RelationalNode relNode, Vector<ParentGrounding> groundings) throws Exception {
groundNode2TemplateNode.put(mainNode, relNode);
// add edges from the parents
// - normal case: just CPF application for one set of parents
if(!relNode.hasAggregator()) {
if(groundings.size() != 1)
throw new Exception("Cannot instantiate " + mainNode.getName() + " for " + groundings.size() + " groups of parents.");
if(debug) {
System.out.println(" relevant nodes/parents");
Map<Integer, String[]> grounding = groundings.firstElement().nodeArgs;
for(Entry<Integer, String[]> e : grounding.entrySet()) {
System.out.println(" " + bln.rbn.getRelationalNode(e.getKey()).getVariableName(e.getValue()));
}
}
instantiateCPF(groundings.firstElement().nodeArgs, relNode, mainNode);
}
// - other case: use combination function
else {
ArrayList<BeliefNode> domprod = new ArrayList<BeliefNode>();
domprod.add(mainNode);
// determine if auxiliary nodes need to be used and connect the parents appropriately
if(!relNode.aggregator.isFunctional) {
Signature sig = relNode.getSignature();
// create auxiliary nodes, one for each set of parents
Vector<BeliefNode> auxNodes = new Vector<BeliefNode>();
int k = 0;
for(ParentGrounding grounding : groundings) {
// create auxiliary node
String auxNodeName = String.format("AUX%d_%s", k++, mainNode.getName());
BeliefNode auxNode = groundBN.addNode(auxNodeName, mainNode.getDomain(), mainNode.getType());
auxNodes.add(auxNode);
Pair<String,String[]> p = RelationalNode.parse(auxNodeName);
this.onAddAuxiliaryNode(auxNode, sig.isBoolean(), p.first, p.second);
// create links from parents to auxiliary node and transfer CPF
instantiateCPF(grounding.nodeArgs, relNode, auxNode);
}
// connect auxiliary nodes to main node
for(BeliefNode parent : auxNodes) {
//System.out.printf("connecting %s and %s\n", parent.getName(), mainNode.getName());
groundBN.connect(parent, mainNode, false);
domprod.add(parent);
}
}
// if the node is functionally determined by the parents, aux. nodes carrying the CPD in the template node are not required
// we link the grounded parents directly
else {
// Note: we keep the vector of parents (in domprod) ordered by parent set (i.e. the parents belonging to a set are grouped)
for(ParentGrounding grounding : groundings) {
HashMap<BeliefNode,BeliefNode> src2targetParent = new HashMap<BeliefNode,BeliefNode>();
connectParents(grounding.nodeArgs, relNode, mainNode, src2targetParent, null);
domprod.addAll(src2targetParent.values());
}
}
// apply combination function
Aggregator combFunc = relNode.aggregator;
if(combFunc == Aggregator.FunctionalOr || combFunc == Aggregator.NoisyOr || combFunc == Aggregator.FunctionalAnd) {
// check if the domain is really boolean
if(!RelationalBeliefNetwork.isBooleanDomain(mainNode.getDomain()))
throw new Exception("Cannot use OR aggregator on non-Boolean node " + relNode.toString());
// determine CPF-id
String cpfid = combFunc.getFunctionSyntax();
switch(combFunc) {
case FunctionalOr:
cpfid += String.format("-OR(%d-%d)", groundings.size(), groundings.firstElement().nodeArgs.size());
break;
case NoisyOr:
cpfid += String.format("-OR(%d)", groundings.size());
break;
case FunctionalAnd:
cpfid += String.format("-AND(%d)", groundings.size());
break;
}
// build the CPF
CPT cpf = (CPT)mainNode.getCPF();
BeliefNode[] domprod_arr = domprod.toArray(new BeliefNode[domprod.size()]);
// - check if we have a cached CPF that we can reuse
Value[] values = cpfCache.get(cpfid);
if(values != null)
cpf.build(domprod_arr, values);
// - otherwise set and apply the filler
else {
cpf.buildZero(domprod_arr, false);
CPFFiller filler = null;
switch(combFunc) {
case FunctionalOr:
filler = new CPFFiller_ORGrouped(mainNode, groundings.firstElement().nodeArgs.size()-1);
break;
case NoisyOr:
filler = new CPFFiller_OR(mainNode);
break;
case FunctionalAnd:
filler = new CPFFiller_AND(mainNode);
break;
}
filler.fill();
// store the newly built CPF in the cache
cpfCache.put(cpfid, cpf.getValues());
}
// set the CPF-id
cpfIDs.put(mainNode, cpfid);
}
else if(combFunc == Aggregator.Sum) {
// check if the domain is really real
if(!RelationalBeliefNetwork.isRealDomain(mainNode.getDomain()))
throw new Exception("Cannot use SUM aggregator on non-Real node " + relNode.toString());
// build the CPF
CPT cpf = (CPT)mainNode.getCPF();
String cpfid = combFunc.getFunctionSyntax() + String.format("-%d", groundings.size());
BeliefNode[] domprod_arr = domprod.toArray(new BeliefNode[domprod.size()]);
// - check if we have a cached CPF that we can reuse
Value[] values = cpfCache.get(cpfid);
if(values != null)
cpf.build(domprod_arr, values);
// - otherwise set and apply the filler
else {
cpf.buildZero(domprod_arr, false);
CPFFiller filler = new CPFFiller_SUM(mainNode);
filler.fill();
// store the newly built CPF in the cache
cpfCache.put(cpfid, cpf.getValues());
}
// set the CPF-id
cpfIDs.put(mainNode, cpfid);
}
else
throw new Exception("Cannot ground structure because of multiple parent sets for node " + mainNode.getName() + " with unhandled aggregator " + relNode.aggregator);
}
}
protected BeliefNode instantiateVariableWithCombiningRule(BeliefNode mainNode, Vector<Pair<RelationalNode, Vector<ParentGrounding>>> suitableTemplates, CombiningRule r) throws Exception {
// get the parent set
HashMap<BeliefNode, Integer> parentIndices = new HashMap<BeliefNode, Integer>();
// * for all the templates (relational nodes) that are involved in the combining rule, we remember
// the mapping from (relational) parent nodes to indices in the domain product of the CPF
// we are instantiating
Vector<Pair<RelationalNode, Map<BeliefNode,Integer>>> templateDomprodMap = new Vector<Pair<RelationalNode, Map<BeliefNode,Integer>>>();
// * build up the domain product of the CPF we are constructing by going over all suitable
// templates and all applicable groundings thereof
int domProdIndex = 1;
for(Pair<RelationalNode, Vector<ParentGrounding>> template : suitableTemplates) {
RelationalNode relNode = template.first;
Vector<ParentGrounding> nodeGroundings = template.second;
// for each grounding, instantiate all relevant nodes and maintain the mapping as described above
for(ParentGrounding nodeGrounding : nodeGroundings) {
Map<BeliefNode,Integer> relParent2domprodIndex = new HashMap<BeliefNode,Integer>();
for(Entry<Integer,String[]> entry : nodeGrounding.nodeArgs.entrySet()) {
RelationalNode relParent = bln.rbn.getRelationalNode(entry.getKey());
if(relParent == relNode)
continue;
if(relParent.isConstant)
continue;
BeliefNode parent = instantiateVariable(relParent.getFunctionName(), entry.getValue());
if(parent == null) { // we could not instantiate the parent
// this is OK only if the parent is a precondition
if(relParent.isPrecondition)
continue;
throw new Exception("Could not instantiate " + relParent + " with params [" + StringTool.join(", ", entry.getValue()) + "] as a parent for " + mainNode);
}
Integer index = parentIndices.get(parent);
if(index == null) {
index = domProdIndex++;
parentIndices.put(parent, index);
}
relParent2domprodIndex.put(relParent.node, index);
}
templateDomprodMap.add(new Pair<RelationalNode, Map<BeliefNode,Integer>>(relNode, relParent2domprodIndex));
}
}
// initialize CPF & connect parents
CPT cpf = (CPT)mainNode.getCPF();
BeliefNode[] domprod = new BeliefNode[1 + parentIndices.size()];
domprod[0] = mainNode;
for(Entry<BeliefNode, Integer> e : parentIndices.entrySet()) {
domprod[e.getValue()] = e.getKey();
this.groundBN.connect(e.getKey(), mainNode, false);
}
cpf.buildZero(domprod, false);
// fill the CPF
if(debug) System.out.println(" combined domain is " + StringTool.join(", ", domprod));
fillCPFCombiningRule(cpf, 1, new int[domprod.length], templateDomprodMap, r);
return mainNode;
}
protected void fillCPFCombiningRule(CPF cpf, int i, int[] addr, Vector<Pair<RelationalNode, Map<BeliefNode,Integer>>> templateDomprodMap, CombiningRule r) throws Exception {
BeliefNode[] domprod = cpf.getDomainProduct();
if(i == domprod.length) {
int domSize = domprod[0].getDomain().getOrder();
if(r.booleanSemantics) {
if(domSize != 2)
throw new Exception("Cannot apply combining-rule " + r + " with Boolean semantics to non-binary random variable " + domprod[0]);
double trueCase = fillCPFCombiningRule_computeColumnEntry(0, addr, templateDomprodMap, r);
cpf.put(addr, new ValueDouble(trueCase));
addr[0] = 1;
cpf.put(addr, new ValueDouble(1.0-trueCase));
}
else { // normalization semantics
double[] values = new double[domSize];
double Z = 0.0;
for(int j = 0; j < domSize; j++) {
values[j] = fillCPFCombiningRule_computeColumnEntry(j, addr, templateDomprodMap, r);
Z += values[j];
}
for(int j = 0; j < domSize; j++) {
values[j] /= Z;
addr[0] = j;
cpf.put(addr, new ValueDouble(values[j]));
}
}
return;
}
int domSize = domprod[i].getDomain().getOrder();
for(int domIdx = 0; domIdx < domSize; domIdx++) {
addr[i] = domIdx;
fillCPFCombiningRule(cpf, i+1, addr, templateDomprodMap, r);
}
}
protected double fillCPFCombiningRule_computeColumnEntry(int idx0, int[] addr, Vector<Pair<RelationalNode, Map<BeliefNode,Integer>>> templateDomprodMap, CombiningRule r) {
// collect values from individual CPFs
addr[0] = idx0;
Vector<Double> values = new Vector<Double>();
for(Pair<RelationalNode, Map<BeliefNode, Integer>> m : templateDomprodMap) {
RelationalNode relNode = m.first;
CPF cpf2 = relNode.node.getCPF();
BeliefNode[] domprod2 = cpf2.getDomainProduct();
int[] addr2 = new int[domprod2.length];
addr2[0] = addr[0];
for(int i2 = 1; i2 < domprod2.length; i2++) {
Integer i1 = m.second.get(domprod2[i2]);
if(i1 != null)
addr2[i2] = addr[i1];
else // this case applies to decision parents and precondition parents that were not instantiated
addr2[i2] = 0; // 0 corresponds to True
}
Double v = cpf2.getDouble(addr2);
values.add(v);
}
return r.compute(values);
}
protected void init() {}
protected abstract void groundFormulaicNodes() throws Exception;
protected abstract void onAddGroundAtomNode(BeliefNode instance, String[] params, Signature sig);
protected abstract void onAddEvidenceVariable(String functionName, String[] params);
protected void onAddAuxiliaryNode(BeliefNode var, boolean isBoolean, String functionName, String[] params) {}
public Database getDatabase() {
return db;
}
/**
* connects the parents given by the grounding to the target node but does *not* initialize the CPF
* @param parentGrounding
* @param srcRelNode the relational node that is to serve as the template for the target node
* @param targetNode the node in the ground network to connect the parents to
* @param src2targetParent a mapping in which to store which node in the template model produced which instantiated parent in the ground network (or null)
* @param constantSettings a mapping in which to store bindings of constants (or null)
* @return the full domain of the target node's CPF
* @throws Exception
*/
protected Vector<BeliefNode> connectParents(Map<Integer, String[]> parentGrounding, RelationalNode srcRelNode, BeliefNode targetNode, HashMap<BeliefNode, BeliefNode> src2targetParent, HashMap<BeliefNode, Integer> constantSettings) throws Exception {
Vector<BeliefNode> domprod = new Vector<BeliefNode>();
domprod.add(targetNode);
HashSet<BeliefNode> handledTargetParents = new HashSet<BeliefNode>();
for(Entry<Integer, String[]> entry : parentGrounding.entrySet()) {
RelationalNode relParent = bln.rbn.getRelationalNode(entry.getKey());
if(relParent == srcRelNode)
continue;
if(relParent.isConstant) {
//System.out.println("Constant node: " + parent.getName() + " = " + entry.getValue()[0]);
if(constantSettings != null)
constantSettings.put(relParent.node, ((Discrete)relParent.node.getDomain()).findName(entry.getValue()[0]));
continue;
}
if(relParent.isPrecondition) {
if(constantSettings != null)
constantSettings.put(relParent.node, 0); // precondition nodes are always true
continue;
}
BeliefNode parent = instantiateVariable(relParent.getFunctionName(), entry.getValue());
if(parent == null)
throw new Exception("Error instantiating parent '" + Signature.formatVarName(relParent.getFunctionName(), entry.getValue()) + "' while instantiating " + targetNode);
if(handledTargetParents.contains(parent))
throw new Exception("Error instantiating " + targetNode + " from " + srcRelNode + ": Duplicate parent " + parent);
//System.out.println("Connecting " + parent.getName() + " to " + targetNode.getName());
handledTargetParents.add(parent);
groundBN.connect(parent, targetNode, false);
domprod.add(parent);
if(src2targetParent != null) src2targetParent.put(relParent.node, parent);
}
return domprod;
}
/**
* connects the parents given by the grounding to the target node and transfers the (correct part of the) CPF to the target node
* @param parentGrounding a grounding (mapping of indices of relational nodes to an array of actual parameters)
* @param srcRelNode relational node that the CPF is to be copied from
* @param targetNode the target node to connect parents to and whose CPF is to be written
* @throws Exception
*/
protected void instantiateCPF(Map<Integer, String[]> parentGrounding, RelationalNode srcRelNode, BeliefNode targetNode) throws Exception {
// connect parents, determine domain products, and set constant nodes (e.g. "x") to their respective constant value
HashMap<BeliefNode, BeliefNode> src2targetParent = new HashMap<BeliefNode, BeliefNode>();
HashMap<BeliefNode, Integer> constantSettings = new HashMap<BeliefNode, Integer>();
Vector<BeliefNode> vDomProd = connectParents(parentGrounding, srcRelNode, targetNode, src2targetParent, constantSettings);
// set decision nodes as constantly true
BeliefNode[] srcDomainProd = srcRelNode.node.getCPF().getDomainProduct();
for(int i = 1; i < srcDomainProd.length; i++) {
if(srcDomainProd[i].getType() == BeliefNode.NODE_DECISION)
constantSettings.put(srcDomainProd[i], 0); // 0 = True
}
// get the correct domain product order (which must reflect the order in the source node)
CPT targetCPF = (CPT)targetNode.getCPF();
BeliefNode[] targetDomainProd = vDomProd.toArray(new BeliefNode[vDomProd.size()]);
int j = 1;
HashSet<BeliefNode> handledParents = new HashSet<BeliefNode>();
for(int i = 1; i < srcDomainProd.length; i++) {
BeliefNode targetParent = src2targetParent.get(srcDomainProd[i]);
//System.out.println("Parent corresponding to " + srcDomainProd[i].getName() + " is " + targetParent);
if(targetParent != null) {
if(handledParents.contains(targetParent))
throw new Exception("Cannot instantiate " + targetNode + " using template " + srcRelNode + ": Duplicate parent " + targetParent);
if(j >= targetDomainProd.length)
throw new Exception("Domain product of " + targetNode + " too small; size = " + targetDomainProd.length + "; tried to add " + targetParent + "; already added " + StringTool.join(",", targetDomainProd));
targetDomainProd[j++] = targetParent;
handledParents.add(targetParent);
}
}
if(j != targetDomainProd.length)
throw new Exception("CPF domain product not fully filled: handled " + j + ", needed " + targetDomainProd.length);
// transfer the CPF values
String cpfID = Integer.toString(srcRelNode.index);
// - if the original relational node had exactly the same number of parents as the instance,
// we can safely transfer its CPT to the instantiated node
if(srcDomainProd.length == targetDomainProd.length) {
targetCPF.build(targetDomainProd, ((CPT)srcRelNode.node.getCPF()).getValues());
}
// - otherwise we must extract the relevant columns that apply to the constant setting
else {
Value[] subCPF;
// get the subpart from the cache if possible
cpfID += constantSettings.toString();
subCPF = cpfCache.get(cpfID);
if(subCPF == null) {
subCPF = getSubCPFValues(srcRelNode.node.getCPF(), constantSettings);
cpfCache.put(cpfID, subCPF);
}
targetCPF.build(targetDomainProd, subCPF);
}
cpfIDs.put(targetNode, cpfID);
/*
// print domain products (just to check)
BeliefNode n = srcRelNode.node;
System.out.println("\nsrc:");
BeliefNode[] domProd = n.getCPF().getDomainProduct();
for(int i = 0; i < domProd.length; i++) {
System.out.println(" " + domProd[i].getName());
}
System.out.println("target:");
n = targetNode;
domProd = n.getCPF().getDomainProduct();
for(int i = 0; i < domProd.length; i++) {
System.out.println(" " + domProd[i].getName());
}
System.out.println();
*/
}
/**
* gets the values of the sub-CPF that one obtains if some of the parents have fixed values
* @param cpf the CPF to extract from
* @param constantSettings fixed values for some of the parents
* @return
*/
public static Value[] getSubCPFValues(CPF cpf, HashMap<BeliefNode, Integer> constantSettings) {
BeliefNode[] domProd = cpf.getDomainProduct();
int[] addr = new int[domProd.length];
Vector<Value> v = new Vector<Value>();
getSubCPFValues(cpf, constantSettings, 0, addr, v);
return v.toArray(new Value[0]);
}
protected static void getSubCPFValues(CPF cpf, HashMap<BeliefNode, Integer> constantSettings, int i, int[] addr, Vector<Value> ret) {
BeliefNode[] domProd = cpf.getDomainProduct();
if(i == domProd.length) {
ret.add(cpf.get(addr));
return;
}
BeliefNode n = domProd[i];
// if we have the setting of the i-th node, use it
Integer setting = constantSettings.get(n);
if(setting != null) {
addr[i] = setting;
getSubCPFValues(cpf, constantSettings, i+1, addr, ret);
}
// otherwise consider all possible settings
else {
Domain d = domProd[i].getDomain();
for(int j = 0; j < d.getOrder(); j++) {
addr[i] = j;
getSubCPFValues(cpf, constantSettings, i+1, addr, ret);
}
}
}
/**
* abstract base class for filling a CPF that is determined by a combination function
* @author Dominik Jain
*/
public abstract class CPFFiller {
CPF cpf;
BeliefNode[] nodes;
public CPFFiller(BeliefNode node) {
cpf = node.getCPF();
nodes = cpf.getDomainProduct();
}
public void fill() throws Exception {
int[] addr = new int[nodes.length];
fill(0, addr);
}
protected void fill(int iNode, int[] addr) throws Exception {
// if all parents have been set, determine the truth value of the formula and
// fill the corresponding entry of the CPT
if(iNode == nodes.length) {
cpf.put(addr, new ValueDouble(getValue(addr)));
return;
}
Discrete domain = (Discrete)nodes[iNode].getDomain();
// - recursively consider all settings
for(int i = 0; i < domain.getOrder(); i++) {
// set address
addr[iNode] = i;
// recurse
fill(iNode+1, addr);
}
}
protected abstract double getValue(int[] addr);
}
/**
* CPF filler for simple OR of boolean nodes
* @author Dominik Jain
*/
public class CPFFiller_AND extends CPFFiller {
public CPFFiller_AND(BeliefNode node) {
super(node);
}
@Override
protected double getValue(int[] addr) {
// AND of boolean nodes: if one of the nodes is true (0), it is true
boolean isTrue = true;
for(int i = 1; i < addr.length; i++)
isTrue = isTrue && addr[i] == 0;
return (addr[0] == 0 && isTrue) || (addr[0] == 1 && !isTrue) ? 1.0 : 0.0;
}
}
/**
* CPF filler for simple OR of boolean nodes
* @author Dominik Jain
*/
public class CPFFiller_OR extends CPFFiller {
public CPFFiller_OR(BeliefNode node) {
super(node);
}
@Override
protected double getValue(int[] addr) {
// OR of boolean nodes: if one of the nodes is true (0), it is true
boolean isTrue = false;
for(int i = 1; i < addr.length; i++)
isTrue = isTrue || addr[i] == 0;
return (addr[0] == 0 && isTrue) || (addr[0] == 1 && !isTrue) ? 1.0 : 0.0;
}
}
/**
* CPF filler for disjunction of conjunction of boolean nodes
* @author Dominik Jain
*/
public class CPFFiller_ORGrouped extends CPFFiller {
int groupSize;
/**
*
* @param node node whose CPF to fill
* @param groupSize number of consecutive parents that make up a group representing a conjunction
*/
public CPFFiller_ORGrouped(BeliefNode node, int groupSize) {
super(node);
this.groupSize = groupSize;
}
@Override
protected double getValue(int[] addr) {
// disjunction of conjunction of boolean nodes (each conjunction is of groupSize)
// order in boolean domains is 0=True, 1=False
boolean isTrue = false;
int g = 0;
for(int i = 1; i < addr.length;) {
if((i-1) % groupSize == 0) {
if(isTrue)
break;
}
isTrue = addr[i] == 0;
if(!isTrue) { // skip to next conjunction
++g;
i = 1 + g * groupSize;
continue;
}
++i;
}
return (addr[0] == 0 && isTrue) || (addr[0] == 1 && !isTrue) ? 1.0 : 0.0;
}
}
/**
* CPF filler for simple SUM of real nodes
* @author meyerphi
*/
public class CPFFiller_SUM extends CPFFiller {
public CPFFiller_SUM(BeliefNode node) {
super(node);
}
@Override
protected double getValue(int[] addr) {
// SUM of real nodes: add up all nodes
double sum = 0.0;
for(int i = 0; i < addr.length; i++)
sum += addr[i];
return sum;
}
}
public void show() {
groundBN.show();
}
/**
* adds to the given evidence the evidence that is implied by the hard formulaic constraints (since all of them must be true)
* @param evidence an array of 2-element arrays containing node name and value
* @return a list of domain indices for each node in the network (-1 for no evidence)
*/
public int[] getFullEvidence(String[][] evidence) {
String[][] fullEvidence = new String[evidence.length+this.hardFormulaNodes.size()][2];
for(int i = 0; i < evidence.length; i++) {
fullEvidence[i][0] = evidence[i][0];
fullEvidence[i][1] = evidence[i][1];
}
int i = evidence.length;
for(BeliefNode node : hardFormulaNodes) {
fullEvidence[i][0] = node.getName();
fullEvidence[i][1] = BooleanDomain.True;
i++;
}
return evidence2DomainIndices(groundBN, fullEvidence);
}
/**
* converts variables-value pairs to list of domain indices.
* (Copy from BeliefNetworkEx that ignores superfluous evidence on evidence functions)
* @param bn
* @param evidences
* @return
*/
protected int[] evidence2DomainIndices(BeliefNetworkEx bn, String[][] evidences) {
BeliefNode[] nodes = bn.getNodes();
int[] evidenceDomainIndices = new int[nodes.length];
Arrays.fill(evidenceDomainIndices, -1);
for (String[] evidence: evidences) {
if(evidence == null || evidence.length != 2)
throw new IllegalArgumentException("Evidences not in the correct format: "+Arrays.toString(evidence)+"!");
int nodeIdx = bn.getNodeIndex(evidence[0]); // TODO inefficient linear search
if (nodeIdx < 0) {
Pair<String,String[]> p = Signature.parseVarName(evidence[0]);
if(this.bln.isEvidenceFunction(p.first))
continue;
else {
String error = "Variable with the name "+ evidence[0]+" not found in model but mentioned in evidence!";
System.err.println("Warning: " + error);
continue;
}
}
/*if (evidenceDomainIndices[nodeIdx] > 0)
logger.warn("Evidence "+evidence[0]+" set twice!");*/
Discrete domain = (Discrete)nodes[nodeIdx].getDomain();
int domainIdx = domain.findName(evidence[1]);
if (domainIdx < 0) {
if (domain instanceof Discretized) {
try {
double value = Double.parseDouble(evidence[1]);
String domainStr = ((Discretized)domain).getNameFromContinuous(value);
domainIdx = domain.findName(domainStr);
} catch (Exception e) {
throw new IllegalArgumentException("Cannot find evidence value "+evidence[1]+" in domain "+domain+"!");
}
}
else {
throw new IllegalArgumentException("Cannot find evidence value "+evidence[1]+" in domain "+domain+" of node " + nodes[nodeIdx].getName());
}
}
evidenceDomainIndices[nodeIdx]=domainIdx;
}
return evidenceDomainIndices;
}
public BeliefNetworkEx getGroundNetwork() {
return this.groundBN;
}
/**
* gets the unique identifier of the CPF that is associated with the given ground node of the network
* @param node
* @return
*/
public String getCPFID(BeliefNode node) {
String cpfID = cpfIDs.get(node);
return cpfID;
}
public void setDebugMode(boolean enabled) {
this.debug = enabled;
}
public RelationalBeliefNetwork getRBN() {
return bln.rbn;
}
/**
* gets the template (fragment variable) used to instantiate the given ground node
* @param node
* @return
*/
public RelationalNode getTemplateOf(BeliefNode node) {
return this.groundNode2TemplateNode.get(node);
}
/**
* gets the collection of auxiliary nodes (nodes added for hard formula constraints) contained in this network
* @return
*/
public Vector<BeliefNode> getAuxiliaryVariables() {
return this.hardFormulaNodes;
}
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
public ParameterHandler getParameterHandler() {
return paramHandler;
}
}