package aima.core.probability.bayes.impl;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import aima.core.probability.RandomVariable;
import aima.core.probability.bayes.BayesianNetwork;
import aima.core.probability.bayes.Node;
/**
* Default implementation of the BayesianNetwork interface.
*
* @author Ciaran O'Reilly
* @author Ravi Mohan
*/
public class BayesNet implements BayesianNetwork {
protected Set<Node> rootNodes = new LinkedHashSet<Node>();
protected List<RandomVariable> variables = new ArrayList<RandomVariable>();
protected Map<RandomVariable, Node> varToNodeMap = new HashMap<RandomVariable, Node>();
public BayesNet(Node... rootNodes) {
if (null == rootNodes) {
throw new IllegalArgumentException(
"Root Nodes need to be specified.");
}
for (Node n : rootNodes) {
this.rootNodes.add(n);
}
if (this.rootNodes.size() != rootNodes.length) {
throw new IllegalArgumentException(
"Duplicate Root Nodes Passed in.");
}
// Ensure is a DAG
checkIsDAGAndCollectVariablesInTopologicalOrder();
variables = Collections.unmodifiableList(variables);
}
//
// START-BayesianNetwork
@Override
public List<RandomVariable> getVariablesInTopologicalOrder() {
return variables;
}
@Override
public Node getNode(RandomVariable rv) {
return varToNodeMap.get(rv);
}
// END-BayesianNetwork
//
//
// PRIVATE METHODS
//
private void checkIsDAGAndCollectVariablesInTopologicalOrder() {
// Topological sort based on logic described at:
// http://en.wikipedia.org/wiki/Topoligical_sorting
Set<Node> seenAlready = new HashSet<Node>();
Map<Node, List<Node>> incomingEdges = new HashMap<Node, List<Node>>();
Set<Node> s = new LinkedHashSet<Node>();
for (Node n : this.rootNodes) {
walkNode(n, seenAlready, incomingEdges, s);
}
while (!s.isEmpty()) {
Node n = s.iterator().next();
s.remove(n);
variables.add(n.getRandomVariable());
varToNodeMap.put(n.getRandomVariable(), n);
for (Node m : n.getChildren()) {
List<Node> edges = incomingEdges.get(m);
edges.remove(n);
if (edges.isEmpty()) {
s.add(m);
}
}
}
for (List<Node> edges : incomingEdges.values()) {
if (!edges.isEmpty()) {
throw new IllegalArgumentException(
"Network contains at least one cycle in it, must be a DAG.");
}
}
}
private void walkNode(Node n, Set<Node> seenAlready,
Map<Node, List<Node>> incomingEdges, Set<Node> rootNodes) {
if (!seenAlready.contains(n)) {
seenAlready.add(n);
// Check if has no incoming edges
if (n.isRoot()) {
rootNodes.add(n);
}
incomingEdges.put(n, new ArrayList<Node>(n.getParents()));
for (Node c : n.getChildren()) {
walkNode(c, seenAlready, incomingEdges, rootNodes);
}
}
}
}