/*******************************************************************************
* Copyright (C) 2008-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.srl.directed.inference;
import java.io.File;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Vector;
import java.util.regex.Pattern;
import probcog.bayesnets.core.BNDatabase;
import probcog.bayesnets.inference.ITimeLimitedInference;
import probcog.bayesnets.inference.SampledDistribution;
import probcog.inference.BasicSampledDistribution;
import probcog.inference.GeneralSampledDistribution;
import probcog.inference.IParameterHandler;
import probcog.inference.ParameterHandler;
import probcog.srl.Database;
import probcog.srl.directed.RelationalBeliefNetwork;
import probcog.srl.directed.bln.AbstractBayesianLogicNetwork;
import probcog.srl.directed.bln.AbstractGroundBLN;
import probcog.srl.directed.bln.BayesianLogicNetwork;
import probcog.srl.directed.bln.py.BayesianLogicNetworkPy;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPT;
import edu.ksu.cis.bnj.ver3.core.values.ValueDouble;
import edu.tum.cs.util.Stopwatch;
/**
* BLN inference tool
*
* @author Dominik Jain
*/
public class BLNinfer implements IParameterHandler {
String declsFile = null;
String networkFile = null;
String logicFile = null;
String dbFile = null;
boolean useMaxSteps = false;
Algorithm algo = Algorithm.LikelihoodWeighting;
String[] cwPreds = null;
boolean showBN = false;
boolean usePython = false;
boolean verbose = true;
boolean saveInstance = false;
boolean noInference = false;
boolean skipFailedSteps = false;
boolean removeDeterministicCPTEntries = false;
boolean resultsFilterEvidence = false;
double timeLimit = 10.0, infoIntervalTime = 1.0;
boolean timeLimitedInference = false;
boolean samplerInitializationBeforeTimingStarts = true;
boolean allowPartialInst = false;
String outputDistFile = null, referenceDistFile = null;
Map<String, Object> params;
AbstractBayesianLogicNetwork bln = null;
AbstractGroundBLN gbln = null;
Database db = null;
Iterable<String> queries = null;
ParameterHandler paramHandler;
Sampler sampler;
TimeLimitedInference tli;
enum SortOrder implements Comparator<InferenceResult> {
Atom {
public int compare(InferenceResult o1, InferenceResult o2) {
return o1.varName.compareTo(o2.varName);
}
},
Probability {
public int compare(InferenceResult o1, InferenceResult o2) {
return -Double.compare(o1.probabilities[0], o2.probabilities[0]);
}
},
PredicateProbability {
public int compare(InferenceResult o1, InferenceResult o2) {
String pred1 = o1.varName.substring(0, o1.varName.indexOf('('));
String pred2 = o2.varName.substring(0, o2.varName.indexOf('('));
int res = pred1.compareTo(pred2);
if(res != 0)
return res;
else
return -Double.compare(o1.probabilities[0], o2.probabilities[0]);
}
},
QueryNumberProbability {
public int compare(InferenceResult o1, InferenceResult o2) {
int res = o1.queryNo - o2.queryNo;
if(res != 0)
return res;
return -Double.compare(o1.probabilities[0], o2.probabilities[0]);
}
};
};
SortOrder resultsSortOrder = SortOrder.Atom;
// computed stuff
Collection<InferenceResult> results;
double groundingTime, inferenceInitTime, inferenceTime;
int stepsTaken;
public BLNinfer() throws Exception {
this(new HashMap<String, Object>());
}
public BLNinfer(Map<String, Object> params) throws Exception {
paramHandler = new ParameterHandler(this);
paramHandler.add("verbose", "setVerbose");
paramHandler.add("maxSteps", "setMaxSteps");
paramHandler.add("numSamples", "setMaxSteps");
paramHandler.add("inferenceMethod", "setInferenceMethod");
paramHandler.add("timeLimit", "setTimeLimit");
this.params = params;
}
public void setVerbose(Boolean verbose) {
this.verbose = verbose;
}
public void setMaxSteps(Integer steps) {
useMaxSteps = true;
}
public void setInferenceMethod(String methodName) {
try {
algo = Algorithm.valueOf(methodName);
} catch(IllegalArgumentException e) {
System.err.println("Error: Unknown inference algorithm '" + methodName + "'");
Algorithm.printList("");
System.exit(1);
}
}
public void setInferenceAlgorithm(Algorithm algo) {
this.algo = algo;
}
public void setTimeLimit(double seconds) {
timeLimitedInference = true;
this.timeLimit = seconds;
}
/**
* Sets a parameter that is to be interpreted by an internal handler of the underlying methods
* @param param the name of the parameter
* @param value the value of the parameter
*/
public void setParameter(String param, String value) {
this.params.put(param, value);
}
public void readArgs(String[] args) throws Exception {
// read arguments
for(int i = 0; i < args.length; i++) {
if(args[i].equals("-b"))
declsFile = args[++i];
else if(args[i].equals("-x"))
networkFile = args[++i];
else if(args[i].equals("-l"))
logicFile = args[++i];
else if(args[i].equals("-q")) {
String query = args[++i];
Pattern comma = Pattern.compile("\\s*,\\s*");
String[] candQueries = comma.split(query);
Vector<String> queries = new Vector<String>();
String q = "";
for(int j = 0; j < candQueries.length; j++) {
if(!q.equals(""))
q += ",";
q += candQueries[j];
if(balancedParentheses(q)) {
queries.add(q);
q = "";
}
}
this.queries = queries;
if(!q.equals(""))
throw new IllegalArgumentException("Unbalanced parentheses in queries");
}
else if(args[i].equals("-e"))
dbFile = args[++i];
else if(args[i].equals("-s"))
showBN = true;
else if(args[i].equals("-rfe"))
resultsFilterEvidence = true;
else if(args[i].equals("-nodetcpt"))
removeDeterministicCPTEntries = true;
else if(args[i].equals("-si"))
saveInstance = true;
else if(args[i].equals("-ni"))
noInference = true;
else if(args[i].equals("-skipFailedSteps"))
skipFailedSteps = true;
else if(args[i].equals("-py"))
usePython = true;
else if(args[i].equals("-cw"))
cwPreds = args[++i].split(",");
else if(args[i].equals("-maxSteps")) {
int steps = Integer.parseInt(args[++i]);
params.put("numSamples", steps);
setMaxSteps(steps);
}
else if(args[i].equals("-allowPartialInst"))
allowPartialInst = true;
else if(args[i].equals("-maxTrials"))
params.put("maxTrials", args[++i]);
else if(args[i].equals("-ia"))
setInferenceMethod(args[++i]);
else if(args[i].equals("-infoInterval"))
params.put("infoInterval", args[++i]);
else if(args[i].equals("-debug"))
params.put("debug", Boolean.TRUE);
else if(args[i].equals("-t")) {
if(i + 1 < args.length && !args[i + 1].startsWith("-"))
setTimeLimit(Double.parseDouble(args[++i]));
else
setTimeLimit(timeLimit);
}
else if(args[i].equals("-infoTime"))
infoIntervalTime = Double.parseDouble(args[++i]);
else if(args[i].equals("-od"))
outputDistFile = args[++i];
else if(args[i].equals("-cd"))
referenceDistFile = args[++i];
else if(args[i].startsWith("-O")) {
String order = args[i].substring(2);
if(order.equals("a"))
resultsSortOrder = SortOrder.Atom;
else if(order.equals("p"))
resultsSortOrder = SortOrder.Probability;
else if(order.equals("pp"))
resultsSortOrder = SortOrder.PredicateProbability;
else if(order.equals("qp"))
resultsSortOrder = SortOrder.QueryNumberProbability;
else
throw new Exception("Unknown sort order '" + order + "'");
}
else if(args[i].startsWith("-p") || args[i].startsWith("--")) { // algorithm-specific
// parameter
String[] pair = args[i].substring(2).split("=");
if(pair.length != 2)
throw new Exception("Argument '" + args[i] + "' for algorithm-specific parameterization is incorrectly formatted.");
params.put(pair[0], pair[1]);
}
else
throw new Exception("Unknown option " + args[i]);
}
}
public void setBLN(AbstractBayesianLogicNetwork bln) {
this.bln = bln;
}
public void setDatabase(Database db) {
this.db = db;
}
public void setQueries(Iterable<String> queries) {
this.queries = queries;
}
public void setGroundBLN(AbstractGroundBLN gbln) {
this.gbln = gbln;
setBLN(gbln.getBLN());
setDatabase(gbln.getDatabase());
}
public Collection<InferenceResult> run() throws Exception {
if(bln == null) {
if(networkFile == null)
throw new IllegalArgumentException("No fragment network given");
if(declsFile == null)
throw new IllegalArgumentException("No model declarations given");
// if(logicFile == null)
// throw new
// IllegalArgumentException("No logical constraints definitions given");
}
if(dbFile == null && db == null)
throw new IllegalArgumentException("No evidence given");
if(queries == null)
throw new IllegalArgumentException("No queries given");
// handle parameters
paramHandler.handle(params, false);
// load relational model
if(bln == null) {
if(!usePython)
bln = new BayesianLogicNetwork(declsFile, networkFile, logicFile);
else
bln = new BayesianLogicNetworkPy(declsFile, networkFile, logicFile);
}
RelationalBeliefNetwork blog = bln;
// (on request) remove deterministic dependencies in CPTs
if(removeDeterministicCPTEntries) {
final double lowProb = 0.001;
for(BeliefNode node : blog.bn.getNodes()) {
CPT cpf = (CPT) node.getCPF();
for(int i = 0; i < cpf.size(); i++)
if(cpf.getDouble(i) == 0.0)
cpf.put(i, new ValueDouble(lowProb));
cpf.normalizeByDomain();
}
}
// read evidence database
if(db == null)
db = new Database(blog);
paramHandler.addSubhandler(db.getParameterHandler());
if(dbFile != null)
db.readBLOGDB(dbFile);
if(cwPreds != null) {
for(String predName : cwPreds)
db.setClosedWorldPred(predName);
}
// instantiate ground model
if(gbln == null) {
Stopwatch sw = new Stopwatch();
sw.start();
bln.setAllowPartialInstantiation(allowPartialInst);
gbln = bln.ground(db);
paramHandler.addSubhandler(gbln);
gbln.instantiateGroundNetwork();
this.groundingTime = sw.getElapsedTimeSecs();
}
if(showBN) {
gbln.getGroundNetwork().show();
}
if(saveInstance) {
// save Bayesian network
String baseName = networkFile.substring(0, networkFile.lastIndexOf('.'));
gbln.getGroundNetwork().saveXMLBIF(baseName + ".instance.xml");
// save evidence data
BNDatabase bndb = new BNDatabase();
for(probcog.srl.Variable var : db.getEntries())
bndb.add(var.getName(), var.value);
bndb.write(new PrintStream(new File(baseName + ".instance.bndb")));
}
if(noInference)
return null;
// read reference distribution if any
GeneralSampledDistribution referenceDist = null;
int[] evidenceDomainIndices = null; // to filter out evidence in
// distribution comparisons
if(referenceDistFile != null) {
referenceDist = GeneralSampledDistribution.fromFile(new File(referenceDistFile));
evidenceDomainIndices = gbln.getFullEvidence(gbln.getDatabase().getEntriesAsArray());
}
// run inference
Stopwatch sw = new Stopwatch();
sw.start();
// - create sampler and pass on parameters
sampler = algo.createSampler(gbln);
sampler.setQueries(queries);
// - set options
paramHandler.addSubhandler(sampler);
// - run inference
SampledDistribution dist;
if(timeLimitedInference) {
if(!(sampler instanceof ITimeLimitedInference))
throw new Exception(sampler.getAlgorithmName() + " does not support time-limited inference");
ITimeLimitedInference tliSampler = (ITimeLimitedInference) sampler;
if(!useMaxSteps)
sampler.setNumSamples(Integer.MAX_VALUE);
sampler.setInfoInterval(Integer.MAX_VALUE); // provide intermediate
// results only
// triggered by
// time-limited
// inference
tli = new TimeLimitedInference(tliSampler, timeLimit, infoIntervalTime);
paramHandler.addSubhandler(tli);
tli.setReferenceDistribution(referenceDist);
tli.setEvidenceDomainIndices(evidenceDomainIndices);
if(samplerInitializationBeforeTimingStarts)
tliSampler.initialize(); // otherwise initialization is called
// by infer()
dist = tli.run();
if(referenceDist != null)
System.out.println("MSEs: " + tli.getMSEs());
results = tli.getResults(dist);
}
else {
dist = sampler.infer();
results = sampler.getResults(dist);
}
this.inferenceTime = sampler.getInferenceTime();
this.inferenceInitTime = sampler.getInitTime();
if(dist != null)
this.stepsTaken = dist.steps;
sw.stop();
// print results
if(verbose) {
ArrayList<InferenceResult> sortedResults = new ArrayList<InferenceResult>(results);
Collections.sort(sortedResults, this.resultsSortOrder);
for(InferenceResult res : sortedResults) {
boolean show = true;
if(resultsFilterEvidence)
if(db.contains(res.varName))
show = false;
if(show)
res.print();
}
}
// save output distribution
if(outputDistFile != null) {
GeneralSampledDistribution gdist = dist.toGeneralDistribution();
File f = new File(outputDistFile);
gdist.write(f);
GeneralSampledDistribution gdist2 = GeneralSampledDistribution.fromFile(f);
gdist2.print(System.out);
}
// compare distributions
if(referenceDist != null && dist != null) {
System.out.println("comparing to reference distribution...");
BasicSampledDistribution.compareDistributions(dist, referenceDist, evidenceDomainIndices);
}
return results;
}
/**
* @return the results returned by the inference algorithm
*/
public Collection<InferenceResult> getResults() {
return this.results;
}
/**
* @return the total number of seconds that the inference algorithm ran for
* (init + computation)
*/
public double getTotalInferenceTime() {
return getInferenceTime() + getInferenceInitTime();
}
/**
* @return the number of seconds the actual inference method ran (without
* initialization)
*/
public double getInferenceTime() {
return inferenceTime;
}
/**
* @return number of seconds taken to instantiate the ground model
*/
public double getGroundingTime() {
return groundingTime;
}
public double getInferenceInitTime() {
return inferenceInitTime;
}
/**
* @return the number of steps taken by the inference algorithm that was run
*/
public int getNumSteps() {
return stepsTaken;
}
public Sampler getInferenceObject() {
return sampler;
}
/**
* @param args
*/
public static void main(String[] args) {
Locale.setDefault(new Locale("en"));
try {
BLNinfer infer = new BLNinfer();
infer.readArgs(args);
infer.run();
// report any unhandled parameters
ParameterHandler handler = infer.getParameterHandler();
Collection<String> unhandledParams = handler.getUnhandledParams();
if(!unhandledParams.isEmpty()) {
System.err.println("Warning: Some parameters could not be handled: " + unhandledParams.toString() + "; supported parameters: ");
// handler.getHandledParameters().toString()
handler.printHelp(System.err);
}
} catch(IllegalArgumentException e) {
e.printStackTrace();
// System.err.println(e);
System.out.println(
"\n usage: BLNinfer <arguments>\n\n" +
" required arguments:\n\n" +
" -b <declarations file> declarations file (types, domains, signatures, etc.)\n" +
" -x <network file> fragment network (XML-BIF or PMML)\n" +
" -l <logic file> logical constraints file\n" +
" -e <evidence db pattern> an evidence database file or file mask\n" +
" -q <comma-sep. queries> queries (predicate names or partially grounded terms with lower-case vars)\n\n" +
" options:\n\n" + " -allowPartialInst allow partial ground network instantiations (skip nodes with no applicable fragment)\n" +
" -maxSteps # the maximum number of steps to take (default: 1000 for non-time-limited inf.)\n" +
" -maxTrials # the maximum number of trials per step for BN sampling algorithms (default: 5000)\n" +
" -infoInterval # the number of steps after which to output a status message\n" +
" -skipFailedSteps failed steps (> max trials) should just be skipped\n\n" +
" -t [secs] use time-limited inference (default: 10 seconds)\n" +
" -infoTime # interval in secs after which to display intermediate results (time-limited inference, default: 1.0)\n" +
" -ia <name> inference algorithm selection; valid names:");
Algorithm.printList(" ");
System.out.println(
" --<key>=<value> set algorithm-specific parameter\n" +
" -debug debug mode with additional outputs\n" +
" -s show ground network in editor\n" +
" -si save ground network instance in BIF format (.instance.xml) and evidence (.instance.bndb)\n" +
" -ni do not actually run the inference method (only instantiate ground network)" +
" -rfe filter evidence in results\n" +
" -nodetcpt remove deterministic CPT columns by replacing 0s with low prob. values\n" +
" -cw <predNames> set predicates as closed-world (comma-separated list of names)\n" +
" -O<a|p|pp|qp> order printed results by atom name (a), probability (p), predicate then probability (pp), query then probability (qp)\n" +
" -od <file> save output distribution to file\n" +
" -cd <file> compare results of inference to reference distribution in file\n" +
" -py use Python-based logic engine [deprecated]\n");
System.exit(1);
} catch(Exception e) {
e.printStackTrace();
System.exit(1);
}
}
public static boolean balancedParentheses(String s) {
int n = 0;
for(int i = 0; i < s.length(); i++) {
if(s.charAt(i) == '(')
n++;
else if(s.charAt(i) == ')')
n--;
}
return n == 0;
}
@Override
public ParameterHandler getParameterHandler() {
return paramHandler;
}
/**
* If time-limited inference was performed, returns the corresponding object
*
* @return an instance of {@link TimeLimitedInference} (or null if
* time-limited inference was not carried out)
*/
public TimeLimitedInference getTimeLimitedInference() {
return this.tli;
}
}