/*******************************************************************************
* Copyright (C) 2006-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.learning;
//import de.tum.in.fipm.base.data.QueryResult;
import edu.ksu.cis.bnj.ver3.core.*;
import edu.ksu.cis.bnj.ver3.core.values.Field;
import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
import java.sql.*;
import java.util.*;
import probcog.bayesnets.core.BeliefNetworkEx;
import probcog.bayesnets.core.Discretized;
import probcog.inference.IParameterHandler;
import probcog.inference.ParameterHandler;
import weka.clusterers.*;
import weka.core.*;
/**
* learns the conditional probability tables for all nodes in a Bayesian network
* when given a set of examples. CPTs are learnt by initializing all the table values to zero
* and incrementing individual values whenever a corresponding example is passed.
* In the end, probablities are obtained by means of normalization.
* @author Dominik Jain
*/
public class CPTLearner extends Learner implements IParameterHandler {
/**
* The logger for this class.
*/
/*
static final Logger logger = Logger.getLogger(CPTLearner.class);
static {
logger.setLevel(Level.WARN);
}*/
/**
* an array of example counter objects - one for each node in the network
*/
protected ExampleCounter[] counters;
/**
* an array of clusterers - one for each node;
* for nodes that do not use clustering to determine the index of the domain, the entry is null
*/
protected Clusterer[] clusterers;
/**
* controls how to finalize a column of the CPT for which there were no examples (i.e. all of the
* column entries are 0); If true, assume a uniform distribution, otherwise keep the zeros.
*/
protected boolean uniformDefault = false;
protected boolean initialized = false;
protected double pseudoCount = 0.0;
protected ParameterHandler paramHandler;
/**
* constructs a CPTLearner object from a BeliefNetworkEx object
* @param bn
* @throws Exception
*/
public CPTLearner(BeliefNetworkEx bn) throws Exception {
super(bn);
paramHandler = new ParameterHandler(this);
paramHandler.add("pseudoCount", "setPseudoCount");
}
/**
* controls how to finalize a column of the CPT when there were no examples (i.e. all of the column's entries are zero); By default, the zeros are kept
* @param value If true, use a uniform distribution for such columns; otherwise leave the column as it was (all zeros)
*/
public void setUniformDefault(boolean value) {
uniformDefault = value;
}
public void setPseudoCount(double pseudoCount) {
this.pseudoCount = pseudoCount;
}
/**
* constructs a CPTLearner object from a DomainLearner. If you consecutively want to
* learn domains and CPTs, you should make use of this constructor, because it relieves
* you of the burden of having to pass the clusterers that categorize instances for
* certain domains manually (duplicate domains are taken into consideration, i.e. clusterers
* will be reused appropriately).
* @param dl the domain learner
* @throws Exception
*/
public CPTLearner(DomainLearner dl) throws Exception {
super(dl.bn.bn);
init();
// initialize clusterers from the domain learner
if(dl.clusteredDomains != null) {
for(int i = 0; i < dl.clusteredDomains.length; i++)
addClusterer(dl.clusteredDomains[i].nodeName, dl.clusterers[i]);
if(dl.duplicateDomains != null) {
for(int i = 0; i < dl.duplicateDomains.length; i++)
for(int j = 0; j < dl.clusteredDomains.length; j++)
if(dl.duplicateDomains[i][0].equals(dl.clusteredDomains[j].nodeName)) {
for(int k = 1; k < dl.duplicateDomains[i].length; k++)
addClusterer(dl.duplicateDomains[i][k], dl.clusterers[j]);
break;
}
}
}
}
/**
* initializes the array of clusterers (initially an array of null references)
* and the array of example counters (one for each node)
*/
protected void init() {
clusterers = new Clusterer[nodes.length];
// create example counters for each node
counters = new ExampleCounter[nodes.length];
for(int i = 0; i < nodes.length; i++)
counters[i] = new ExampleCounter(nodes[i], bn, this.pseudoCount);
initialized = true;
}
/**
* learns all the examples in the result set. Each row in the result set represents one example.
* All the random variables (nodes) in the network
* need to be found in each result row as columns that are named accordingly, i.e. for each
* random variable, there must be a column with a matching name in the result set.
* @param rs the result set
* @throws Exception if the result set is empty
* @throws SQLException particularly if there is no matching column for one of the node names
*/
public void learn(ResultSet rs) throws Exception {
if(!initialized) init();
try {
// if it's an empty result set, throw exception
if(!rs.next())
throw new Exception("empty result set!");
BeliefNode[] nodes = bn.bn.getNodes();
ResultSetMetaData rsmd = rs.getMetaData();
int numCols = rsmd.getColumnCount();
// Now we can get much more nodes than attributes
// if(numCols != nodes.length)
// throw new Exception("Result does not contain suitable data (column count = " + numCols + "; node count = " + nodes.length + ")");
// map node indices to result set column indices
int[] nodeIdx2colIdx = new int[nodes.length];
Arrays.fill(nodeIdx2colIdx, -1);
for(int i = 1; i <= numCols; i++) {
Set<String> nodeNames = bn.getNodeNamesForAttribute(rsmd.getColumnName(i));
for (String nodeName: nodeNames) {
int node_idx = bn.getNodeIndex(nodeName);
if(node_idx == -1)
throw new Exception("Unknown node referenced in result set: " + rsmd.getColumnName(i));
nodeIdx2colIdx[node_idx] = i;
}
}
// gather data, iterating over the result set
int[] domainIndices = new int[nodes.length];
do {
// for each row...
// - get the indices into the domains of each node
// that correspond to the current row of data
// (sorted in the same order as the nodes are ordered
// in the BeliefNetwork)
for(int node_idx = 0; node_idx < nodes.length; node_idx++) {
int domain_idx;
if(clusterers[node_idx] == null) {
Discrete domain = (Discrete) nodes[node_idx].getDomain();
String strValue;
if (domain instanceof Discretized) { // If we have a discretized domain we discretize first...
double value = rs.getDouble(nodeIdx2colIdx[node_idx]);
strValue = (((Discretized)domain).getNameFromContinuous(value));
} else {
strValue = rs.getString(nodeIdx2colIdx[node_idx]);
}
domain_idx = domain.findName(strValue);
if(domain_idx == -1)
throw new Exception(strValue + " not found in domain of " + nodes[node_idx].getName());
}
else {
Instance inst = new Instance(1);
double value = rs.getDouble(bn.getAttributeNameForNode(bn.bn.getNodes()[node_idx].getName()));
inst.setValue(0, value);
domain_idx = clusterers[node_idx].clusterInstance(inst);
}
domainIndices[node_idx] = domain_idx;
}
// - update each node's CPT
for(int i = 0; i < nodes.length; i++) {
counters[i].count(domainIndices);
}
} while(rs.next());
}
catch (SQLException ex) { // handle any database errors
System.out.println("SQLException: " + ex.getMessage());
System.out.println("SQLState: " + ex.getSQLState());
System.out.println("VendorError: " + ex.getErrorCode());
}
}
/**
* learns all the examples in the instances. Each instance in the instances represents one example.
* All the random variables (nodes) in the network
* need to be found in each instance as columns that are named accordingly, i.e. for each
* random variable, there must be an attribute with a matching name in the instance.
* @param instances the instances
* @throws Exception if the result set is empty
* @throws SQLException particularly if there is no matching column for one of the node names
*/
public void learn(Instances instances) throws Exception {
if(!initialized) init();
// if it's an empty result set, throw exception
if(instances.numInstances() == 0)
throw new Exception("empty result set!");
BeliefNode[] nodes = bn.bn.getNodes();
int numAttributes = instances.numAttributes();
// Now we can get much more nodes than attributes
// if(numAttributes != nodes.length)
// throw new Exception("Result does not contain suitable data (attribute count = " + numAttributes + "; node count = " + nodes.length + ")");
// map node indices to attribute index
int[] nodeIdx2colIdx = new int[nodes.length];
Arrays.fill(nodeIdx2colIdx, -1);
for(int i = 0; i < numAttributes; i++) {
Set<String> nodeNames = bn.getNodeNamesForAttribute(instances.attribute(i).name());
//logger.debug("Nodes for attribute "+instances.attribute(i).name()+": "+nodeNames);
if (nodeNames==null)
continue;
for (String nodeName: nodeNames) {
int node_idx = bn.getNodeIndex(nodeName);
if(node_idx == -1)
throw new Exception("Unknown node referenced in result set: " + instances.attribute(i).name());
nodeIdx2colIdx[node_idx] = i;
}
}
// gather data, iterating over the result set
int[] domainIndices = new int[nodes.length];
@SuppressWarnings("unchecked")
Enumeration<Instance> instanceEnum = instances.enumerateInstances();
while (instanceEnum.hasMoreElements()) {
Instance instance = instanceEnum.nextElement();
// for each row...
// - get the indices into the domains of each node
// that correspond to the current row of data
// (sorted in the same order as the nodes are ordered
// in the BeliefNetwork)
for(int node_idx = 0; node_idx < nodes.length; node_idx++) {
int domain_idx;
if(clusterers[node_idx] == null) {
Discrete domain = (Discrete) nodes[node_idx].getDomain();
String strValue;
if (domain instanceof Discretized) { // If we have a discretized domain we discretize first...
int colIdx = nodeIdx2colIdx[node_idx];
if (colIdx < 0) {
//bn.dump();
/*
for (int i = 0; i < numAttributes; i++) {
logger.debug("Attribute "+i+": "+instances.attribute(i).name());
}
StringBuffer sb = new StringBuffer();
for (int i = 0; i < nodeIdx2colIdx.length; i++) {
sb.append(i+"\t");
}
sb.append("\n");
for (int i = 0; i < nodeIdx2colIdx.length; i++) {
sb.append(nodeIdx2colIdx[i]+"\t");
}
logger.debug(sb);
*/
throw new Exception("No attribute specified for "+bn.bn.getNodes()[node_idx].getName());
}
double value = instance.value(colIdx);
strValue = (((Discretized)domain).getNameFromContinuous(value));
/*if (domain.findName(strValue) == -1) {
logger.debug(domain);
logger.debug(strValue);
}*/
} else {
int colIdx = nodeIdx2colIdx[node_idx];
if (colIdx < 0) {
throw new Exception("No attribute specified for "+bn.bn.getNodes()[node_idx].getName());
}
strValue = instance.stringValue(nodeIdx2colIdx[node_idx]);
}
domain_idx = domain.findName(strValue);
if(domain_idx == -1) {
/*String[] myDomain = bn.getDiscreteDomainAsArray(bn.bn.getNodes()[node_idx].getName());
for (int i=0; i<myDomain.length; i++) {
logger.debug(myDomain[i]);
}*/
throw new Exception(strValue + " not found in domain of " + nodes[node_idx].getName());
}
}
else {
Instance inst = new Instance(1);
inst.setValue(0, instance.value(nodeIdx2colIdx[node_idx]));
domain_idx = clusterers[node_idx].clusterInstance(inst);
}
domainIndices[node_idx] = domain_idx;
}
// - update each node's CPT
for(int i = 0; i < nodes.length; i++) {
counters[i].count(domainIndices);
}
}
}
/**
* learns an example from a Map<String,String>.
* This is the only learning method without using {@link BeliefNetworkEx#getAttributeNameForNode(String)}.
* @param data a Map containing the data for one example. The names of all the random
* variables (nodes) in the network must be found in the set of keys of the
* hash map.
* @throws Exception if required keys are missing from the HashMap
*/
public void learn(Map<String,String> data) throws Exception {
if(!initialized) init();
// - get the indices into the domains of each node
// that correspond to the current row of data
// (sorted in the same order as the nodes are ordered
// in the BeliefNetwork)
BeliefNode[] nodes = bn.bn.getNodes();
int[] domainIndices = new int[nodes.length];
for(int node_idx = 0; node_idx < nodes.length; node_idx++) {
int domain_idx;
String value = data.get(nodes[node_idx].getName());
if(value == null)
throw new Exception("Key " + nodes[node_idx].getName() + " not found in data!");
if(clusterers[node_idx] == null) {
Discrete domain = (Discrete) nodes[node_idx].getDomain();
domain_idx = domain.findName(value);
if(domain_idx == -1)
throw new Exception(value + " not found in domain of " + nodes[node_idx].getName());
}
else {
Instance inst = new Instance(1);
inst.setValue(0, Double.parseDouble(value));
domain_idx = clusterers[node_idx].clusterInstance(inst);
}
domainIndices[node_idx] = domain_idx;
}
// - update each node's CPT
for(int i = 0; i < nodes.length; i++) {
counters[i].count(domainIndices);
}
}
/**
* learns all the examples in a fipm.data.QueryResult (otherwise analogous to learn(ResultSet))
* @param res the query result containing the data for a set of examples
* @throws Exception
*/
/*
public void learn(QueryResult res) throws Exception {
// map node indices to result set column indices
Vector colnames = res.getColumnNames();
int[] nodeIdx2colIdx = new int[nodes.length];
for(int i = 0; i < nodes.length; i++) {
nodeIdx2colIdx[i] = colnames.indexOf(bn.getAttributeNameForNode(nodes[i].getName()));
if(nodeIdx2colIdx[i] == -1)
throw new Exception("Incomplete result set; missing: " + nodes[i].getName());
}
// gather data, iterating over the result set
int[] domainIndices = new int[nodes.length];
for(int k = 0; k < res.getRowCount(); k++) {
// for each row...
Vector<Object> row = new Vector<Object>();
for(Object r:res.getRow(k))
row.add(r);
// - get the indices into the domains of each node
// that correspond to the current row of data
// (sorted in the same order as the nodes are ordered
// in the BeliefNetwork)
for(int node_idx = 0; node_idx < nodes.length; node_idx++) {
int domain_idx;
if(clusterers[node_idx] == null) {
Discrete domain = (Discrete) nodes[node_idx].getDomain();
String strValue;
if (domain instanceof Discretized) { // If we have a discretized domain we discretize first...
double value = Double.parseDouble(row.get(nodeIdx2colIdx[node_idx]).toString());
strValue = (((Discretized)domain).getNameFromContinuous(value));
} else {
strValue = row.get(nodeIdx2colIdx[node_idx]).toString();
}
domain_idx = domain.findName(strValue);
if(domain_idx == -1)
throw new Exception(strValue + " not found in domain of " + nodes[node_idx].getName());
}
else {
Instance inst = new Instance(1);
inst.setValue(0, Double.parseDouble((String)row.get(nodeIdx2colIdx[node_idx])));
domain_idx = clusterers[node_idx].clusterInstance(inst);
}
domainIndices[node_idx] = domain_idx;
}
// - update each node's CPT
for(int i = 0; i < nodes.length; i++) {
counters[i].count(domainIndices);
}
}
}
*/
/**
* tells the CPTLearner to use a clusterer to categorize instances (i.e. example outcomes)
* for a certain node.
* @param nodeName the name of the node
* @param clusterer the clusterer to use for categorization
* @throws Exception if the name of the node is invalid
*/
public void addClusterer(String nodeName, Clusterer clusterer) throws Exception {
for(int i = 0; i < nodes.length; i++)
if(nodes[i].getName().equals(nodeName)) {
clusterers[i] = clusterer;
return;
}
throw new Exception("Passed unknown node name!");
}
/**
* normalizes the CPTs (is called by finish and should not be called)
*/
protected void end_learning() {
// normalize the CPTs
for(int i = 0; i < nodes.length; i++)
((CPT)nodes[i].getCPF()).normalizeByDomain(uniformDefault);
}
/**
* An instance of this class counts examples for a given node.
*/
protected class ExampleCounter {
CPF cpf;
/**
* indices of relevant nodes (parents and node itself)
*/
public int[] nodeIndices;
/**
* creates an ExampleCounter object for one of the nodes in a Bayesian network
* @param n the node
* @param bn the Bayesian Network the node is part of
*/
public ExampleCounter(BeliefNode n, BeliefNetworkEx bn, double pseudoCount) {
// empty the cpf (initialize values to 0)
cpf = n.getCPF();
for(int i = 0; i < cpf.size(); i++)
cpf.put(i, new ValueDouble(pseudoCount));
// get the indices of the nodes that the CPT depends on
BeliefNode[] nodes = cpf.getDomainProduct();
nodeIndices = new int[nodes.length];
for(int i = 0; i < nodes.length; i++)
nodeIndices[i] = bn.getNodeIndex(nodes[i]);
}
public ExampleCounter(BeliefNode n, BeliefNetworkEx bn) {
this(n, bn, 0);
}
public ExampleCounter(CPF cpf, int[] nodeIndices) {
this.cpf = cpf;
this.nodeIndices = nodeIndices;
}
/**
* increments the value in the CPT that corresponds to the example
* @param domainIndices a complete example (i.e. an example containing
* values for each (relevant) node) specified as an array of integers,
* where each value is an index into the corresponding node's
* domain, the order being determined by the BeliefNetwork's
* array of nodes as returned by getNodes().
*/
public void count(int[] domainIndices) {
count(domainIndices, 1.0);
}
/**
* adds the given weight to the value in the CPT that corresponds to the example
* @param domainIndices a complete example (i.e. an example containing
* values for each (relevant) node) specified as an array of integers,
* where each value is an index into the corresponding node's
* domain, the order being determined by the BeliefNetwork's
* array of nodes as returned by getNodes().
* @param weight the weight of the example
*/
public void count(int[] domainIndices, double weight) {
int[] addr = new int[nodeIndices.length];
// get the address of the CPT field
for(int i = 0; i < nodeIndices.length; i++) {
addr[i] = domainIndices[nodeIndices[i]];
}
// get the real address of the table entry
int realAddr = cpf.addr2realaddr(addr);
// add one to the entry
cpf.put(realAddr, Field.add(cpf.get(realAddr), new ValueDouble(weight)) );
}
}
@Override
public ParameterHandler getParameterHandler() {
return this.paramHandler;
}
}