/******************************************************************************* * Copyright (C) 2007-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.srl.directed; import java.io.File; import java.io.PrintStream; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.Vector; import probcog.bayesnets.core.BeliefNetworkEx; import probcog.logic.Formula; import probcog.logic.GroundAtom; import probcog.srl.RelationKey; import probcog.srl.RelationalModel; import probcog.srl.Signature; import probcog.srl.directed.RelationalNode.Aggregator; import probcog.srl.mln.MLNWriter; import probcog.srl.taxonomy.Taxonomy; import probcog.srldb.Database; import edu.ksu.cis.bnj.ver3.core.BeliefNode; import edu.ksu.cis.bnj.ver3.core.CPF; import edu.ksu.cis.bnj.ver3.core.Discrete; import edu.ksu.cis.bnj.ver3.core.Domain; import edu.ksu.cis.bnj.ver3.core.values.ValueDouble; import edu.tum.cs.util.StringTool; import edu.tum.cs.util.datastruct.CollectionFilter; import edu.tum.cs.util.datastruct.MultiIterator; /** * Represents a relational belief network. * @author Dominik Jain (core) * @author Philipp Meyer (added utility nodes) */ public class RelationalBeliefNetwork extends BeliefNetworkEx implements RelationalModel { /** * maps a node index to the corresponding extended node */ protected HashMap<Integer,ExtendedNode> extNodesByIdx; /** * maps a function/predicate name to the signature of the corresponding function */ protected Map<String, Signature> signatures; /** * maps the name of a fixed domain to the vector of elements it contains */ protected HashMap<String, List<String>> guaranteedDomElements; /** * a mapping of function/relation names to RelationKey objects which signify argument groups that are keys of the relation (which may be used for a functional lookup) */ protected Map<String, Collection<RelationKey>> relationKeys; /** * maps function/relation names to combining rules */ protected Map<String, CombiningRule> combiningRules = new HashMap<String, CombiningRule>(); /** * list of of function names for which a uniform distribution is assumed by default if no * fragment is found. */ protected Vector<String> uniformDefaultFunctions = new Vector<String>(); protected Taxonomy taxonomy = null; protected Vector<String> prologRules = new Vector<String>(); public Collection<RelationKey> getRelationKeys(String relation) { return relationKeys.get(relation.toLowerCase()); } /** * constructs an empty relational belief network * @throws Exception */ public RelationalBeliefNetwork() throws Exception { super(); extNodesByIdx = new HashMap<Integer, ExtendedNode>(); signatures = new HashMap<String, Signature>(); relationKeys = new HashMap<String, Collection<RelationKey>>(); guaranteedDomElements = new HashMap<String, List<String>>(); } /** * instantiates a relational belief network from a fragment network * @param networkFile * @throws Exception */ public RelationalBeliefNetwork(File networkFile) throws Exception { this(); initNetwork(networkFile); } protected void initNetwork(File networkFile) throws Exception { super.initNetwork(networkFile.toString()); BeliefNode[] nodes = bn.getNodes(); for(int i = 0; i < nodes.length; i++) { ExtendedNode n = createNode(nodes[i]); addExtendedNode(n); } } /** * creates a relational node from the given belief node * @param node * @return * @throws Exception */ protected ExtendedNode createNode(BeliefNode node) throws Exception { switch(node.getType()) { case BeliefNode.NODE_CHANCE: return new RelationalNode(this, node); case BeliefNode.NODE_DECISION: return new DecisionNode(this, node); case BeliefNode.NODE_UTILITY: return new UtilityNode(this, node); default: throw new Exception("Don't know how to treat node " + node.getName() + " of type " + node.getType()); } } public void addExtendedNode(ExtendedNode node) { extNodesByIdx.put(node.index, node); } /** * gets the first relational node where the entire node label matches the given name * @param name * @return */ public RelationalNode getRelationalNode(String name) { BeliefNode node = this.getNode(name); if(node == null) return null; return getRelationalNode(this.getNodeIndex(node)); } public RelationalNode getRelationalNode(int idx) { return (RelationalNode)getExtendedNode(idx); } public RelationalNode getRelationalNode(BeliefNode node) { return (RelationalNode)getExtendedNode(node); } public ExtendedNode getExtendedNode(int idx) { return extNodesByIdx.get(new Integer(idx)); } public ExtendedNode getExtendedNode(BeliefNode node) { return getExtendedNode(this.getNodeIndex(node)); } public Collection<ExtendedNode> getExtendedNodes() { return extNodesByIdx.values(); } public Iterable<RelationalNode> getRelationalNodes() { return new CollectionFilter<RelationalNode, ExtendedNode>(getExtendedNodes(), RelationalNode.class); } public static boolean isBooleanDomain(Domain domain) { if(!(domain instanceof Discrete)) return false; int order = domain.getOrder(); if(order > 2 || order <= 0) return false; if(domain.getOrder() == 1) { if(domain.getName(0).equalsIgnoreCase("true") || domain.getName(0).equalsIgnoreCase("false")) return true; return false; } if(domain.getName(0).equalsIgnoreCase("true") || domain.getName(1).equalsIgnoreCase("true")) return true; return false; } public static boolean isRealDomain(Domain domain) { if(!(domain instanceof Discrete)) return false; int order = domain.getOrder(); return order == 1; } /** * obtains the names of parents of the variable that is given by a node name and its actual arguments * @param nodeName * @param actualArgs * @return an array of variable names * @throws Exception * TODO this should be rewritten with ParentGrounder */ public String[] getParentVariableNames(RelationalNode node, String[] actualArgs) throws Exception { RelationalNode child = node; BeliefNode[] parents = bn.getParents(child.node); String[] ret = new String[parents.length]; for(int i = 0; i < parents.length; i++) { RelationalNode parent = getRelationalNode(getNodeIndex(parents[i].getName())); StringBuffer varName = new StringBuffer(parent.getFunctionName() + "("); String param = null; for(int iCur = 0; iCur < parent.params.length; iCur++) { for(int iMain = 0; iMain < child.params.length; iMain++) { if(child.params[iMain].equals(parent.params[iCur])) { param = actualArgs[iMain]; break; } } if(param == null) throw new Exception(String.format("Could not determine parameters of parent '%s' for node '%s'", parent.getFunctionName(), node.getFunctionName() + actualArgs.toString())); varName.append(param); if(iCur < parent.params.length-1) varName.append(","); } varName.append(")"); ret[i] = varName.toString(); } return ret; } public void addSignature(Signature sig) throws Exception { String key = sig.functionName; //.toLowerCase() Signature old = signatures.get(key); if(old != null) throw new Exception("Duplicate signature definition for '" + sig.functionName + "'; previously defined as " + old + ", now defined as " + sig); signatures.put(key, sig); } /** * retrieves the signature of a function/predicate * @param functionName the name of the function/predicate * @return a Signature object */ public Signature getSignature(String functionName) { return signatures.get(functionName/*.toLowerCase()*/); } public Set<String> getFunctionNames() { return signatures.keySet(); } public Signature getSignature(RelationalNode node) { return signatures.get(node.getFunctionName()); } public Collection<Signature> getSignatures() { return signatures.values(); } /** * replace a type by a new type in all function signatures * @param oldType * @param newType */ public void replaceType(String oldType, String newType) { for(Signature sig : getSignatures()) sig.replaceType(oldType, newType); } /** * guesses the model's function signatures by assuming the same type whenever the same variable name is used (ignoring any numeric suffixes), setting the domain name to ObjType_x when x is the variable, * and assuming a different domain of return values for each node, using Dom{NodeName} as the domain name. * @throws Exception * */ public void guessSignatures() throws Exception { for(RelationalNode node : getRelationalNodes()) { if(node.isConstant) // signatures for constants are determined in checkSignatures (called below) continue; String[] argTypes = new String[node.params.length]; for(int i = 0; i < node.params.length; i++) { String param = node.params[i].replaceAll("\\d+", ""); argTypes[i] = "objType_" + param; } String retType = isBooleanDomain(((Discrete)node.node.getDomain())) ? "Boolean" : "dom" + node.getFunctionName(); Signature sig = new Signature(node.getFunctionName(), retType, argTypes); addSignature(sig); } checkSignatures(); // to fill constants } /** * check fragments for type inconsistencies and write return types for constant nodes * @throws Exception */ protected void checkSignatures() throws Exception { for(RelationalNode node : getRelationalNodes()) { if(node.isFragment()) checkFragment(node); } } /** * checks the given fragment for type inconsistencies * @param fragment * @param relevantNodes * @throws Exception */ protected void checkFragment(RelationalNode fragment) throws Exception { MultiIterator<RelationalNode> relevantNodes = new MultiIterator<RelationalNode>(); relevantNodes.add(fragment); relevantNodes.add(getRelationalParents(fragment)); HashMap<String,String> types = new HashMap<String,String>(); // parameter/argument -> type name mapping for(RelationalNode node : relevantNodes) { if(node.isBuiltInPred()) continue; if(node.isConstant) { // update constant return types using the mapping String type = types.get(node.getFunctionName()); if(type == null) // constants that were referenced by any of their parents must now have a type assigned throw new Exception("Constant " + node + " not referenced and therefore not typed."); node.constantType = type; } else { Signature sig = getSignature(node); if(sig == null) { throw new Exception("Node " + node + " has no signature!"); } // check for the right number of arguments and their types if(sig.argTypes.length != node.params.length) throw new Exception(String.format("Signature of '%s' is in conflict with node '%s': Signature requires %d arguments, node has %d.", sig.functionName, node.toString(), sig.argTypes.length, node.params.length)); for(int i = 0; i < node.params.length; i++) { String key = node.params[i]; String prevType = types.get(key); if(prevType == null) types.put(key, sig.argTypes[i]); else { if(!prevType.equals(sig.argTypes[i])) { boolean error = true; if (taxonomy != null && (taxonomy.query_isa(prevType, sig.argTypes[i]) || taxonomy.query_isa(sig.argTypes[i], prevType))) error = false; if(error) throw new Exception(String.format("Type mismatch while processing fragment '%s': '%s' has incompatible types '%s' and '%s'", fragment.getName(), key, prevType, sig.argTypes[i])); } } } } } } /** * gets all the parents of the given node that are instances of RelationalNode * @param node * @return */ public Vector<RelationalNode> getRelationalParents(RelationalNode node) { BeliefNode[] p = this.bn.getParents(node.node); Vector<RelationalNode> ret = new Vector<RelationalNode>(); for(int i = 0; i < p.length; i++) { ExtendedNode n = getExtendedNode(p[i]); if(n instanceof RelationalNode) ret.add((RelationalNode)n); } return ret; } /** * @deprecated * converts the network to a Markov logic network * @param out the stream to write to * @param compactFormulas whether to write CPTs more compactly by first learning a classification tree * @param numericWeights whether to print weighs as numbers (if false, print as log(x)) * @throws Exception */ public void toMLN(PrintStream out, boolean declarationsOnly, boolean compactFormulas, boolean numericWeights) throws Exception { MLNWriter writer = new MLNWriter(out); // write domain declarations out.println("// domain declarations"); HashSet<String> handled = new HashSet<String>(); HashMap<String, Vector<String>> domains = new HashMap<String,Vector<String>>(); for(RelationalNode node : getRelationalNodes()) { Signature sig = getSignature(node.getFunctionName()); if(sig == null) continue; if(sig.returnType.equals("Boolean")) continue; if(handled.contains(sig.returnType)) continue; handled.add(sig.returnType); Vector<String> d = new Vector<String>(); out.printf("%s = {", Database.lowerCaseString(sig.returnType)); String[] dom = getDiscreteDomainAsArray(node.node); for(int i = 0; i < dom.length; i++) { if(i > 0) out.print(", "); String elem = Database.upperCaseString(dom[i]); out.print(elem); d.add(elem); } out.println("}"); domains.put(sig.returnType, d); } out.println(); // write predicate declarations out.println("// predicate declarations"); Set<Signature> handledSigs = new HashSet<Signature>(); for(RelationalNode node : getRelationalNodes()) { if(node.isConstant) continue; Signature sig = node.getSignature(); if(sig == null) continue; if(handledSigs.contains(sig)) continue; handledSigs.add(sig); String[] argTypes; if(sig.returnType.equals("Boolean")) argTypes = new String[sig.argTypes.length]; else { argTypes = new String[sig.argTypes.length + 1]; argTypes[argTypes.length-1] = Database.lowerCaseString(sig.returnType); } for(int i = 0; i < sig.argTypes.length; i++) { if(sig.argTypes[i].length() == 0) throw new Exception("Parameter " + i + " of " + sig.functionName + " has empty type: " + sig); argTypes[i] = Database.lowerCaseString(sig.argTypes[i]); } out.printf("%s(%s)\n", Database.lowerCaseString(sig.functionName), StringTool.join(", ", argTypes)); } out.println(); // mutual exclusiveness and exhaustiveness out.println("// mutual exclusiveness and exhaustiveness"); // - non-boolean nodes for(RelationalNode node : getRelationalNodes()) { if(node.isConstant || node.isAuxiliary) continue; if(!node.isBoolean()) { out.print(Database.lowerCaseString(node.getFunctionName())); out.print('('); for(int i = 0; i <= node.params.length; i++) { if(i > 0) out.print(", "); out.printf("a%d", i); if(i == node.params.length) out.print('!'); } out.println(")"); } } // TODO - add constraints for functional dependencies in relations out.println(); if(declarationsOnly) return; // write formulas (and auxiliary predicate definitions for special nodes) int[] order = getTopologicalOrder(); for(int i = 0; i < order.length; i++) { RelationalNode node = getRelationalNode(order[i]); if(node.isConstant || node.isAuxiliary) continue; CPT2MLNFormulas converter = new CPT2MLNFormulas(node); // write auxiliary definitions and formulas required by certain node types if(node.aggregator != null && node.parentMode != null) { if(node.aggregator == Aggregator.Average && node.parentMode.equals("CP")) { // average of conditional probabilities // get the relation that is responsible for grounding the free parameters RelationalNode rel = node.getFreeParamGroundingParent(); if(rel == null) throw new Exception("Could not determine relevant relational parent"); // add predicate declaration for influence factor String influenceRelation = String.format("inflfac_%s_%s", node.getFunctionName(), rel.getFunctionName()); Signature sig = rel.getSignature(); writer.writePredicateDecl(influenceRelation, sig.argTypes, null); // write mutual exclusiveness and exhaustiveness definition writer.writeMutexDecl(influenceRelation, rel.params, node.addParams); // add a precondition that must be added to each CPT formula converter.addPrecondition(Signature.formatVarName(influenceRelation, rel.params)); // write the formula connecting the influence predicate to the regular relation: if the relation does not hold, there is no influence out.println("!" + rel.getCleanName() + " => !" + Signature.formatVarName(influenceRelation, rel.params) + "."); } else if(node.aggregator == Aggregator.NoisyOr) { // noisy or } } // write conditional probability distribution if(node.hasCPT()) { out.println("// CPT for " + node.getName()); out.println("// <group>"); if(compactFormulas) { // convert using decision trees for compactness converter.convert(out); } else { // old method: direct conversion CPF cpf = node.node.getCPF(); int[] addr = new int[cpf.getDomainProduct().length]; walkCPD_MLNformulas(out, cpf, addr, 0, converter.getPrecondition(), numericWeights); } out.println("// </group>\n"); } else { if(node.aggregator == Aggregator.NoisyOr) { out.print(MLNWriter.formatAsAtom(node.getCleanName()) + " <=> "); int k = 0; for(RelationalNode parent : node.getParents()) { // get the parameters that are free in this parent Vector<String> freeparams = new Vector<String>(); for(String p : node.addParams) if(parent.hasParam(p)) freeparams.add(p); if(freeparams.isEmpty()) continue; // print the condition if(k++ > 0) out.print(" v "); out.print("EXIST " + StringTool.join(",", freeparams.toArray(new String[0])) + " " + MLNWriter.formatAsAtom(parent.toAtom())); } if(k == 0) throw new Exception("None of the parents of OR-node " + node + " handle any of the free parameters."); out.println(); } } } } /** * converts the network to a Markov logic network * @param out the stream to write to * @param compactFormulas whether to write CPTs more compactly by first learning a classification tree * @param numericWeights whether to print weighs as numbers (if false, print as log(x)) * @throws Exception */ public void toMLN(MLNConverter converter, boolean declarationsOnly, boolean compactFormulas) throws Exception { // domain declarations System.out.printf("Converting %d domains...\n", this.getGuaranteedDomainElements().size()); for(java.util.Map.Entry<String, List<String>> e : this.getGuaranteedDomainElements().entrySet()) { converter.addGuaranteedDomainElements(e.getKey(), e.getValue()); } // predicate declarations System.out.printf("Converting %d function declarations...\n", this.getSignatures().size()); for(Signature sig : this.getSignatures()) { Signature mlnSig = sig; if(!mlnSig.isBoolean()) { String[] argTypes = new String[sig.argTypes.length+1]; for(int i = 0; i < sig.argTypes.length; i++) argTypes[i] = sig.argTypes[i]; argTypes[argTypes.length-1] = sig.returnType; mlnSig = new Signature(sig.functionName, "Boolean", argTypes); converter.addFunctionalDependency(sig.functionName, argTypes.length-1); } converter.addSignature(mlnSig); } if(declarationsOnly) return; // write formulas (and auxiliary predicate definitions for special nodes) int[] order = getTopologicalOrder(); System.out.printf("Converting %d nodes...\n", order.length); for(int i = 0; i < order.length; i++) { ExtendedNode extNode = getExtendedNode(order[i]); if(!(extNode instanceof RelationalNode)) continue; RelationalNode node = (RelationalNode) extNode; if(!node.isFragment()) continue; // TODO auxiliary definitions and formulas required by certain node types?? // write conditional probability distribution if(!node.hasAggregator()) { converter.beginCPT(node); CPT2MLNFormulas cptconverter = new CPT2MLNFormulas(node); if(compactFormulas) { // convert using decision trees for compactness throw new RuntimeException("Compact formulas not yet supported"); } else { // old method: direct conversion CPF cpf = node.node.getCPF(); int[] addr = new int[cpf.getDomainProduct().length]; walkCPD_MLNformulas(converter, cpf, addr, 0, cptconverter.getPrecondition(), true); } converter.endCPT(); } // for aggregators, consider specific conversion else { Formula f = node.toFormula(null); if(f == null) throw new Exception("Don't know how to generate a formula for " + node); converter.addHardFormula(f); } } } protected void walkCPD_MLNformulas(MLNConverter converter, CPF cpf, int[] addr, int i, String precondition, boolean numericWeights) throws Exception { BeliefNode[] nodes = cpf.getDomainProduct(); if(i == addr.length) { // we have a complete address // collect values of constants in order to replace references to them in the individual predicates HashMap<String,String> constantValues = new HashMap<String,String>(); for(int j = 0; j < addr.length; j++) { ExtendedNode extNode = getExtendedNode(nodes[j]); if(extNode instanceof RelationalNode) { RelationalNode rn = getRelationalNode(nodes[j]); if(rn.isConstant) { String value = ((Discrete)rn.node.getDomain()).getName(addr[j]); constantValues.put(rn.functionName, value); } } } // for each element of the address obtain the corresponding literal/predicate StringBuffer sb = new StringBuffer(); for(int j = 0; j < addr.length; j++) { String conjunct = null; ExtendedNode extNode = getExtendedNode(nodes[j]); if(extNode instanceof DecisionNode) { conjunct = extNode.toString(); } else { RelationalNode rn = (RelationalNode)extNode; if(!rn.isConstant) conjunct = rn.toLiteralString(addr[j], constantValues); } if(conjunct != null) { if(sb.length() > 0) sb.append(" ^ "); sb.append(conjunct); } } if(precondition != null) { sb.append(" ^ " + precondition); } // get the weight double weight = Math.log(cpf.getDouble(addr)); if(Double.isInfinite(weight)) weight = -100.0; // print weight and formula Formula f; try { f = Formula.fromString(sb.toString()); } catch(Error e) { System.err.println("Error parsing formula: " + sb.toString()); throw e; } catch(Exception e) { System.err.println("Error parsing formula: " + sb.toString()); throw e; } converter.addFormula(f, weight); } else { // the address is yet incomplete -> consider all ways of setting the next e // if the node is a necessary precondition for the child node, there is only one possible setting (True) boolean isPrecondition = false; ExtendedNode extNode = getExtendedNode(nodes[i]); RelationalNode node; if(extNode instanceof DecisionNode) isPrecondition = true; else { node = (RelationalNode) extNode; isPrecondition = node.isPrecondition; } Discrete dom = (Discrete)extNode.node.getDomain(); if(isPrecondition) { addr[i] = dom.findName("True"); if(addr[i] == -1) addr[i] = dom.findName("true"); if(addr[i] == -1) throw new Exception("Domain of necessary precondition " + extNode + " must contain either 'True' or 'true'!"); walkCPD_MLNformulas(converter, cpf, addr, i+1, precondition, numericWeights); } // otherwise consider all domain elements else { for(int j = 0; j < dom.getOrder(); j++) { addr[i] = j; walkCPD_MLNformulas(converter, cpf, addr, i+1, precondition, numericWeights); } } } } /** * @deprecated * @param out * @param cpf * @param addr * @param i * @param precondition * @param numericWeights * @throws Exception */ protected void walkCPD_MLNformulas(PrintStream out, CPF cpf, int[] addr, int i, String precondition, boolean numericWeights) throws Exception { BeliefNode[] nodes = cpf.getDomainProduct(); if(i == addr.length) { // we have a complete address // collect values of constants in order to replace references to them in the individual predicates HashMap<String,String> constantValues = new HashMap<String,String>(); for(int j = 0; j < addr.length; j++) { RelationalNode rn = getRelationalNode(nodes[j]); if(rn.isConstant) { String value = ((Discrete)rn.node.getDomain()).getName(addr[j]); constantValues.put(rn.functionName, value); } } // for each element of the address obtain the corresponding literal/predicate StringBuffer sb = new StringBuffer(); for(int j = 0; j < addr.length; j++) { RelationalNode rn = getRelationalNode(nodes[j]); if(!rn.isConstant) { if(j > 0) sb.append(" ^ "); sb.append(rn.toLiteralString(addr[j], constantValues)); } } if(precondition != null) { sb.append(" ^ " + precondition); } // get the weight int realAddr = cpf.addr2realaddr(addr); double value = ((ValueDouble)cpf.get(realAddr)).getValue(); double weight = Math.log(value); if(Double.isInfinite(weight)) weight = -100.0; // print weight and formula if(numericWeights) out.printf("%f %s\n", weight, sb.toString()); else out.printf("logx(%f) %s\n", value, sb.toString()); } else { // the address is yet incomplete -> consider all ways of setting the next e // if the node is a necessary precondition for the child node, there is only one possible setting (True) RelationalNode node = getRelationalNode(nodes[i]); Discrete dom = (Discrete)node.node.getDomain(); if(node.isPrecondition) { addr[i] = dom.findName("True"); if(addr[i] == -1) addr[i] = dom.findName("true"); if(addr[i] == -1) throw new Exception("Domain of necessary precondition " + node + " must contain either 'True' or 'true'!"); walkCPD_MLNformulas(out, cpf, addr, i+1, precondition, numericWeights); } // otherwise consider all domain elements else { for(int j = 0; j < dom.getOrder(); j++) { addr[i] = j; walkCPD_MLNformulas(out, cpf, addr, i+1, precondition, numericWeights); } } } } public ParentGrounder getParentGrounder(RelationalNode node) throws Exception { return node.getParentGrounder(); } public void addRelationKey(RelationKey k) { Collection<RelationKey> list = relationKeys.get(k.relation.toLowerCase()); if(list == null) { list = new Vector<RelationKey>(); relationKeys.put(k.relation.toLowerCase(), list); } list.add(k); //System.out.println("Key: " + k); //System.out.println(" now: " + this.getRelationKeys(k.relation.toLowerCase())); } /** * prepares this network for learning by materializing additional nodes (e.g. for noisy-or) * @throws Exception */ public void prepareForLearning() throws Exception { for(RelationalNode node : getRelationalNodes()) { if(node.parentMode != null && node.parentMode.equals("AUX")) { // create an auxiliary node that contains the ungrounded parameters // create fully grounded variant String[] params = new String[node.params.length + node.addParams.length]; int i = 0; for(int j = 0; j < node.params.length; j++) params[i++] = node.params[j]; for(int j = 0; j < node.addParams.length; j++) params[i++] = node.addParams[j]; String fullName = Signature.formatVarName(node.getFunctionName() + "_" + StringTool.join("", node.addParams), params); BeliefNode fullyGroundedNode = this.addNode(fullName); fullyGroundedNode.setDomain(node.node.getDomain()); // create the corresponding relational node and define a signature for it RelationalNode fullyGroundedRelNode = new RelationalNode(this, fullyGroundedNode); addExtendedNode(fullyGroundedRelNode); // - determine argument types for signature String[] argTypes = new String[params.length]; Signature origSig = node.getSignature(); Vector<RelationalNode> relParents = getRelationalParents(node); for(int j = 0; j < params.length; j++) { if(j < node.params.length) argTypes[j] = origSig.argTypes[j]; else { // check relational parents for parameter match boolean haveType = false; for(int k = 0; k < relParents.size() && !haveType; k++) { RelationalNode parent = relParents.get(k); Signature sig = parent.getSignature(); for(int l = 0; l < parent.params.length; l++) { if(parent.params[l].equals(params[j])) { argTypes[j] = sig.argTypes[l]; haveType = true; break; } } } if(!haveType) throw new Exception("Could not determine type of free parameter " + params[j]); } } // - add signature addSignature(new Signature(fullyGroundedRelNode.getFunctionName(), origSig.returnType, argTypes)); // connect all parents of this node to the fully grounded version BeliefNode[] parents = this.bn.getParents(node.node); for(BeliefNode parent : parents) { this.bn.disconnect(parent, node.node); this.bn.connect(parent, fullyGroundedNode); } // connect the fully grounded version to the original child this.bn.connect(fullyGroundedNode, node.node); // modify the original child node.parentMode = ""; node.setLabel(); //node.node.setName(RelationalNode.formatName(node.getFunctionName(), node.params)); } } //show(); } @Override public HashMap<String, List<String>> getGuaranteedDomainElements() { return guaranteedDomElements; } public void setGuaranteedDomainElements(String domName, String[] elements) { this.guaranteedDomElements.put(domName, Arrays.asList(elements)); } /** * retrieves the name of the random variable that corresponds to a logical ground atom * @return */ public String gndAtom2VarName(GroundAtom ga) { if(getSignature(ga.predicate).isBoolean()) return ga.toString(); else { StringBuffer s = new StringBuffer(ga.predicate + "("); for(int i = 0; i < ga.args.length-1; i++) { if(i > 0) s.append(','); s.append(ga.args[i]); } s.append(')'); return s.toString(); } } @Override public BeliefNode getNode(String name) { throw new RuntimeException("This method should never be called in relational networks because they may contain several nodes with the same name, so the mapping may not be well-defined."); } @Override public int getNodeIndex(String nodeName) { throw new RuntimeException("This method should never be called in relational networks because they may contain several nodes with the same name."); } public boolean isBoolean(String functionName) { Signature sig = getSignature(functionName); if(sig == null) throw new RuntimeException("No signature was defined for '" + functionName + "'"); return sig.isBoolean(); } public boolean isEvidenceFunction(String functionName) { Signature sig = getSignature(functionName); if(sig == null) throw new RuntimeException("No signature was defined for '" + functionName + "'"); return sig.isLogical; } public Taxonomy getTaxonomy() { return taxonomy; } @Override public Collection<String> getPrologRules() { return prologRules; } public CombiningRule getCombiningRule(String function) { return this.combiningRules.get(function); } public boolean usesUniformDefault(String functionName) { return uniformDefaultFunctions.contains(functionName); } @Override public void addGuaranteedDomainElement(String domain, String element) { throw new UnsupportedOperationException(); } }