package bayesGame.bayesbayes; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.Stack; import org.apache.commons.math3.fraction.Fraction; import org.apache.commons.math3.util.Pair; import bayesGame.bayesbayes.nodeCPD.DeterministicOR; import bayesGame.bayesbayes.nodeCPD.NodeCPD; import edu.uci.ics.jung.graph.DirectedSparseGraph; public class BayesNet implements Iterable<BayesNode> { private int edgeCounter = 0; private Map<Object,BayesNode> nodes; private NetGraph graph; private HashSet<BayesNode> visitedDownstreamNodes; private Stack<Pair<BayesNode, BayesNode>> downstreamMessagePaths; public BayesNet() { graph = new NetGraph(this); nodes = new HashMap<Object,BayesNode>(); } public boolean addNode(BayesNode node){ boolean added = graph.addVertex(node); if (added){ nodes.put(node.type, node); } return added; } public boolean addNode(Object object){ BayesNode node = new BayesNode(object); return addNode(node); } public boolean addNode(Object object, boolean value){ BayesNode node = new BayesNode(object); node.setTrueValue(value); return addNode(node); } public boolean addNode(Object object, Object[] scope){ BayesNode node = getNode(object, scope); return addNode(node); } public boolean addNode(Object object, NodeCPD cpd, Object... parents){ return addNodeWithParents(object, cpd, parents); } public void addNodes(Object... nodes){ for (Object o : nodes){ addNode(o); } } private BayesNode getNode(Object o){ BayesNode newNode = getNodeIffPresent(o); if (newNode == null){ newNode = new BayesNode(o); } return newNode; } private BayesNode getNode(Object o, Object[] scope){ BayesNode newNode = getNodeIffPresent(o); if (newNode == null){ newNode = new BayesNode(o, scope); } return newNode; } public DirectedSparseGraph<BayesNode, Pair<Integer,Integer>> getGraph(){ return graph; } public boolean isPresent(Object o){ return nodes.containsKey(o); } public boolean isFullyAssumed(){ boolean fullyAssumed = true; Collection<BayesNode> nodeValues = nodes.values(); for (BayesNode existingNode : nodeValues){ if (!existingNode.isObserved() && !existingNode.isAssumed()){ fullyAssumed = false; break; } } return fullyAssumed; } private BayesNode getNodeIffPresent(Object o){ return nodes.get(o); } public String getCPTDescription(Object o){ BayesNode node = getNodeIffPresent(o); if (node != null){ return node.cptDescription; } else { return ""; } } public boolean removeBayesNode(Object object){ BayesNode node = new BayesNode(object); nodes.remove(node); return graph.removeVertex(node); } private boolean connectNodes(BayesNode node1, BayesNode node2){ boolean result = graph.addEdge(new Pair<Integer,Integer>(edgeCounter,0), node1, node2); if (result){ edgeCounter++; } return result; } public boolean connectNodes(Object rawNode1, Object rawNode2){ BayesNode node1 = getNode(rawNode1); BayesNode node2 = getNode(rawNode2); if (!scopesCompatible(node1, node2)){ return false; } return this.connectNodes(node1, node2); } /** * Connects two existing nodes, adding the parent to the child's scope if neither has the * other in its scope. Note that this will rewrite the child's conditional probability table. * To avoid this, use the regular connectNodes function which takes no action for incompatible * scopes, or use the changeNodeCPD function to make the child conform to the desired probability * distribution. * * @param rawNode1 * @param rawNode2 * @return */ public boolean forceConnectNodes(Object rawNode1, Object rawNode2){ BayesNode node1 = getNode(rawNode1); BayesNode node2 = getNode(rawNode2); if (!scopesCompatible(node1, node2)){ node2.addItemToScope(node1.type); } return this.connectNodes(node1, node2); } /** * Adds a node to the network which evaluates to true (with P = 1) iff at least one * of its parents is true (with P = 1). The parents of the node must already exist * in the network, whereas the node itself must not exist: if these criteria are not * met, or if no parents are provided, the function will return false and do nothing. * * @param orNode The identifier of the deterministic OR node to be added * @param parents The parents of the OR node * @return true if the node was added, false otherwise */ public boolean addDeterministicOr(Object orNode, Object... parents){ return addNodeWithParents(orNode, new DeterministicOR(), parents); } public boolean addNodeWithParents(Object object, NodeCPD cpd, Object... parents){ BayesNode orBayesNode = new BayesNode(object, parents); if (parents.length == 0){ return false; } // remove any duplicate parents parents = new HashSet<Object>(Arrays.asList(parents)).toArray(new Object[0]); for (Object o : parents){ if (!isPresent(o)){ return false; } } this.setNodeTo(orBayesNode, parents, cpd); boolean added = addNode(orBayesNode); if (!added){ return false; } for (Object o : parents){ boolean sanityCheck = connectNodes(o, object); if (!sanityCheck){ throw new IllegalStateException("Failed to connect nodes that should be connected fine ??? Shouldn't be possible"); } } return true; } private BayesNode setNodeTo(BayesNode node, Object[] parents, NodeCPD cpd){ return node = cpd.getNode(node, parents); } /** * Turns an existing node with at least one parent into a deterministic OR node. * * @param orNode The node to be made into a deterministic OR node * @return false if the node doesn't exist or has no parents, true otherwise. */ public boolean makeDeterministicOr(Object orObject){ return changeNodeCPD(orObject, new DeterministicOR()); } /** * Gives an existing node with at least one parent the kind of probability distribution specified in the parameter. * * @param object The node to be changed * @param cpd The desired probability distribution * @return false if the node doesn't exist or has no parents, true otherwise. */ public boolean changeNodeCPD(Object object, NodeCPD cpd){ BayesNode node = getNodeIffPresent(object); // the node has to already exist for us to do anything if (node == null){ return false; } ArrayList<BayesNode> parentNodes = new ArrayList<BayesNode>(graph.getPredecessors(node)); // and it should have parents if (parentNodes.size() == 0){ return false; } Object[] parents = new Object[parentNodes.size()]; // the graph object returns the parents as a list of BayesNodes, so let's extract their // types into a separate array for (int i = 0; i < parentNodes.size(); i++){ parents[i] = parentNodes.get(i).type; } // as the node relies on the values of its parents, it should have its parents in its scope Set<Object> scopeSet = new HashSet<Object>(Arrays.asList(node.scope)); for (Object p : parents){ if (!scopeSet.contains(p)){ node.addItemToScope(p); } } this.setNodeTo(node, parents, cpd); return true; } public List<Object> getFamily(Object object){ ArrayList<Object> family = new ArrayList<Object>(); BayesNode node = getNodeIffPresent(object); // the node has to already exist for us to do anything if (node == null){ return family; } ArrayList<BayesNode> parentNodes = new ArrayList<BayesNode>(graph.getPredecessors(node)); ArrayList<BayesNode> childNodes = new ArrayList<BayesNode>(graph.getSuccessors(node)); for (BayesNode b : parentNodes){ family.add(b.type); } for (BayesNode b : childNodes){ family.add(b.type); } return family; } private boolean scopesCompatible(BayesNode node1, BayesNode node2){ ArrayList<Object> difference = getScopeDifference(node1, node2); if (difference.isEmpty()){ return false; } return true; } private ArrayList<Object> getScopeDifference(BayesNode node1, BayesNode node2){ ArrayList<Object> list1 = new ArrayList<Object>(Arrays.asList(node1.scope)); ArrayList<Object> list2 = new ArrayList<Object>(Arrays.asList(node2.scope)); list1.retainAll(list2); return list1; } public boolean containsNode(Object rawNode){ return isPresent(rawNode); } public Fraction getProbability(Object object){ BayesNode node = getNodeIffPresent(object); if (node == null){ throw new IllegalArgumentException("Requested object not found in the graph"); } return node.getProbability(); } public ArrayList<Map<Object,Boolean>> getNonZeroProbabilities(Object object){ BayesNode node = getNodeIffPresent(object); if (node == null){ throw new IllegalArgumentException("Requested object not found in the graph"); } return node.getNonZeroProbabilities(); } public Map<Object,Boolean> getCurrentAssignments(){ Map<Object,Boolean> assignments = new HashMap<Object,Boolean>(nodes.size()); Collection<BayesNode> nodeValues = nodes.values(); for (BayesNode node : nodeValues){ if (node.isObserved() || node.isAssumed()){ boolean truthValue = node.getProbability().equals(Fraction.ONE); assignments.put(node.type, truthValue); } } return assignments; } public boolean observe(Object object){ BayesNode node = getNode(object); node.observe(); Fraction probability = node.getProbability(); return (probability.intValue() == 1); } public void observe(Object object, boolean value){ BayesNode node = getNode(object); node.observe(value); } public boolean isObserved(Object object){ BayesNode node = getNodeIffPresent(object); return node.isObserved(); } public void assume(Object object, boolean value){ BayesNode node = getNodeIffPresent(object); if (node == null){ throw new IllegalArgumentException("Requested object not found in the graph"); } node.assumeValue(value); } public void assume(Object object){ BayesNode node = getNodeIffPresent(object); if (node == null){ throw new IllegalArgumentException("Requested object not found in the graph"); } node.clearAssumedValue(); } public void addProperty(Object object, String property){ BayesNode node = getNodeIffPresent(object); if (node == null){ throw new IllegalArgumentException("Requested object not found in the graph"); } node.addProperty(property); } public void removeProperty(Object object, String property){ BayesNode node = getNodeIffPresent(object); node.removeProperty(property); } public boolean setProbabilityOfUntrue(Object object, Fraction probability, Object... variables){ BayesNode node = getNode(object); return node.setProbabilityOfUntrueVariables(probability, variables); } public boolean setTrueValue(Object object, boolean value){ BayesNode node = getNodeIffPresent(object); if (node == null){ return false; } node.setTrueValue(value); return true; } public void clearAssumptions(){ Collection<BayesNode> nodeValues = nodes.values(); for (BayesNode node : nodeValues){ node.clearAssumedValue(); } } public void resetNetworkBeliefs(){ Collection<BayesNode> nodeValues = nodes.values(); for (BayesNode node : nodeValues){ node.resetPotential(); } } public void resetNetworkBeliefsObservations(){ Collection<BayesNode> nodeValues = nodes.values(); for (BayesNode node : nodeValues){ node.resetNode(); } } // TODO: currently assumes that all the nodes are connected, this should be checked // TODO: only works on polytrees, doesn't check that the network is one public void updateBeliefs(){ if (nodes.size() > 1){ HashMap<BayesNode,Fraction> calculatedProbabilities = new HashMap<BayesNode,Fraction>(); Collection<BayesNode> nodeValues = nodes.values(); for (BayesNode root : nodeValues){ resetNetworkBeliefs(); visitedDownstreamNodes = new HashSet<BayesNode>(); downstreamMessagePaths = new Stack<Pair<BayesNode,BayesNode>>(); sendDownstreamMessages(root); visitedDownstreamNodes.clear(); sendUpstreamMessages(); root.multiplyPotentialWithMessages(); Fraction rootProbability = root.getProbability(); calculatedProbabilities.put(root, rootProbability); } for (BayesNode node : nodeValues){ Fraction calculatedProbability = calculatedProbabilities.get(node); node.setProbability(calculatedProbability); } } } private void sendDownstreamMessages(BayesNode source){ visitedDownstreamNodes.add(source); ArrayList<BayesNode> sourceNeighbors = new ArrayList<BayesNode>(graph.getNeighbors(source)); for (BayesNode neighbor : sourceNeighbors){ if (!visitedDownstreamNodes.contains(neighbor)){ Object[] sharedScope = getScopeDifference(source, neighbor).toArray(); //b Object[] sharedScope = new Object[]{source.type}; Message message = source.generateDownstreamMessage(sharedScope); neighbor.receiveDownstreamMessage(message); downstreamMessagePaths.push(new Pair<BayesNode,BayesNode>(source, neighbor)); sendDownstreamMessages(neighbor); } } } private void sendUpstreamMessages(){ while (!downstreamMessagePaths.isEmpty()){ Pair<BayesNode,BayesNode> path = downstreamMessagePaths.pop(); BayesNode source = path.getSecond(); BayesNode receiver = path.getFirst(); Object[] sharedScope = getScopeDifference(source, receiver).toArray(); // Object[] sharedScope = new Object[]{source.type}; Message message = source.generateUpstreamMessage(sharedScope); receiver.receiveUpstreamMessage(message); source.multiplyPotentialWithMessages(); } } public Iterator iterator(){ return nodes.values().iterator(); } public List<Object> getParents(Object object){ ArrayList<Object> family = new ArrayList<Object>(); BayesNode node = getNodeIffPresent(object); // the node has to already exist for us to do anything if (node == null){ return family; } ArrayList<BayesNode> parentNodes = new ArrayList<BayesNode>(graph.getPredecessors(node)); for (BayesNode b : parentNodes){ family.add(b.type); } return family; } }