/******************************************************************************* * Copyright (C) 2008-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.bayesnets.inference; import java.util.HashMap; import probcog.bayesnets.core.BeliefNetworkEx; import edu.ksu.cis.bnj.ver3.core.BeliefNode; import edu.ksu.cis.bnj.ver3.core.Discrete; import edu.tum.cs.util.Stopwatch; /** * Gibbs Sampling MCMC inference. * @author Dominik Jain */ public class GibbsSampling extends Sampler { int[] nodeOrder; HashMap<BeliefNode, BeliefNode[]> children; public GibbsSampling(BeliefNetworkEx bn) throws Exception { super(bn); children = new HashMap<BeliefNode, BeliefNode[]>(); for(int i = 0; i < nodes.length; i++) { children.put(nodes[i], bn.bn.getChildren(nodes[i])); } nodeOrder = bn.getTopologicalOrder(); } public void _infer() throws Exception { Stopwatch sw = new Stopwatch(); // get initial setting with non-zero evidence probability out.println("initial setting..."); WeightedSample s = bn.getWeightedSample(nodeOrder, evidenceDomainIndices, generator); if(s == null) throw new Exception("Could not find an initial state with non-zero probability in given number of trials."); // do Gibbs sampling out.println("Gibbs sampling..."); sw.start(); // - get a bunch of samples for(int i = 1; i <= numSamples; i++) { if(i % infoInterval == 0) out.println(" step " + i); gibbsStep(evidenceDomainIndices, s); s.trials = 1; s.weight = 1; addSample(s); } sw.stop(); report(String.format("time taken: %.2fs (%.4fs per sample)\n", sw.getElapsedTimeSecs(), sw.getElapsedTimeSecs()/numSamples)); } public double gibbsStep(int[] evidenceDomainIndices, WeightedSample s) { double p = 1.0; // resample all of the (non-evidence) nodes for(int j = 0; j < nodes.length; j++) { // skip evidence nodes if(evidenceDomainIndices[j] != -1) continue; // initialize BeliefNode n = nodes[j]; Discrete dom = (Discrete)n.getDomain(); int domSize = dom.getOrder(); double[] distribution = new double[domSize]; // for the current node, calculate a value for each setting for(int d = 0; d < domSize; d++) { s.nodeDomainIndices[j] = d; // consider the probability of the setting given the node's parents double value = getCPTProbability(n, s.nodeDomainIndices); // consider the probability of the children's settings given the respective parents for(BeliefNode child : children.get(n)) { value *= getCPTProbability(child, s.nodeDomainIndices); } distribution[d] = value; } double sum = 0; for(int i = 0; i < distribution.length; i++) sum += distribution[i]; s.nodeDomainIndices[j] = sample(distribution, sum, generator); p = distribution[s.nodeDomainIndices[j]] / sum; } return p; } }