/******************************************************************************* * Copyright (C) 2006-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.bayesnets.core; import java.io.BufferedReader; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Random; import java.util.Set; import java.util.Stack; import java.util.regex.Matcher; import java.util.regex.Pattern; import probcog.bayesnets.core.io.Converter_ergo; import probcog.bayesnets.core.io.Converter_hugin; import probcog.bayesnets.core.io.Converter_pmml; import probcog.bayesnets.core.io.Converter_uai; import probcog.bayesnets.core.io.Converter_xmlbif; import probcog.bayesnets.inference.WeightedSample; import edu.ksu.cis.bnj.ver3.core.BeliefNetwork; import edu.ksu.cis.bnj.ver3.core.BeliefNode; import edu.ksu.cis.bnj.ver3.core.CPF; import edu.ksu.cis.bnj.ver3.core.CPT; import edu.ksu.cis.bnj.ver3.core.Discrete; import edu.ksu.cis.bnj.ver3.core.DiscreteEvidence; import edu.ksu.cis.bnj.ver3.core.Domain; import edu.ksu.cis.bnj.ver3.core.values.ValueDouble; import edu.ksu.cis.bnj.ver3.inference.approximate.sampling.ForwardSampling; import edu.ksu.cis.bnj.ver3.inference.exact.Pearl; import edu.ksu.cis.bnj.ver3.plugin.IOPlugInLoader; import edu.ksu.cis.bnj.ver3.streams.Exporter; import edu.ksu.cis.bnj.ver3.streams.Importer; import edu.ksu.cis.bnj.ver3.streams.OmniFormatV1_Reader; import edu.ksu.cis.util.graph.algorithms.TopologicalSort; import edu.ksu.cis.util.graph.core.Graph; import edu.ksu.cis.util.graph.core.Vertex; /** * An instance of class BeliefNetworkEx represents a full Bayesian Network. * It is a wrapper for BNJ's BeliefNetwork class with extended functionality. * BeliefNetwork could not simply be extended by inheritance because virtually all members are * declared private. Therefore, BeliefNetworkEx has a public member bn, which is an instance of * BeliefNetwork. * * @author Dominik Jain * */ public class BeliefNetworkEx { /*static final Logger logger = Logger.getLogger(BeliefNetworkEx.class); static { logger.setLevel(Level.WARN); }*/ static boolean defaultPluginsRegistered = false; /** * The maximum number of unsuccessful trials for sampling. * TODO: This should perhaps depend on the number of samples to be gathered? */ public static final int MAX_TRIALS = 5000; /** * the BNJ BeliefNetwork object that is wrapped by the instance of class BeliefNetworkEx. * When using BNJ directly, you may need this; or you may want to use the methods of BeliefNetwork to perform * an operation on the network that BeliefNetworkEx does not wrap. */ public BeliefNetwork bn; /** * the name of the currently loaded belief network file */ protected String filename; /** * The mapping from attribute names to the node names of nodes that should get data from the attribute. */ protected Map<String, String> nodeNameToAttributeMapping = new HashMap<String, String>(); /** * The inverse mapping of {@link #nodeNameToAttributeMapping}. */ protected Map<String, Set<String>> attributeToNodeNameMapping = new HashMap<String, Set<String>>(); /** * constructs a BeliefNetworkEx object from a BNJ BeliefNetwork object * @param bn the BNJ BeliefNetwork object */ public BeliefNetworkEx(BeliefNetwork bn) { this.bn = bn; initAttributeMapping(); } /** * constructs a BeliefNetworkEx object from a saved network file * @param filename the name of the file to load the network from * @throws Exception */ public BeliefNetworkEx(String filename) throws Exception { initNetwork(filename); } /** * constructs an empty network. Use methods addNode and connect to define the network structure. */ public BeliefNetworkEx() { this.bn = new BeliefNetwork(); } protected void initNetwork(String filename) throws Exception { this.filename = filename; this.bn = load(filename); initAttributeMapping(); } /** * Initialize the attribute mapping with the basenodes' names to itself respectively. */ protected void initAttributeMapping() { for (BeliefNode node: bn.getNodes()) { addAttributeMapping(node.getName(), node.getName()); } } /** * Add a link from the node name to the attribute name. * Insert an entry into {@link #nodeNameToAttributeMapping} and into {@link #attributeToNodeNameMapping}. * @param nodeName the name of the node to link. * @param attributeName the name of the attribute to be linked with the node. */ protected void addAttributeMapping(String nodeName, String attributeName) { nodeNameToAttributeMapping.put(nodeName, attributeName); Set<String> nodeNames = attributeToNodeNameMapping.get(attributeName); if (nodeNames == null) { nodeNames = new HashSet<String>(); attributeToNodeNameMapping.put(attributeName, nodeNames); } nodeNames.add(nodeName); } /** * Get the attribute name that is linked to the given node. * @param nodeName the name of the node. * @return the attribute's name. */ public String getAttributeNameForNode(String nodeName) { return nodeNameToAttributeMapping.get(nodeName); } /** * Get the node names that are linked to the given attribute name. * @param attributeName the attribute name the nodes are linked to. * @return the node names that are linked to the attribute. */ public Set<String> getNodeNamesForAttribute(String attributeName) { return attributeToNodeNameMapping.get(attributeName); } /** * adds a node to the network * @param node the node that is to be added */ public void addNode(BeliefNode node) { bn.addBeliefNode(node); addAttributeMapping(node.getName(), node.getName()); } /** * adds a decision node (boolean) to the network * @param name label of the node */ public BeliefNode addDecisionNode(String name) { BeliefNode node = new BeliefNode(name, new Discrete(new String[]{"True", "False"})); node.setType(BeliefNode.NODE_DECISION); bn.addBeliefNode(node); return node; } /** * adds a node with the given name and the standard discrete domain {True, False} to the network * @param name the name of the node * @return a reference to the BeliefNode object that was constructed */ public BeliefNode addNode(String name) { return addNode(name, new Discrete(new String[]{"True", "False"})); } /** * adds a node with the given name and domain to the network. * Associate the attribute with the same name to the node. * @param name the name of the node * @param domain the node's domain (usually an instance of BNJ's class Discrete) * @return a reference to the BeliefNode object that was constructed */ public BeliefNode addNode(String name, Domain domain) { return addNode(name, domain, name); } /** * adds a node with the given name and domain and attribute name to the network. * @param name the name of the node * @param domain the node's domain (usually an instance of BNJ's class Discrete) * @param attributeName the name of the attribute that is assigned to the node * @return a reference to the BeliefNode object that was constructed */ public BeliefNode addNode(String name, Domain domain, String attributeName) { return addNode(name, domain, attributeName, BeliefNode.NODE_CHANCE); } /** * adds a node with the given name and domain and attribute name to the network. * @param name the name of the node * @param domain the node's domain (usually an instance of BNJ's class Discrete) * @param attributeName the type of the node (BeliefNode.NODE_CHANCE, BeliefNode.NODE_UTILITY or BeliefNode.NODE_DECISION) * @return a reference to the BeliefNode object that was constructed */ public BeliefNode addNode(String name, Domain domain, int type) { return addNode(name, domain, name, type); } /** * adds a node with the given name and domain and attribute name to the network. * @param name the name of the node * @param domain the node's domain (usually an instance of BNJ's class Discrete) * @param attributeName the name of the attribute that is assigned to the node * @param attributeName the type of the node (BeliefNode.NODE_CHANCE, BeliefNode.NODE_UTILITY or BeliefNode.NODE_DECISION) * @return a reference to the BeliefNode object that was constructed */ public BeliefNode addNode(String name, Domain domain, String attributeName, int type) { BeliefNode node = new BeliefNode(name, domain); node.setType(type); bn.addBeliefNode(node); addAttributeMapping(name, attributeName); //logger.debug("Added node "+name+" with attributeName "+attributeName); return node; } /** * adds an edge to the network, i.e. a dependency * @param node1 the name of the node that influences another * @param node2 the name of node that is influenced * @throws Exception if either of the node names are invalid */ public void connect(String node1, String node2) throws Exception { try { //logger.debug("connecting "+node1+" and "+node2); //logger.debug("Memory free: "+Runtime.getRuntime().freeMemory()+"/"+Runtime.getRuntime().totalMemory()); BeliefNode n1 = getNode(node1); BeliefNode n2 = getNode(node2); if(n1 == null || n2 == null) throw new Exception("One of the node names "+node1+" or "+node2+" is invalid!"); //logger.debug("Domainsize: "+n1.getDomain().getOrder()+"x"+n2.getDomain().getOrder()); //logger.debug("Doing the connect..."); bn.connect(n1, n2); //logger.debug("Memory free: "+Runtime.getRuntime().freeMemory()+"/"+Runtime.getRuntime().totalMemory()); //logger.debug("Connection done."); } catch(Exception e) { System.out.println("Exception occurred in connect!"); e.printStackTrace(System.out); throw e; } catch(Error e2) { System.out.println("Error occurred"); e2.printStackTrace(System.out); throw e2; } } /** connect two nodes * @param parent parent which the bnode will be a child of * @param child node which will be made a child of parent * @param adjustCPF whether to adjust the CPF as well (otherwise only the graph is altered); should be set to false only if the CPF is manually initialized later on */ public void connect(BeliefNode parent, BeliefNode child, boolean adjustCPF) { Graph graph = bn.getGraph(); graph.addDirectedEdge(parent.getOwner(), child.getOwner()); if(adjustCPF) { Vertex[] parents = graph.getParents(child.getOwner()); BeliefNode[] after = new BeliefNode[parents.length + 1]; for (int i = 0; i < parents.length; i++) { after[i + 1] = ((BeliefNode) parents[i].getObject()); } after[0] = child; CPT beforeCPF = (CPT)child.getCPF(); child.setCPF(beforeCPF.expand(after)); } } public void connect(BeliefNode parent, BeliefNode child) { connect(parent, child, true); } /** * retrieves the node with the given name * @param name the name of the node * @return a reference to the node (or null if there is no node with the given name) */ public BeliefNode getNode(String name) { int idx = getNodeIndex(name); if(idx == -1) return null; return bn.getNodes()[idx]; } public BeliefNode getNode(int idx) { return bn.getNodes()[idx]; } /** * get the index (into the BeliefNetwork's array of nodes) of the node with the given name * @param name the name of the node * @return the index of the node (or -1 if there is no node with the given name) */ public int getNodeIndex(String name) { BeliefNode[] nodes = bn.getNodes(); for(int i = 0; i < nodes.length; i++) if(nodes[i].getName().equals(name)) return i; return -1; } /** * Get the indices of the nodes that the CPT of the given node depends on. * @param node the node to take the CPT from. * @return the indices of the nodes that the CPT of the given node depends on. */ public int[] getDomainProductNodeIndices(BeliefNode node) { BeliefNode[] nodes = node.getCPF().getDomainProduct(); int[] nodeIndices = new int[nodes.length]; for(int i = 0; i < nodes.length; i++) nodeIndices[i] = this.getNodeIndex(nodes[i].getName()); return nodeIndices; } /** * Get the indices into the domains of the nodes for the given node value assignments. * @param nodeAndDomains the assignments to be converted. * @return the assignment converted to doamin indices. */ public int[] getNodeDomainIndicesFromStrings(String[][] nodeAndDomains) { BeliefNode[] nodes = bn.getNodes(); int[] nodeDomainIndices = new int[nodes.length]; Arrays.fill(nodeDomainIndices, -1); for (String[] nodeAndDomain: nodeAndDomains) { if (nodeAndDomain == null || nodeAndDomain.length != 2) throw new IllegalArgumentException("Evidences not in the correct format: "+Arrays.toString(nodeAndDomain)+"!"); int nodeIdx = getNodeIndex(nodeAndDomain[0]); if (nodeIdx < 0) throw new IllegalArgumentException("Variable with the name "+nodeAndDomain[0]+" not found!"); /*if (nodeDomainIndices[nodeIdx] > 0) logger.warn("Evidence "+nodeAndDomain[0]+" set twice!");*/ Discrete domain = (Discrete)nodes[nodeIdx].getDomain(); int domainIdx = domain.findName(nodeAndDomain[1]); if (domainIdx < 0) { if (domain instanceof Discretized) { try { double value = Double.parseDouble(nodeAndDomain[1]); String domainStr = ((Discretized)domain).getNameFromContinuous(value); domainIdx = domain.findName(domainStr); } catch (Exception e) { throw new IllegalArgumentException("Cannot find evidence value "+nodeAndDomain[1]+" in domain "+domain+"!"); } } else { throw new IllegalArgumentException("Cannot find evidence value "+nodeAndDomain[1]+" in domain "+domain+"!"); } } nodeDomainIndices[nodeIdx]=domainIdx; } return nodeDomainIndices; } public int getNodeIndex(BeliefNode node) { BeliefNode[] nodes = bn.getNodes(); for(int i = 0; i < nodes.length; i++) if(nodes[i] == node) return i; return -1; } /** * sets evidence for one of the network's node * @param nodeName the name of the node for which evidence is to be set * @param outcome the outcome, which must be in compliance with the node's domain * @throws Exception if the node name does not exist in the network or the outcome is not valid for the node's domain */ public void setEvidence(String nodeName, String outcome) throws Exception { BeliefNode node = getNode(nodeName); if(node == null) throw new Exception("Invalid node reference: " + nodeName); Discrete domain = (Discrete) node.getDomain(); int idx = domain.findName(outcome); if(idx == -1) throw new Exception("Outcome " + outcome + " not in domain of " + nodeName); node.setEvidence(new DiscreteEvidence(idx)); } /** * calculates a probability Pr[X=x, Y=y, ... | E=e, F=f, ...] * @param queries an array of 2-element string arrays (variable, outcome) * that represents the conjunction "X=x AND Y=y AND ..."; * @param evidences the conjunction of evidences, specified in the same way * @return the calculated probability * @throws Exception */ public double getProbability(String[][] queries, String[][] evidences) throws Exception { // queries with only one query variable (i.e. Pr[X | A,B,...]) can be solved directly // ... for others, recursion is necessary if(queries.length == 1) { // remove any previous evidence BeliefNode[] nodes = bn.getNodes(); for(int i = 0; i < nodes.length; i++) nodes[i].setEvidence(null); // set new evidence if(evidences != null) for(int i = 0; i < evidences.length; i++) { setEvidence(evidences[i][0], evidences[i][1]); } // run inference Pearl inf = new Pearl(); inf.run(this.bn); // return result BeliefNode node = getNode(queries[0][0]); CPF cpf = inf.queryMarginal(node); BeliefNode[] dp = cpf.getDomainProduct(); boolean done = false; int[] addr = cpf.realaddr2addr(0); while(!done) { for (int i = 0; i < addr.length; i++) if(dp[i].getDomain().getName(addr[i]).equals(queries[0][1])) { ValueDouble v = (ValueDouble) cpf.get(addr); return v.getValue(); } done = cpf.addOne(addr); } throw new Exception("Outcome not in domain!"); //inf.printResults(); } else { // Pr[A,B,C,D | E] = Pr[A | B,C,D,E] * Pr[B,C,D | E] String[][] _queries = new String[1][2]; String[][] _queries2 = new String[queries.length-1][2]; _queries[0] = queries[0]; int numEvidences = evidences == null ? 0 : evidences.length; String[][] _evidences = new String[numEvidences+queries.length-1][2]; int idx = 0; for(int i = 1; i < queries.length; i++, idx++) { _evidences[idx] = queries[i]; _queries2[idx] = queries[i]; } for(int i = 0; i < numEvidences; i++, idx++) _evidences[idx] = evidences[i]; return getProbability(_queries, _evidences) * getProbability(_queries2, evidences); } } protected void printProbabilities(int node, Stack<String[]> evidence) throws Exception { BeliefNode[] nodes = bn.getNodes(); if(node == nodes.length) { String[][] e = new String[evidence.size()][]; evidence.toArray(e); double prob = getProbability(e, null); StringBuffer s = new StringBuffer(); s.append(String.format("%6.2f%% ", 100*prob)); int i = 0; for(String[] pair : evidence) { if(i > 0) s.append(", "); s.append(String.format("%s=%s", pair[0], pair[1])); i++; } System.out.println(s); return; } Domain dom = nodes[node].getDomain(); for(int i = 0; i < dom.getOrder(); i++) { evidence.push(new String[]{nodes[node].getName(), dom.getName(i)}); printProbabilities(node+1, evidence); evidence.pop(); } } public void printFullJoint() throws Exception { printProbabilities(0, new Stack<String[]>()); } /** * prints domain information for all nodes of the network to System.out */ public void printDomain() { BeliefNode[] nodes = bn.getNodes(); for(int i = 0; i < nodes.length; i++) { System.out.print(nodes[i].getName()); Discrete domain = (Discrete)nodes[i].getDomain(); System.out.print(" {"); int c = domain.getOrder(); for(int j = 0; j < c; j++) { if(j > 0) System.out.print(", "); System.out.print(domain.getName(j)); } System.out.println("}"); } } /** * static function for loading a Bayesian network into an instance of class BeliefNetwork * @param filename the file containing the network data * @param importer an importer that is capable of understanding the file format * @return the loaded network in a new instance of class BeliefNetwork * @throws FileNotFoundException */ public static BeliefNetwork load(String filename, Importer importer) throws FileNotFoundException { FileInputStream fis = new FileInputStream(filename); OmniFormatV1_Reader ofv1w = new OmniFormatV1_Reader(); importer.load(fis, ofv1w); return ofv1w.GetBeliefNetwork(0); } /** * loads a Bayesian network from the given file (determining a suitable importer from the extension) * @param filename * @return * @throws Exception */ public static BeliefNetwork load(String filename) throws Exception { registerDefaultPlugins(); IOPlugInLoader iopl = IOPlugInLoader.getInstance(); String ext = iopl.GetExt(filename); Importer imp = iopl.GetImporterByExt(ext); if(imp == null) throw new Exception("Unable to find an importer that can handle " + ext + " files."); return load(filename, imp); } /** * saves a Bayesian network to the given filename (determining a suitable exporter from the extension) * @param filename * @return * @throws Exception */ public static void save(BeliefNetwork net, String filename) throws Exception { registerDefaultPlugins(); IOPlugInLoader iopl = IOPlugInLoader.getInstance(); String ext = iopl.GetExt(filename); Exporter exporter = iopl.GetExportersByExt(ext); if(exporter == null) throw new Exception("Unable to find an exporter that can handle " + ext + " files."); save(net, filename, exporter); } /** * static function for writing a Bayesian network to a file using a given exporter * @param net the network to be written * @param filename the file to write to * @param exporter an exporter for the desired file format * @throws FileNotFoundException */ public static void save(BeliefNetwork net, String filename, Exporter exporter) throws FileNotFoundException { exporter.save(net, new FileOutputStream(filename)); //OmniFormatV1_Writer.Write(net, (OmniFormatV1)exporter); } /** * saves a Bayesian network to the given filename (determining a suitable exporter from the extension) * @param filename * @return * @throws Exception */ public void save(String filename) throws Exception { save(this.bn, filename); } /** * writes the Bayesian network to a file with the given name using an exporter * @param filename the file to write to * @param exporter an exporter for the desired file format * @throws FileNotFoundException */ public void save(String filename, Exporter exporter) throws FileNotFoundException { save(this.bn, filename, exporter); } /** * writes the Bayesian network to a file with the given name in XML-BIF format * @param filename the file to write to * @throws FileNotFoundException */ public void saveXMLBIF(String filename) throws FileNotFoundException { save(filename, new Converter_xmlbif()); } /** * writes the Bayesian network to a file with the given name in a PMML-based format * @param filename the file to write to * @throws FileNotFoundException */ public void savePMML(String filename) throws FileNotFoundException { save(filename, new Converter_pmml()); } /** * writes the Bayesian network to the same file it was loaded from * @throws Exception * */ public void save() throws Exception { IOPlugInLoader pil = IOPlugInLoader.getInstance(); if(filename == null) throw new Exception("Cannot save - filename not given!"); Exporter exporter = pil.GetExportersByExt(pil.GetExt(filename)); save(filename, exporter); } /** * sorts the domain of the node with the given name alphabetically (if numeric is false) or * numerically (if numeric is true) - in ascending order * @param nodeName the name of the node whose domain is to be sorted * @param numeric whether to sort numerically or not. If numeric is true, * all domain values are converted to double for sorting. * If numeric is false, the values are simply sorted alphabetically. * @throws Exception if the node name is invalid */ public void sortNodeDomain(String nodeName, boolean numeric) throws Exception { BeliefNode node = getNode(nodeName); if(node == null) throw new Exception("Node not found"); Discrete domain = (Discrete)node.getDomain(); int ord = domain.getOrder(); String[] strings = new String[ord]; if(!numeric) { for(int i = 0; i < ord; i++) strings[i] = domain.getName(i); Arrays.sort(strings); } else { double[] values = new double[ord]; for(int i = 0; i < ord; i++) values[i] = Double.parseDouble(domain.getName(i)); double[] sorted_values = values.clone(); Arrays.sort(sorted_values); for(int i = 0; i < ord; i++) for(int j = 0; j < ord; j++) if(sorted_values[i] == values[j]) strings[i] = domain.getName(j); } bn.changeBeliefNodeDomain(node, new Discrete(strings)); } /** * returns the domain of the node with the given name * @param nodeName the name of the node for which the domain is to be returned * @return the domain of the node (usually instance of class Discrete) * or null if the node name is invalid */ public Domain getDomain(String nodeName) { BeliefNode node = getNode(nodeName); if(node == null) return null; return node.getDomain(); } /** * shows the Bayesian Network in an editor window (with support for standard IO plugins) */ public void show() { registerDefaultPlugins(); edu.ksu.cis.bnj.gui.GUIWindow window = new edu.ksu.cis.bnj.gui.GUIWindow(); window.register(); window.open(bn, filename); } public static void registerDefaultPlugins() { if(defaultPluginsRegistered) return; IOPlugInLoader iopl = IOPlugInLoader.getInstance(); // XML-BIF Converter_xmlbif xmlbif = new Converter_xmlbif(); iopl.addPlugin(xmlbif, xmlbif); // PMML Converter_pmml pmml = new Converter_pmml(); iopl.addPlugin(pmml, pmml); // Hugin Converter_hugin hugin = new Converter_hugin(); iopl.addPlugin(null, hugin); // Ergo Converter_ergo ergo = new Converter_ergo(); iopl.addPlugin(ergo, ergo); // UAI Converter_uai uai = new Converter_uai(); iopl.addPlugin(null, uai); defaultPluginsRegistered = true; } /** * shows the Bayesian Network in a BNJ editor window, * loading the BNJ plugins in the given directory * @param pluginDir a directory containing BNJ plugins (jar files) */ public void show(String pluginDir) { IOPlugInLoader iopl = IOPlugInLoader.getInstance(); iopl.loadPlugins(pluginDir); show(); } /** * helper function for queryShell that reads a list of comma-separated assignments "A=a,B=b,..." * into an array [["A","a"],["B","b"],...] * @param list * @return * @throws java.lang.Exception */ protected static String[][] readList(String list) throws java.lang.Exception { if(list == null) return null; String[] items = list.split(","); String[][] res = new String[items.length][2]; for(int i = 0; i < items.length; i++) { res[i] = items[i].split("="); if(res[i].length != 2) throw new java.lang.Exception("syntax error!"); } return res; } /** * starts a shell that allows the user to query the network */ public void queryShell() { // output some usage information System.out.println("Domain:"); printDomain(); System.out.println("\nUsage: Pr[X=x, Y=y, ... | E=e, F=f, ...] (X,Y: query vars;\n" + " E,F: evidence vars;\n" + " x,y,e,f: outcomes\n" + " exit (close shell)"); BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); for(;;) { try { // get input query from stdin System.out.print("\n> "); String input = br.readLine(); if(input.equalsIgnoreCase("exit")) break; // parse expression... input = input.replaceAll("\\s+", ""); Pattern p = Pattern.compile("Pr\\[([^\\]\\|]*)(?:\\|([^\\]]*))?\\]"); Matcher m = p.matcher(input); if(!m.matches()) { System.out.println("syntax error!"); } else { String[][] queries = readList(m.group(1)); String[][] evidences = readList(m.group(2)); try { // evaluate and output result... double result = getProbability(queries, evidences); System.out.println(result); } catch(Exception e) { System.out.println(e.getMessage()); } } } catch(IOException e) { System.err.println(e.getMessage()); } catch(java.lang.Exception e) { System.out.println(e.getMessage()); } } } /** * Get the sample assignment and its sampled probability as the weight sorted by probability. * @param evidences the evidences for the distribution. * @param queryNodes the nodes that should be sampled. * @param numSamples the number of samples to draw from. * @return the accumulated samples and their sampled conditional probabilities given the evidences * or null if we run out of trials for the first sample. * @throws Exception */ public WeightedSample[] getAssignmentDistribution(String[][] evidences, String[] queryNodeNames, int numSamples) throws Exception { HashMap<WeightedSample, Double> sampleSums = new HashMap<WeightedSample, Double>(); int[] queryNodes = new int[queryNodeNames.length]; for (int i=0; i<queryNodeNames.length; i++) { queryNodes[i]=getNodeIndex(queryNodeNames[i]); if (queryNodes[i] < 0) throw new IllegalArgumentException("Cannot find node with name "+queryNodeNames[i]); } Random generator = new Random(); for (int i=0; i<numSamples; i++) { WeightedSample sample = getWeightedSample(evidences, generator); if (sample == null && i == 0) // If we need too many trials and we have no samples return null; // it will be very probable that we have an endless loop because of a bad evidence! WeightedSample subSample = sample.subSample(queryNodes); if (sampleSums.containsKey(subSample)) { sampleSums.put(subSample, sampleSums.get(subSample)+subSample.weight); } else { sampleSums.put(subSample, subSample.weight); } } double sum = 0; for (WeightedSample sample: sampleSums.keySet()) { //logger.debug(sample); double value = sampleSums.get(sample); sum += value; } WeightedSample[] samples = sampleSums.keySet().toArray(new WeightedSample[0]); for (WeightedSample sample: samples) { sample.weight = sampleSums.get(sample)/sum; } Arrays.sort(samples, new Comparator<WeightedSample>() { public int compare(WeightedSample o1, WeightedSample o2) { return Double.compare(o2.weight, o1.weight); } }); return samples; } /** * gets a topological ordering of the network's nodes * @return an array of integers containing node indices */ public int[] getTopologicalOrder() { TopologicalSort topsort = new TopologicalSort(); topsort.execute(bn.getGraph()); return topsort.alpha; } /** * Get a specific entry in the cpt of the given node. * The nodeDomainIndices should contain a value for each node in the BeliefNet but only values * in the domain product of the node are queried for. * WARNING: This is very slow (mainly because getDomainProductNodeIndices performs a linear search for each node) * @param node the node the CPT should come from. * @param nodeDomainIndices the values the nodes should have (domain indices for all the nodes in the network) * @return the probability entry in the CPT. */ public double getCPTProbability(BeliefNode node, int[] nodeDomainIndices ) { CPF cpf = node.getCPF(); int[] domainProduct = getDomainProductNodeIndices(node); int[] address = new int[domainProduct.length]; for (int i=0; i<address.length; i++) { address[i]=nodeDomainIndices[domainProduct[i]]; } int realAddress = cpf.addr2realaddr(address); return cpf.getDouble(realAddress); } /** * Remove all evidences. */ public void removeAllEvidences() { // remove evidences (restoring original state) for(BeliefNode node : bn.getNodes()) { node.setEvidence(null); } } /** * Calculates a probability Pr[X=x, Y=y, ... | E=e, F=f, ...] by sampling a number of samples. * @param queries an array of 2-element string arrays (variable, outcome) * that represents the conjunction "X=x AND Y=y AND ...". * @param evidences the conjunction of evidences, specified in the same way. * @param numSamples the number of samples to draw. * @return the calculated probability. * @throws Exception */ public double getSampledProbability(String[][] queries, String[][] evidences, int numSamples) throws Exception { String[] queryNodes = new String[queries.length]; for (int i=0; i<queryNodes.length; i++) { queryNodes[i]=queries[i][0]; } WeightedSample[] samples = getAssignmentDistribution(evidences, queryNodes, numSamples); double goodSum = 0; double allSum = 0; for (int i=0; i<samples.length; i++) { allSum += samples[i].weight; if (samples[i].checkAssignment(queries)) goodSum += samples[i].weight; } return goodSum/allSum; } /** * Sample from the BeliefNet via likelihood weighted sampling. * @param evidences the evidences for the sample. * @param sampleDomainIndexes the resulting domain indexes for each node. * The length must be initialized to the number of nodes in the net. * @return * @throws Exception */ public WeightedSample getWeightedSample(String[][] evidences, Random generator) throws Exception { if (generator == null) { generator = new Random(); } return getWeightedSample(getTopologicalOrder(), evidence2DomainIndices(evidences), generator); } public WeightedSample getWeightedSample(int[] nodeOrder, int[] evidenceDomainIndices, Random generator) throws Exception { BeliefNode[] nodes = bn.getNodes(); int[] sampleDomainIndices = new int[nodes.length]; boolean successful = false; double weight = 1.0; int trials=0; success:while (!successful) { //System.out.println(trials); weight = 1.0; if (trials > MAX_TRIALS) return null; for (int i=0; i< nodeOrder.length; i++) { int nodeIdx = nodeOrder[i]; int domainIdx = evidenceDomainIndices[nodeIdx]; if (domainIdx >= 0) { // This is an evidence node? sampleDomainIndices[nodeIdx] = domainIdx; nodes[nodeIdx].setEvidence(new DiscreteEvidence(domainIdx)); // TODO this call is inefficient double prob = getCPTProbability(nodes[nodeIdx], sampleDomainIndices); if (prob == 0.0) { //System.out.println("sampling failed at evidence node " + nodes[nodeIdx].getName()); removeAllEvidences(); trials++; continue success; } weight *= prob; } else { domainIdx = ForwardSampling.sampleForward(nodes[nodeIdx], bn, generator); if (domainIdx < 0) { System.out.println("could not sample forward because of column with 0s in CPT of " + nodes[nodeIdx].getName()); removeAllEvidences(); trials++; continue success; } sampleDomainIndices[nodeIdx] = domainIdx; nodes[nodeIdx].setEvidence(new DiscreteEvidence(domainIdx)); } } trials++; removeAllEvidences(); successful = true; } return new WeightedSample(this, sampleDomainIndices, weight, null, trials); } public int[] evidence2DomainIndices(String[][] evidences) { BeliefNode[] nodes = bn.getNodes(); int[] evidenceDomainIndices = new int[nodes.length]; Arrays.fill(evidenceDomainIndices, -1); for (String[] evidence: evidences) { if (evidence == null || evidence.length != 2) throw new IllegalArgumentException("Evidences not in the correct format: "+Arrays.toString(evidence)+"!"); int nodeIdx = getNodeIndex(evidence[0]); if (nodeIdx < 0) { String error = "Variable with the name "+evidence[0]+" not found in model but mentioned in evidence!"; System.err.println("Warning: " + error); continue; //throw new IllegalArgumentException(error); } /*if (evidenceDomainIndices[nodeIdx] > 0) logger.warn("Evidence "+evidence[0]+" set twice!");*/ Discrete domain = (Discrete)nodes[nodeIdx].getDomain(); int domainIdx = domain.findName(evidence[1]); if (domainIdx < 0) { if (domain instanceof Discretized) { try { double value = Double.parseDouble(evidence[1]); String domainStr = ((Discretized)domain).getNameFromContinuous(value); domainIdx = domain.findName(domainStr); } catch (Exception e) { throw new IllegalArgumentException("Cannot find evidence value "+evidence[1]+" in domain "+domain+"!"); } } else { throw new IllegalArgumentException("Cannot find evidence value "+evidence[1]+" in domain "+domain+" of node " + nodes[nodeIdx].getName()); } } evidenceDomainIndices[nodeIdx]=domainIdx; } return evidenceDomainIndices; } /** * performs sampling on the network and returns a sample of the marginal distribution represented by this Bayesian network; evidences that are set during sampling are removed * afterwards in order to retain the original state of the network. * @return a hashmap of (node name, string value) pairs representing the sample * @param generator random number generator to use to generate sample (null to create one) * @throws Exception */ public HashMap<String,String> getSample(Random generator) throws Exception { if(generator == null) generator = new Random(); HashMap<String,String> ret = new HashMap<String,String>(); // perform topological sort to determine sampling order TopologicalSort topsort = new TopologicalSort(); topsort.execute(bn.getGraph()); int[] order = topsort.alpha; // sample BeliefNode[] nodes = bn.getNodes(); boolean succeeded = false; while(!succeeded) { ArrayList<BeliefNode> setEvidences = new ArrayList<BeliefNode>(); // remember nodes for which we set evidences while sampling for(int i = 0; i < order.length; i++) { BeliefNode node = nodes[order[i]]; if(node.hasEvidence()) { throw new Exception("At least one node has evidence. You can only sample from the marginal distribution!"); } int idxValue = ForwardSampling.sampleForward(node, bn, generator); if(idxValue == -1) { // sampling node failed - most probably because the distribution was all 0 values -> retry from start succeeded = false; break; } succeeded = true; Domain dom = node.getDomain(); //System.out.println("set node " + node.getName() + " to " + dom.getName(idxValue)); ret.put(node.getName(), dom.getName(idxValue)); node.setEvidence(new DiscreteEvidence(idxValue)); setEvidences.add(node); } // remove evidences (restoring original state) for(BeliefNode node : setEvidences) { node.setEvidence(null); } } return ret; } public static String[] getDiscreteDomainAsArray(BeliefNode node) { Discrete domain = (Discrete)node.getDomain(); String[] ret = new String[domain.getOrder()]; for(int i = 0; i < ret.length; i++) ret[i] = domain.getName(i); return ret; } public String[] getDiscreteDomainAsArray(String nodeName) { return getDiscreteDomainAsArray(getNode(nodeName)); } /* public void dump() { BeliefNode[] nodes = bn.getNodes(); for (int i=0; i<nodes.length; i++) { logger.debug("Node "+i+": "+nodes[i].getName()); logger.debug("\tAttribute: "+getAttributeNameForNode(nodes[i].getName())); } for (String attributeName: attributeToNodeNameMapping.keySet()) { logger.debug("Attribute "+attributeName+": "+attributeToNodeNameMapping.get(attributeName)); } } */ public interface CPTWalker { public abstract void tellSize(int childConfigs, int parentConfigs); public abstract void tellNode(BeliefNode n); public abstract void tellValue(int[] addr, double v); } /** * @param node the node whose CPT to walk * @param walker the visitor * @param byColumn whether to walk the CPT by column rather than by row */ public void walkCPT(BeliefNode node, CPTWalker walker, boolean byColumn) { CPF cpf = node.getCPF(); BeliefNode[] nodes = cpf.getDomainProduct(); int parentConfigs = 1; for(int i = 1; i < nodes.length; i++) parentConfigs *= nodes[i].getDomain().getOrder(); walker.tellSize(nodes[0].getDomain().getOrder(), parentConfigs); int[] addr = new int[cpf.getDomainProduct().length]; walker.tellNode(node); walkCPT(walker, cpf, addr, byColumn ? 1 : 0, byColumn); } protected void walkCPT(CPTWalker walker, CPF cpf, int[] addr, int i, boolean byColumn) { BeliefNode[] nodes = cpf.getDomainProduct(); boolean done = !byColumn ? i == addr.length : i == addr.length+1; if(done) { // we have a complete address of all parents // get the probability value int realAddr = cpf.addr2realaddr(addr); double value = ((ValueDouble)cpf.get(realAddr)).getValue(); walker.tellValue(addr, value); } else { // the address is yet incomplete -> consider all ways of setting the next e int idx = i % addr.length; Discrete dom = (Discrete)nodes[idx].getDomain(); for(int j = 0; j < dom.getOrder(); j++) { addr[idx] = j; walkCPT(walker, cpf, addr, i+1, byColumn); } } } /** * gets the index of the given value inside the given node's domain * @param node a node with a discrete domain * @param value the value whose index to search for * @return the index of the value in the node's domain */ public int getDomainIndex(BeliefNode node, String value) { Discrete domain = (Discrete)node.getDomain(); return domain.findName(value); } /** * computes the prior distribution of each node * @param evidenceDomainIndices may be null, otherwise evidence to be "faked in" (domain index for each of the nodes, -1 for no evidence). The prior of an evidence node is then calculated as 1 and nodes lower in the topology will consider the evidence in their priors. * @return */ public HashMap<BeliefNode, double[]> computePriors(int[] evidenceDomainIndices) { HashMap<BeliefNode, double[]> priors = new HashMap<BeliefNode, double[]>(); BeliefNode[] nodes = bn.getNodes(); int[] topOrder = getTopologicalOrder(); for(int i : topOrder) { BeliefNode node = nodes[i]; double[] dist = new double[node.getDomain().getOrder()]; int evidence = evidenceDomainIndices != null ? evidenceDomainIndices[i] : -1; if(evidence >= 0) { for(int j = 0; j < dist.length; j++) dist[j] = evidence == j ? 1.0 : 0.0; } else { CPF cpf = node.getCPF(); computePrior(priors, evidenceDomainIndices, cpf, 0, new int[cpf.getDomainProduct().length], dist); } priors.put(node, dist); } return priors; } protected void computePrior(HashMap<BeliefNode, double[]> priors, int[] evidenceDomainIndices, CPF cpf, int i, int[] addr, double[] dist) { BeliefNode[] domProd = cpf.getDomainProduct(); if(i == addr.length) { double p = cpf.getDouble(addr); // p = P(node setting | parent configuration) for(int j = 1; j < addr.length; j++) { double[] parentPrior = priors.get(domProd[j]); p *= parentPrior[addr[j]]; } // p = P(node setting, parent configuration) dist[addr[0]] += p; return; } BeliefNode node = domProd[i]; int nodeIdx = getNodeIndex(node); if(evidenceDomainIndices[nodeIdx] >= 0) { addr[i] = evidenceDomainIndices[nodeIdx]; computePrior(priors, evidenceDomainIndices, cpf, i+1, addr, dist); } else { Domain dom = node.getDomain(); for(int j = 0; j < dom.getOrder(); j++) { addr[i] = j; computePrior(priors, evidenceDomainIndices, cpf, i+1, addr, dist); } } } /** * gets the probability of the possible world given by the vector of domain indices * @param nodeDomainIndices domain indices for each of the node's random variables * @return */ public double getWorldProbability(int[] nodeDomainIndices) { BeliefNode[] nodes = bn.getNodes(); double ret = 1.0; for(int i = 0; i < nodes.length; i++) ret *= getCPTProbability(nodes[i], nodeDomainIndices); return ret; } public BeliefNode[] getNodes() { return bn.getNodes(); } /** * gets the total number of possible worlds */ public double getNumWorlds() { double num = 1; for(BeliefNode n : getNodes()) num *= n.getDomain().getOrder(); return num; } }