/******************************************************************************* * Copyright (C) 2009-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog; import java.io.File; import java.io.PrintStream; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Vector; import java.util.regex.Pattern; import probcog.logic.sat.weighted.WeightedFormula; import probcog.srl.Database; import probcog.srl.mln.MarkovLogicNetwork; import probcog.srl.mln.MarkovRandomField; import probcog.srl.mln.inference.InferenceAlgorithm; import probcog.srl.mln.inference.InferenceResult; import probcog.srl.mln.inference.MAPInferenceAlgorithm; import probcog.srl.mln.inference.MCSAT; import probcog.srl.mln.inference.MaxWalkSAT; import probcog.srl.mln.inference.Toulbar2MAPInference; import edu.tum.cs.util.Stopwatch; import edu.tum.cs.util.StringTool; /** * MLN command-line inference tool * @author Dominik Jain */ public class MLNinfer { enum Algorithm {MaxWalkSAT, MCSAT, Toulbar2, MaxWalkSATRooms}; /** * @param args */ public static void main(String[] args) { try { String[] mlnFiles = null; String dbFile = null; String query = null; Integer maxSteps = null; String resultsFile = null; Algorithm algo = Algorithm.MCSAT; String[] cwPreds = null; boolean debug = false; HashMap<String,Object> params = new HashMap<String,Object>(); // read arguments for(int i = 0; i < args.length; i++) { if(args[i].equals("-i")) mlnFiles = args[++i].split(","); else if(args[i].equals("-q")) query = args[++i]; else if(args[i].equals("-e")) dbFile = args[++i]; else if(args[i].equals("-r")) resultsFile = args[++i]; else if(args[i].equals("-cw")) cwPreds = args[++i].split(","); else if(args[i].equals("-maxSteps")) maxSteps = Integer.parseInt(args[++i]); else if(args[i].equals("-mws")) algo = Algorithm.MaxWalkSAT; else if(args[i].equals("-mwsr")) algo = Algorithm.MaxWalkSATRooms; else if(args[i].equals("-mcsat")) algo = Algorithm.MCSAT; else if(args[i].equals("-t2")) algo = Algorithm.Toulbar2; else if(args[i].equals("-debug")) debug = true; else if(args[i].startsWith("-p") || args[i].startsWith("--")) { // algorithm-specific parameter String[] pair = args[i].substring(2).split("="); if(pair.length != 2) throw new Exception("Argument '" + args[i] + "' for algorithm-specific parameterization is incorrectly formatted."); params.put(pair[0], pair[1]); } else System.err.println("Warning: unknown option " + args[i] + " ignored!"); } if(mlnFiles == null || dbFile == null || query == null) { System.out.println("\n usage: MLNinfer <-i <(comma-sep.) MLN file(s)>> <-e <evidence db file>> <-q <comma-sep. queries>> [options]\n\n"+ " -maxSteps # the maximum number of steps to take (default determined by algorithm, if any)\n" + " -r <filename> save results to file\n" + " -mws algorithm: MaxWalkSAT (MAP inference)\n" + " -mcsat algorithm: MC-SAT (default)\n" + " -t2 algorithm: Toulbar2 branch & bound\n" + " -debug debug mode with additional outputs\n" + " -cw <predNames> set predicates as closed-world (comma-separated list of names)\n" + " --<key>=<value> set algorithm-specific parameter\n" ); return; } // determine queries Pattern comma = Pattern.compile("\\s*,\\s*"); String[] candQueries = comma.split(query); Vector<String> queries = new Vector<String>(); String q = ""; for(int i = 0; i < candQueries.length; i++) { if(!q.equals("")) q += ","; q += candQueries[i]; if(balancedParentheses(q)) { queries.add(q); q = ""; } } if(!q.equals("")) throw new IllegalArgumentException("Unbalanced parentheses in queries"); // load relational model Stopwatch constructSW = new Stopwatch(); constructSW.start(); System.out.printf("reading model %s...\n", StringTool.join(", ", mlnFiles)); MarkovLogicNetwork mln = new MarkovLogicNetwork(mlnFiles); // instantiate ground model System.out.printf("reading database %s...\n", dbFile); Database db = new Database(mln); db.readMLNDB(dbFile); if(cwPreds != null) { for(String predName : cwPreds) db.setClosedWorldPred(predName); } System.out.printf("creating ground MRF...\n"); MarkovRandomField mrf = mln.ground(db); if(debug) { System.out.println("MRF:"); for(WeightedFormula wf : mrf) System.out.println(" " + wf.toString()); } constructSW.stop(); // run inference System.out.println("starting inference process..."); Stopwatch sw = new Stopwatch(); sw.start(); InferenceAlgorithm infer = null; switch(algo) { case MCSAT: infer = new MCSAT(mrf); break; case MaxWalkSAT: infer = new MaxWalkSAT(mrf); break; case Toulbar2: infer = new Toulbar2MAPInference(mrf); break; } infer.setDebugMode(debug); if(maxSteps != null) infer.setMaxSteps(maxSteps); infer.getParameterHandler().handle(params, true); System.out.printf("algorithm: %s\n", infer.getAlgorithmName()); List<InferenceResult> results = infer.infer(queries); sw.stop(); // show results System.out.printf("\nconstruction time: %.4fs, inference time: %.4fs\n", constructSW.getElapsedTimeSecs(), sw.getElapsedTimeSecs()); System.out.println("results:"); Collections.sort(results); PrintStream out = null; if(resultsFile != null) out = new PrintStream(new File(resultsFile)); for(InferenceResult r : results) { r.print(); if(out != null) out.printf("%s %f\n", r.ga.toString().replace(" ", ""), r.value); } if(out != null) out.close(); if(infer instanceof MAPInferenceAlgorithm) { MAPInferenceAlgorithm mapi = (MAPInferenceAlgorithm)infer; double value = mrf.getWorldValue(mapi.getSolution()); System.out.printf("\nsolution value: %f\n", value); System.out.printf("\nsum of unsatisfied formula weights: %f\n", mrf.getSumOfUnsatClauseWeights(mapi.getSolution())); } } catch(Exception e) { e.printStackTrace(); } } public static boolean balancedParentheses(String s) { int n = 0; for(int i = 0; i < s.length(); i++) { if(s.charAt(i) == '(') n++; else if(s.charAt(i) == ')') n--; } return n == 0; } }