/******************************************************************************* * Copyright (C) 2009-2012 Ralf Wernicke, Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.srl.mln; import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.PrintStream; import java.io.StringReader; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Set; import java.util.Vector; import java.util.regex.Matcher; import java.util.regex.Pattern; import probcog.logic.Formula; import probcog.logic.parser.ParseException; import probcog.logic.sat.weighted.WeightedFormula; import probcog.srl.Database; import probcog.srl.RelationKey; import probcog.srl.RelationalModel; import probcog.srl.Signature; import probcog.srl.taxonomy.Taxonomy; import probcog.tools.JythonInterpreter; import edu.tum.cs.util.FileUtil; /** * represents a Markov logic network * @author Ralf Wernicke * @author Dominik Jain */ public class MarkovLogicNetwork implements RelationalModel { protected File mlnFile; //protected HashMap<Formula, Double> formula2weight; protected Vector<WeightedFormula> formulas; /** * maps a predicate name to its signature */ protected HashMap<String, Signature> signatures; /** * maps domain/type names to a list of guaranteed domain elements */ protected HashMap<String, HashSet<String>> guaranteedDomainElements; /** * mapping from predicate name to index of argument that is functionally determined */ protected HashMap<String, Integer> functionalPreds; double sumAbsWeights = 0; /** * constructs a Markov logic network from an MLN file * @param mlnFileLoc location of the MLN-file * @throws Exception */ public MarkovLogicNetwork(String mlnFileLoc) throws Exception { this(); // read the complete MLN-File and save it in a String mlnFile = new File(mlnFileLoc); String content = FileUtil.readTextFile(mlnFile); read(content); } public MarkovLogicNetwork(String[] mlnFiles) throws Exception { this(); mlnFile = new File(mlnFiles[0]); StringBuffer content = new StringBuffer(); for(String filename : mlnFiles) { content.append(FileUtil.readTextFile(filename)); content.append("\n"); } read(content.toString()); } /** * constructs an empty MLN */ public MarkovLogicNetwork() { mlnFile = null; signatures = new HashMap<String, Signature>(); functionalPreds = new HashMap<String, Integer>(); guaranteedDomainElements = new HashMap<String, HashSet<String>>(); formulas = new Vector<WeightedFormula>(); } /** * adds a predicate signature to this model * @param sig signature of this predicate */ public void addSignature(Signature sig) { signatures.put(sig.functionName, sig); } public void addFormula(Formula f, double weight) throws Exception { f.addConstantsToModel(this); this.formulas.add(new WeightedFormula(f, weight, false)); } public void addHardFormula(Formula f) throws Exception { f.addConstantsToModel(this); this.formulas.add(new WeightedFormula(f, getHardWeight(), true)); } public void addFunctionalDependency(String predicateName, Integer index) { this.functionalPreds.put(predicateName, index); } public void addGuaranteedDomainElement(String domain, String element) { HashSet<String> s = guaranteedDomainElements.get(domain); if(s == null) guaranteedDomainElements.put(domain, s=new HashSet<String>()); s.add(element); } public void addGuaranteedDomainElements(String domain, String[] elements) { for(String e : elements) addGuaranteedDomainElement(domain, e); } public Vector<WeightedFormula> getFormulas() { return formulas; } /** * returns the signature for the given predicate * @param predName name of predicate (signature for this predicate will be returned) * @return */ public Signature getSignature(String predName) { return signatures.get(predName); } /** * gets the functionally determined argument of a functional predicate * @return the index of the argument that is functionally determined or null if there is no such argument */ public Integer getFunctionallyDeterminedArgument(String predicateName) { return this.functionalPreds.get(predicateName); } /** * Method that grounds MLN to a MarkovRandomField * @param dbFileLoc file location of evidence for this scenario * @return returns a grounded MLN as a MarkovRandomField MRF * @throws Exception */ public MarkovRandomField ground(Database db) throws Exception { return ground(db, true, null); } public MarkovRandomField ground(Database db, boolean storeFormulasInMRF, GroundingCallback gc) throws Exception { return new MarkovRandomField(this, db, storeFormulasInMRF, gc); } /** * reads the contents of an MLN file * @throws Exception */ public void read(String content) throws Exception { String actLine; ArrayList<Formula> hardFormulas = new ArrayList<Formula>(); // remove all comments Pattern comments = Pattern.compile("//.*?$|/\\*.*?\\*/", Pattern.MULTILINE | Pattern.DOTALL); Matcher matcher = comments.matcher(content); content = matcher.replaceAll(""); BufferedReader breader = new BufferedReader(new StringReader(content)); String identifier = "\\w+"; String constant = "(?:[A-Z]\\w*|[0-9]+)"; // predicate declaration Pattern predDecl = Pattern.compile(String.format("(%s)\\(\\s*(%s!?(?:\\s*,\\s*%s!?)*)\\s*\\)", identifier, identifier, identifier)); // domain declaration Pattern domDecl = Pattern.compile(String.format("(%s)\\s*=\\s*\\{\\s*(%s(?:\\s*,\\s*%s)*)\\s*\\}", identifier, constant, constant)); JythonInterpreter jython = null; // parse line by line for(actLine = breader.readLine(); breader != null && actLine != null; actLine = breader.readLine()) { String line = actLine.trim(); if(line.length() == 0) continue; // hard constraint if(line.endsWith(".")) { Formula f; String strF = line.substring(0, line.length() - 1); try { f = Formula.fromString(strF); } catch(ParseException e) { throw new Exception("The hard formula '" + strF + "' could not be parsed: " + e.toString()); } hardFormulas.add(f); continue; } // predicate declaration Matcher m = predDecl.matcher(line); if(m.matches()) { String predName = m.group(1); Signature sig = getSignature(predName); if(sig != null) { throw new Exception(String.format("Signature declared in line '%s' was previously declared as '%s'", line, sig.toString())); } String[] argTypes = m.group(2).trim().split("\\s*,\\s*"); for (int c = 0; c < argTypes.length; c++) { // check whether it's a blockvariable if(argTypes[c].endsWith("!")) { argTypes[c] = argTypes[c].replace("!", ""); Integer oldValue = functionalPreds.put(predName, c); if(oldValue != null) throw new Exception(String.format("Predicate '%s' was declared to have more than one functionally determined parameter", predName)); } } sig = new Signature(predName, "boolean", argTypes); addSignature(sig); continue; } // domain declaration m = domDecl.matcher(line); if(m.matches()) { addGuaranteedDomainElements(m.group(1), m.group(2).trim().split("\\s*,\\s*")); continue; } // must be a weighted formula int iSpace = line.indexOf(' '); if(iSpace == -1) throw new Exception("This line is not a correct declaration of a weighted formula: " + line); String strWeight = line.substring(0, iSpace); Double weight = null; try { weight = Double.parseDouble(strWeight); } catch(NumberFormatException e) { if(jython == null) { jython = new JythonInterpreter(); jython.exec("from math import *"); jython.exec("def logx(x):\n if x == 0: return -100\n return log(x)"); } try { weight = jython.evalDouble(strWeight); } catch(Exception e2) { throw new Exception("Could not interpret weight '" + strWeight + "': " + e2.toString()); } } String strF = line.substring(iSpace+1).trim(); Formula f; try { f = Formula.fromString(strF); } catch(ParseException e) { throw new Exception("The formula '" + strF + "' could not be parsed: " + e.toString()); } addFormula(f, weight); sumAbsWeights += Math.abs(weight); } for (Formula f : hardFormulas) addHardFormula(f); } /** * @return the weight used for hard constraints */ public double getHardWeight() { return sumAbsWeights + 100000; // TODO this number should be selected with extreme care (especially for MPE inference it is very relevant); we should set it to the sum of abs. weights of soft formulas in the *ground* model + X } /** * replace a type by a new type in all function signatures * @param oldType * @param newType */ public void replaceType(String oldType, String newType) { for(Signature sig : signatures.values()) sig.replaceType(oldType, newType); } /** * @return a mapping from domain names to arrays of elements */ @Override public HashMap<String, HashSet<String>> getGuaranteedDomainElements() { return guaranteedDomainElements; } /** * * @param relation * @return */ public Collection<RelationKey> getRelationKeys(String relation) { // TODO return null; } /** * @return the set of functional predicates (i.e. their names) */ public Set<String> getFunctionalPreds() { return functionalPreds.keySet(); } public Collection<Signature> getSignatures() { return this.signatures.values(); } /** * @return null because MLNs do not use a taxonomy */ public Taxonomy getTaxonomy() { return null; } public void write(PrintStream out) { MLNWriter writer = new MLNWriter(out); // domain declarations if(this.guaranteedDomainElements.size() > 0) { out.println("// domain declarations"); for(java.util.Map.Entry<String,? extends Iterable<String>> e : this.getGuaranteedDomainElements().entrySet()) { writer.writeDomainDecl(e.getKey(), e.getValue()); } out.println(); } // predicate declarations out.println("// predicate declarations"); for(Signature sig : this.getSignatures()) { writer.writePredicateDecl(sig, this.getFunctionallyDeterminedArgument(sig.functionName)); } out.println(); out.println("// formulas"); for(WeightedFormula wf : getFormulas()) { if(wf.isHard) out.printf("%s.\n", wf.formula.toString()); else out.printf("%f %s\n", wf.weight, wf.formula.toString()); } } public void write(File f) throws FileNotFoundException { write(new PrintStream(f)); } @Override public Collection<String> getPrologRules() { return null; } }