/*
* Author: tdanford
* Date: Dec 3, 2008
*/
package org.seqcode.ml.bayesnets;
import java.util.*;
import org.seqcode.gseutils.graphs.DirectedAlgorithms;
import org.seqcode.gseutils.graphs.DirectedGraph;
import org.seqcode.gseutils.models.Model;
import org.seqcode.math.probability.FiniteDistribution;
import org.seqcode.ml.regression.DataFrame;
public class BN<X extends Model> {
public DirectedGraph graph;
private Map<String,BNVar> vars;
private Map<String,BNCpd> cpds;
private DataFrame<X> data;
public BN(BN bn) {
data = bn.data;
vars = new TreeMap<String,BNVar>(bn.vars);
cpds = new TreeMap<String,BNCpd>(bn.cpds);
graph = new DirectedGraph(bn.graph);
}
public BN(DataFrame<X> d, String... varNames) {
data = d;
graph = new DirectedGraph();
vars = new TreeMap<String,BNVar>();
cpds = new TreeMap<String,BNCpd>();
for(int i = 0; i < varNames.length; i++) {
if(vars.containsKey(varNames[i])) {
throw new IllegalArgumentException(varNames[i]);
}
BNVar var = new BNVar(varNames[i], data.fieldValues(varNames[i]));
vars.put(varNames[i], var);
graph.addVertex(varNames[i]);
}
// No call to learnCPDs(), since we presume that the user will add some edges to
// the graph first, and then call that method him/herself.
}
public BN(DataFrame<X> d, DirectedGraph g) {
data = d;
graph = g;
vars = new TreeMap<String,BNVar>();
cpds = new TreeMap<String,BNCpd>();
Set<String> verts = graph.getVertices();
for(String varName : verts) {
BNVar var = new BNVar(varName, data.fieldValues(varName));
vars.put(varName, var);
}
learnCPDs();
}
public FiniteDistribution posterior(X m, String var) {
BNVar bnvar = vars.get(var);
Object bnvalue = bnvar.findValue(m);
if(bnvalue != null) {
int code = bnvar.encode(bnvalue);
return new FiniteDistribution(bnvar.size(), code);
}
return null;
}
public void print() {
DirectedAlgorithms algos = new DirectedAlgorithms(graph);
Vector<String> verts = algos.getTopologicalOrdering();
graph.printGraph(System.out);
for(String name : verts) {
cpds.get(name).print();
System.out.println();
}
}
public DataFrame<X> getData() { return data; }
public Set<String> varNames() { return graph.getVertices(); }
public X sample() {
try {
X model = data.getModelClass().newInstance();
DirectedAlgorithms algos = new DirectedAlgorithms(graph);
Vector<String> verts = algos.getTopologicalOrdering();
for(String v : verts) {
cpds.get(v).resample(model);
}
return model;
} catch (InstantiationException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
return null;
}
public double logLikelihood(Model m) {
double sum = 0.0;
for(String varName : cpds.keySet()) {
sum += cpds.get(varName).logLikelihood(m);
}
return sum;
}
public double logLikelihood(Iterator<? extends Model> ms) {
double sum = 0.0;
while(ms.hasNext()) {
sum += logLikelihood(ms.next());
}
return sum;
}
public double logLikelihood() {
return logLikelihood(data.iterator());
}
public BNVar getVar(String v) {
return vars.get(v);
}
public BNCpd getCPD(String v) {
return cpds.get(v);
}
public int countParameters() {
int count = 0;
for(String key : cpds.keySet()) {
count += cpds.get(key).countParameters();
}
return count;
}
public int countParameters(String child, String... parents) {
BNVar cvar = vars.get(child);
int count = cvar.size();
for(int i = 0; i < parents.length; i++) {
count *= vars.get(parents[i]).size();
}
return count;
}
public void learnCPDs() {
DirectedAlgorithms algos = new DirectedAlgorithms(graph);
if(algos.hasCycle()) {
throw new IllegalStateException("Graph has a cycle.");
}
for(String name : vars.keySet()) {
cpds.put(name, learnCPD(name));
}
}
private BNCpd learnCPD(String node) {
Set<String> parentNames = graph.getParents(node);
BNVar[] parents = new BNVar[parentNames.size()];
BNVar child = vars.get(node);
int pi = 0;
for(String p : parentNames) {
parents[pi++] = vars.get(p);
}
BNCpd cpd = new BNCpd(parents, child);
cpd.learn(data.iterator());
return cpd;
}
}