/*******************************************************************************
* Copyright (C) 2009-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.io.File;
import java.util.Arrays;
import java.util.Vector;
import java.util.Map.Entry;
import java.util.regex.Pattern;
import probcog.bayesnets.core.BNDatabase;
import probcog.bayesnets.core.BeliefNetworkEx;
import probcog.inference.BasicSampledDistribution;
import probcog.inference.GeneralSampledDistribution;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.ksu.cis.bnj.ver3.core.CPT;
import edu.ksu.cis.bnj.ver3.core.Discrete;
import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
import edu.tum.cs.util.Stopwatch;
/**
* Main application of Bayesian network inference.
* @author Dominik Jain
*/
public class BNinfer {
/**
* @param args
*/
public static void main(String[] args) {
try {
String networkFile = null;
String dbFile = null;
String query = null;
int maxSteps = 1000;
int maxTrials = 5000;
int infoInterval = 100;
Algorithm algo = Algorithm.LikelihoodWeighting;
boolean debug = false;
boolean skipFailedSteps = false;
boolean removeDeterministicCPTEntries = false;
double timeLimit = 10.0, infoIntervalTime = 1.0;
boolean timeLimitedInference = false;
boolean useMaxSteps = false;
String outputDistFile = null, referenceDistFile = null;
// read arguments
for(int i = 0; i < args.length; i++) {
if(args[i].equals("-n"))
networkFile = args[++i];
else if(args[i].equals("-q"))
query = args[++i];
else if(args[i].equals("-e"))
dbFile = args[++i];
else if(args[i].equals("-nodetcpt"))
removeDeterministicCPTEntries = true;
else if(args[i].equals("-skipFailedSteps"))
skipFailedSteps = true;
else if(args[i].equals("-maxSteps")) {
maxSteps = Integer.parseInt(args[++i]);
useMaxSteps = true;
}
else if(args[i].equals("-maxTrials"))
maxTrials = Integer.parseInt(args[++i]);
else if(args[i].equals("-ia")) {
try {
algo = Algorithm.valueOf(args[++i]);
}
catch(IllegalArgumentException e) {
System.err.println("Error: Unknown inference algorithm '" + args[i] + "'");
System.exit(1);
}
}
else if(args[i].equals("-infoInterval"))
infoInterval = Integer.parseInt(args[++i]);
else if(args[i].equals("-debug"))
debug = true;
else if(args[i].equals("-t")) {
timeLimitedInference = true;
if(i+1 < args.length && !args[i+1].startsWith("-"))
timeLimit = Double.parseDouble(args[++i]);
}
else if(args[i].equals("-od"))
outputDistFile = args[++i];
else if(args[i].equals("-cd"))
referenceDistFile = args[++i];
else
System.err.println("Warning: unknown option " + args[i] + " ignored!");
}
if(networkFile == null || dbFile == null || query == null) {
System.out.println("\n usage: BNinfer <arguments>\n\n" +
" required arguments:\n\n" +
" -n <network file> fragment network (XML-BIF or PMML)\n" +
" -e <evidence db pattern> an evidence database file or file mask\n" +
" -q <comma-sep. queries> queries (predicate names or partially grounded terms with lower-case vars)\n\n" +
" options:\n\n" +
" -maxSteps # the maximum number of steps to take, where applicable (default: 1000)\n" +
" -maxTrials # the maximum number of trials per step for BN sampling algorithms (default: 5000)\n" +
" -infoInterval # the number of steps after which to output a status message\n" +
" -skipFailedSteps failed steps (> max trials) should just be skipped\n\n" +
" -t [secs] use time-limited inference (default: 10 seconds)\n" +
" -infoTime # interval in secs after which to display intermediate results (time-limited inference, default: 1.0)\n" +
" -ia <name> inference algorithm selection; valid names:");
for(Algorithm a : Algorithm.values())
System.out.printf(" %-28s %s\n", a.toString(), a.getDescription());
System.out.println(
" -od <file> save output distribution to file\n" +
" -cd <file> compare results of inference to reference distribution in file\n" +
" -debug debug mode with additional outputs\n" +
" -nodetcpt remove deterministic CPT columns by replacing 0s with low prob. values\n");
System.exit(1);
}
// determine queries
Pattern comma = Pattern.compile("\\s*,\\s*");
String[] candQueries = comma.split(query);
Vector<String> queries = new Vector<String>();
String q = "";
for(int i = 0; i < candQueries.length; i++) {
if(!q.equals(""))
q += ",";
q += candQueries[i];
if(balancedParentheses(q)) {
queries.add(q);
q = "";
}
}
if(!q.equals(""))
throw new IllegalArgumentException("Unbalanced parentheses in queries");
// load model
BeliefNetworkEx bn = new BeliefNetworkEx(networkFile);
BeliefNode[] nodes = bn.bn.getNodes();
// (on request) remove deterministic dependencies in CPTs
if(removeDeterministicCPTEntries) {
final double lowProb = 0.001;
for(BeliefNode node : nodes) {
CPF cpf = node.getCPF();
for(int i = 0; i < cpf.size(); i++)
if(cpf.getDouble(i) == 0.0)
cpf.put(i, new ValueDouble(lowProb));
((CPT)cpf).normalizeByDomain();
}
}
// read evidence database
int[] evidenceDomainIndices = new int[nodes.length];
Arrays.fill(evidenceDomainIndices, -1);
BNDatabase db = new BNDatabase(new File(dbFile));
for(Entry<String,String> entry : db.getEntries()) {
BeliefNode node = bn.getNode(entry.getKey());
if(node == null)
throw new Exception("Evidence node '" + entry.getKey() + "' not found in model.");
Discrete dom = (Discrete)node.getDomain();
int domidx = dom.findName(entry.getValue());
if(domidx == -1)
throw new Exception("Value '" + entry.getValue() + "' not found in domain of node '" + entry.getKey() + "'");
evidenceDomainIndices[bn.getNodeIndex(node)] = domidx;
}
// read reference distribution if any
GeneralSampledDistribution referenceDist = null;
if(referenceDistFile != null) {
referenceDist = GeneralSampledDistribution.fromFile(new File(referenceDistFile));
}
// determine queries
Vector<Integer> queryVars = new Vector<Integer>();
for(String qq : queries) {
int varIdx = bn.getNodeIndex(qq);
if(varIdx == -1)
throw new Exception("Unknown variable '" + qq + "'");
queryVars.add(varIdx);
}
// run inference
Stopwatch sw = new Stopwatch();
sw.start();
// - create sampler
Sampler sampler = algo.createSampler(bn);
// - set evidence and options
sampler.setEvidence(evidenceDomainIndices);
sampler.setQueryVars(queryVars);
sampler.setDebugMode(debug);
sampler.setMaxTrials(maxTrials);
sampler.setSkipFailedSteps(skipFailedSteps);
sampler.setNumSamples(maxSteps);
sampler.setInfoInterval(infoInterval);
// - run inference
SampledDistribution dist = null;
if(timeLimitedInference) {
if(!(sampler instanceof ITimeLimitedInference))
throw new Exception(sampler.getAlgorithmName() + " does not support time-limited inference");
ITimeLimitedInference tliSampler = (ITimeLimitedInference) sampler;
if(!useMaxSteps)
sampler.setNumSamples(Integer.MAX_VALUE);
sampler.setInfoInterval(Integer.MAX_VALUE); // provide intermediate results only triggered by time-limited inference
TimeLimitedInference tli = new TimeLimitedInference(tliSampler, timeLimit, infoIntervalTime);
tli.setReferenceDistribution(referenceDist);
dist = tli.run();
if(referenceDist != null)
System.out.println("MSEs: " + tli.getMSEs());
}
else
dist = sampler.infer();
sw.stop();
// print results
for(String qq : queries) {
int varIdx = bn.getNodeIndex(qq);
dist.printVariableDistribution(System.out, varIdx);
}
// save output distribution
if(outputDistFile != null) {
GeneralSampledDistribution gdist = dist.toGeneralDistribution();
File f= new File(outputDistFile);
gdist.write(f);
GeneralSampledDistribution gdist2 = GeneralSampledDistribution.fromFile(f);
gdist2.print(System.out);
}
// compare distributions
if(referenceDist != null) {
System.out.println("comparing to reference distribution...");
BasicSampledDistribution.compareDistributions(dist, referenceDist, evidenceDomainIndices);
}
}
catch(Exception e) {
e.printStackTrace();
}
}
public static boolean balancedParentheses(String s) {
int n = 0;
for(int i = 0; i < s.length(); i++) {
if(s.charAt(i) == '(')
n++;
else if(s.charAt(i) == ')')
n--;
}
return n == 0;
}
}