/*******************************************************************************
* Copyright (C) 2007-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.srl.directed.learning;
import java.io.File;
import java.io.PrintStream;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Vector;
import java.util.regex.Pattern;
import probcog.inference.IParameterHandler;
import probcog.inference.ParameterHandler;
import probcog.srl.Database;
import probcog.srl.GenericDatabase;
import probcog.srl.Signature;
import probcog.srl.directed.ABLModel;
/**
* Bayesian logic network parameter learning tool.
* @author Dominik Jain
*/
public class BLNLearner implements IParameterHandler {
protected boolean showBN = false, learnDomains = false, ignoreUndefPreds = false, toMLN = false, debug = false, uniformDefault = false;
protected boolean verbose = true;
protected String declsFile = null, bifFile = null, dbFile = null, outFileDecls = null, outFileNetwork = null;
protected boolean noNormalization = false;
protected boolean mergeDomains = false;
protected ABLModel bn;
protected Vector<GenericDatabase<?,?>> dbs = new Vector<GenericDatabase<?,?>>();
protected ParameterHandler paramHandler;
protected Map<String,Object> params = new HashMap<String,Object>();
public BLNLearner() {
paramHandler = new ParameterHandler(this);
}
public void readArgs(String[] args) throws IllegalArgumentException {
for(int i = 0; i < args.length; i++) {
if(args[i].equals("-s"))
showBN = true;
else if(args[i].equals("-d"))
learnDomains = true;
else if(args[i].equals("-md"))
mergeDomains = true;
else if(args[i].equals("-i"))
ignoreUndefPreds = true;
else if(args[i].equals("-b"))
declsFile = args[++i];
else if(args[i].equals("-x"))
bifFile = args[++i];
else if(args[i].equals("-t"))
dbFile = args[++i];
else if(args[i].equals("-ob"))
outFileDecls = args[++i];
else if(args[i].equals("-ox"))
outFileNetwork = args[++i];
else if(args[i].equals("-mln"))
toMLN = true;
else if(args[i].equals("-nn"))
noNormalization = true;
else if(args[i].equals("-ud"))
uniformDefault = true;
else if(args[i].equals("-debug"))
debug = true;
else if(args[i].startsWith("--")) { // algorithm-specific parameter
String[] pair = args[i].substring(2).split("=");
if(pair.length != 2)
throw new IllegalArgumentException("Argument '" + args[i] + "' for algorithm-specific parameterization is incorrectly formatted.");
params.put(pair[0], pair[1]);
}
else
throw new IllegalArgumentException("Unknown parameter: " + args[i]);
}
if(outFileDecls == null || outFileNetwork == null)
throw new IllegalArgumentException("Not all output files given");
}
public void setABLModel(ABLModel abl) {
bn = abl;
}
public void addTrainingDatabase(GenericDatabase<?,?> db) {
dbs.add(db);
}
public void setLearnDomains(boolean enabled) {
this.learnDomains = enabled;
}
public void setOutputFileNetwork(String filename) {
this.outFileNetwork = filename;
}
public void setOutputFileDecls(String filename) {
this.outFileDecls = filename;
}
/**
* Sets a parameter that is to be interpreted by an internal handler of the underlying methods
* @param param the name of the parameter
* @param value the value of the parameter
*/
public void setParameter(String param, String value) {
this.params.put(param, value);
}
public void setVerbose(boolean verbose) {
this.verbose = verbose;
}
public ABLModel learn() throws IllegalArgumentException {
try {
if(bn == null) {
if(bifFile == null) {
throw new IllegalArgumentException("No network file given");
}
// create an ABL model
bn = new ABLModel(declsFile, bifFile);
}
// process parameters
paramHandler.handle(params, false);
// prepare it for learning
bn.prepareForLearning();
if(verbose) {
System.out.println("Signatures:");
for(Signature sig : bn.getSignatures()) {
System.out.println(" " + sig);
}
}
// read the training databases
if(dbFile != null) {
if(verbose) System.out.println("Reading data...");
String regex = new File(dbFile).getName();
Pattern p = Pattern.compile( regex );
File directory = new File(dbFile).getParentFile();
if(directory == null)
directory = new File(".");
else
if(!directory.exists())
throw new IllegalArgumentException("The directory '" + directory + "', which was specfied in the pattern, does not exist");
if(verbose) System.out.printf("Searching for '%s' in '%s'...\n", regex, directory);
for (File file : directory.listFiles()) {
if(p.matcher(file.getName()).matches()) {
Database db = new Database(bn);
if(verbose) System.out.printf("reading %s...\n", file.getAbsolutePath());
db.readBLOGDB(file.getPath(), ignoreUndefPreds);
//db.finalize(); // TODO determine whether to do this or not
dbs.add(db);
}
}
}
if(dbs.isEmpty())
throw new IllegalArgumentException("No training databases given");
// check domains for overlaps and merge if necessary
if(mergeDomains) {
if(verbose) System.out.println("Checking domains...");
for(GenericDatabase<?,?> db : dbs)
db.checkDomains(verbose);
}
// learn domains
if(learnDomains) {
if(verbose) System.out.println("Learning domains...");
DomainLearner domLearner = new DomainLearner(bn);
domLearner.setVerbose(verbose);
for(GenericDatabase<?,?> db : dbs) {
domLearner.learn(db);
}
domLearner.finish();
}
if(verbose) {
System.out.println("Domains:");
for(Signature sig : bn.getSignatures()) {
System.out.println(" " + sig.functionName + ": " + sig.returnType + " ");
}
}
// learn parameters
boolean learnParams = true;
if(learnParams) {
if(verbose) {
System.out.println("Learning parameters...");
if(uniformDefault)
System.out.println(" option: uniform distribution is assumed as default");
}
CPTLearner cptLearner = new CPTLearner(bn, uniformDefault, debug);
paramHandler.addSubhandler(cptLearner);
//cptLearner.setUniformDefault(true);
int i = 1;
for(GenericDatabase<?,?> db : dbs) {
if(verbose) System.out.printf("database %d/%d\n", i, dbs.size());
cptLearner.learnTyped(db, true, verbose);
++i;
}
if(!noNormalization)
cptLearner.finish();
// write learnt BLOG/ABL model
if(outFileDecls != null) {
if(verbose) System.out.println("Writing declarations to " + outFileDecls + "...");
if(outFileNetwork != null)
bn.setNetworkFilename(outFileNetwork);
PrintStream out = new PrintStream(new File(outFileDecls));
bn.write(out);
out.close();
}
// write parameters to Bayesian network template
if(outFileNetwork != null) {
if(verbose) System.out.println("Writing network to " + outFileNetwork + "...");
bn.save(outFileNetwork);
}
}
Collection<String> unhandledParams = paramHandler.getUnhandledParams();
if (!unhandledParams.isEmpty()) {
System.err.println("WARNING: There were unhandled parameters: " + unhandledParams.toString());
paramHandler.printHelp(System.out);
}
// write MLN
if(toMLN) {
String filename = outFileDecls + ".mln";
if(verbose) System.out.println("Writing MLN " + filename);
PrintStream out = new PrintStream(new File(outFileDecls + ".mln"));
bn.toMLN(out, false, false, false);
}
// show bayesian network
if(showBN) {
if(verbose) System.out.println("Showing Bayesian network...");
bn.show();
}
}
catch(Exception e) {
e.printStackTrace();
}
return bn;
}
public static void main(String[] args) {
BLNLearner l = new BLNLearner();
try {
l.readArgs(args);
l.learn();
}
catch(IllegalArgumentException e) {
String acronym = "ABL";
System.out.println("\n usage: learn" + acronym + " [-b <" + acronym + " file>] <-x <network file>> <-t <training db pattern>> <-ob <" + acronym + " output>> <-ox <network output>> [-s] [-d]\n\n"+
" -b " + acronym + " file from which to read function signatures\n" +
" -s show learned fragment network\n" +
" -d learn domains\n" +
" -md merge domains containing the same constants\n" +
" -i ignore data on predicates not defined in the model\n" +
" -ud apply uniform distribution by default (for CPT columns with no examples)\n" +
" -nn no normalization (i.e. keep counts in CPTs)\n" +
" -mln convert learnt model to a Markov logic network\n" +
" -debug output debug information\n");
return;
}
}
@Override
public ParameterHandler getParameterHandler() {
return paramHandler;
}
}