/******************************************************************************* * Copyright (C) 2010-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.service; import java.util.Map; import java.util.Vector; import probcog.srl.BooleanDomain; import probcog.srl.Database; import probcog.srl.Signature; import probcog.srl.Variable; import probcog.srl.mln.MarkovLogicNetwork; import probcog.srl.mln.MarkovRandomField; import probcog.srl.mln.inference.InferenceAlgorithm; import probcog.srl.mln.inference.MCSAT; /** * Represents a Markov logic network model for use in the ProbCog service. * @author Dominik Jain */ public class MLNModel extends Model { protected MarkovLogicNetwork mln; protected Database db; protected MarkovRandomField mrf; public MLNModel(String name, String mln) throws Exception { super(name); this.mln = new MarkovLogicNetwork(mln); } @Override protected String _getConstantType(String constant) { return db.getConstantType(constant); } @Override public void beginSession(Map<String, Object> params) throws Exception { super.beginSession(params); db = new Database(mln); } @Override protected Vector<InferenceResult> _infer(Iterable<String> queries) throws Exception { InferenceAlgorithm ia = new MCSAT(mrf); paramHandler.addSubhandler(ia); Vector<InferenceResult> res = new Vector<InferenceResult>(); for(probcog.srl.mln.inference.InferenceResult r : ia.infer(queries)) { InferenceResult r2 = new InferenceResult(r.ga.predicate, r.ga.args, r.value); res.add(r2); } return res; } @Override protected void _setEvidence(Iterable<String[]> evidence) throws Exception { for(String[] tuple : evidence) { String functionName = tuple[0]; Signature sig = mln.getSignature(functionName); if(sig == null) throw new Exception("Function '" + functionName + "' appearing in evidence not found in model " + name); String value; String[] params; if(sig.argTypes.length == tuple.length-1) { params = new String[tuple.length-1]; for(int i = 0; i < params.length; i++) params[i] = tuple[i+1]; value = BooleanDomain.True; } else { params = new String[tuple.length-2]; for(int i = 0; i < params.length; i++) params[i] = tuple[i+1]; value = BooleanDomain.getStandardValue(tuple[tuple.length-1]); } db.addVariable(new Variable(functionName, params, value, mln)); } } @Override public Vector<String[]> getDomains() { throw new RuntimeException("not implemented"); // TODO } @Override public Vector<String[]> getPredicates() { return getPredicatesFromSignatures(mln.getSignatures()); } @Override public void instantiate() throws Exception { mrf = mln.ground(db); } }