/******************************************************************************* * Copyright (C) 2008-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.srl.directed.inference; import probcog.bayesnets.core.BeliefNetworkEx; import probcog.bayesnets.inference.ITimeLimitedInference; import probcog.bayesnets.inference.SampledDistribution; import probcog.srl.directed.bln.AbstractGroundBLN; /** * Bayesian Network Sampler: reduces inference in relational models to standard Bayesian network * inference in the ground (auxiliary) network. * @author Dominik Jain */ public class BNSampler extends Sampler implements ITimeLimitedInference { protected int maxTrials; /** * whether steps that exceed the max number of trials should just be skipped rather than raising an exception */ protected boolean skipFailedSteps; protected Class<? extends probcog.bayesnets.inference.Sampler> samplerClass; protected probcog.bayesnets.inference.Sampler sampler; /** * the evidence we are working on */ protected int[] evidenceDomainIndices; public BNSampler(AbstractGroundBLN gbln, Class<? extends probcog.bayesnets.inference.Sampler> samplerClass) throws Exception { super(gbln); maxTrials = 5000; this.paramHandler.add("maxTrials", "setMaxTrials"); this.paramHandler.add("skipFailedSteps", "setSkipFailedSteps"); this.samplerClass = samplerClass; } public void setMaxTrials(int maxTrials) { this.maxTrials = maxTrials; } public void setSkipFailedSteps(boolean canSkip) { this.skipFailedSteps = canSkip; } @Override protected void _initialize() throws Exception { // create full evidence String[][] evidence = this.gbln.getDatabase().getEntriesAsArray(); evidenceDomainIndices = gbln.getFullEvidence(evidence); // initialize sampler sampler = getSampler(); paramHandler.addSubhandler(sampler.getParameterHandler()); sampler.setEvidence(evidenceDomainIndices); sampler.setQueryVars(queryVars); sampler.setDebugMode(debug); sampler.setNumSamples(numSamples); sampler.setInfoInterval(infoInterval); sampler.setMaxTrials(maxTrials); sampler.setSkipFailedSteps(skipFailedSteps); sampler.initialize(); } @Override public SampledDistribution _infer() throws Exception { // run inference if(verbose) System.out.printf("running %s...\n", sampler.getAlgorithmName()); SampledDistribution dist = sampler.infer(); return dist; } protected probcog.bayesnets.inference.Sampler getSampler() throws Exception { return samplerClass.getConstructor(BeliefNetworkEx.class).newInstance(gbln.getGroundNetwork()); } @Override public String getAlgorithmName() { return "BNInference:" + samplerClass.getSimpleName(); } public SampledDistribution pollResults() throws Exception { if(sampler == null) return null; return sampler.pollResults(); } }