/******************************************************************************* * Copyright (C) 2008-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.service; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.Map; import java.util.Vector; import java.util.Map.Entry; import probcog.logic.parser.ParseException; import probcog.srl.Database; import probcog.srl.Signature; import probcog.srl.Variable; import probcog.srl.directed.RelationalNode; import probcog.srl.directed.bln.BayesianLogicNetwork; import probcog.srl.directed.bln.GroundBLN; import probcog.srl.directed.inference.BLNinfer; import edu.tum.cs.util.StringTool; import edu.tum.cs.util.datastruct.Pair; /** * Represents a Bayesian logic network model for use in the ProbCog service. * @author Dominik Jain */ public class BLNModel extends Model { protected BayesianLogicNetwork bln; protected GroundBLN gbln; protected Database db; protected String filenames; public BLNModel(String modelName, String blogFile, String networkFile, String logicFile) throws IOException, ParseException, Exception { super(modelName); this.filenames = String.format("%s;%s;%s", blogFile, networkFile, logicFile); this.bln = new BayesianLogicNetwork(blogFile, networkFile, logicFile); } @Override public void instantiate() throws Exception { gbln = bln.ground(db); paramHandler.addSubhandler(gbln); gbln.instantiateGroundNetwork(); } @Override public void beginSession(Map<String, Object> params) throws Exception { super.beginSession(params); db = new Database(bln.rbn); paramHandler.addSubhandler(db); } @Override protected Vector<InferenceResult> _infer(Iterable<String> queries) throws Exception { BLNinfer inference = new BLNinfer(actualParams); paramHandler.addSubhandler(inference); inference.setGroundBLN(gbln); inference.setQueries(queries); Collection<probcog.srl.directed.inference.InferenceResult> results = inference.run(); // store results in common InferenceResult format Vector<InferenceResult> ret = new Vector<InferenceResult>(); for(probcog.srl.directed.inference.InferenceResult res : results) { Pair<String, String[]> var = RelationalNode.parse(res.varName); Signature sig = bln.rbn.getSignature(var.first); String[] params = var.second; boolean isBool = sig.isBoolean(); if(!isBool) { String[] fullParams = new String[params.length+1]; for(int i = 0; i < params.length; i++) fullParams[i] = params[i]; params = fullParams; } for(int i = 0; i < res.domainElements.length; i++) { if(!isBool) params[params.length-1] = res.domainElements[i]; else if(!res.domainElements[i].equalsIgnoreCase("True")) continue; ret.add(new InferenceResult(var.first, params.clone(), res.probabilities[i])); } } return ret; } @Override protected void _setEvidence(Iterable<String[]> evidence) throws Exception { for(String[] tuple : evidence) { String functionName = tuple[0]; Signature sig = bln.rbn.getSignature(functionName); if(sig == null) throw new Exception("Function '" + functionName + "' appearing in evidence not found in model " + name); String value; String[] params; if(sig.argTypes.length == tuple.length-1) { params = new String[tuple.length-1]; for(int i = 0; i < params.length; i++) params[i] = tuple[i+1]; value = "True"; } else { if(tuple.length < sig.argTypes.length+2) throw new Exception("Evidence entry has too few parameters: " + StringTool.join(", ", tuple)); params = new String[sig.argTypes.length]; for(int i = 0; i < params.length; i++) params[i] = tuple[i+1]; value = tuple[params.length+1]; } db.addVariable(new Variable(functionName, params, value, this.bln.rbn)); } } @Override public Vector<String[]> getPredicates() { return getPredicatesFromSignatures(this.bln.rbn.getSignatures()); } public Vector<String[]> getDomains() { Vector<String[]> ret = new Vector<String[]>(); for(Entry<String,? extends Collection<String>> e : this.bln.rbn.getGuaranteedDomainElements().entrySet()) { Collection<String> elems = e.getValue(); ArrayList<String> tuple = new ArrayList<String>(elems.size()+1); tuple.add(e.getKey()); for(String elem : elems) { String c = mapConstantFromProbCog(elem); if(c == null) continue; tuple.add(c); } ret.add(tuple.toArray(new String[tuple.size()])); } return ret; } @Override protected String _getConstantType(String constant) { return db.getConstantType(constant); } @Override public String toString() { return String.format("%s=BLN[%s]", this.name, this.filenames); } }