/******************************************************************************* * Copyright (C) 2008-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.srl.directed.inference; import java.util.HashMap; import java.util.HashSet; import probcog.bayesnets.core.BeliefNetworkEx; import probcog.bayesnets.inference.BackwardSamplingWithPriors; import probcog.bayesnets.inference.SampledDistribution; import probcog.bayesnets.inference.WeightedSample; import probcog.srl.directed.bln.GroundBLN; import edu.ksu.cis.bnj.ver3.core.BeliefNode; import edu.ksu.cis.bnj.ver3.core.CPF; import edu.ksu.cis.bnj.ver3.core.Discrete; import edu.ksu.cis.bnj.ver3.core.Domain; import edu.tum.cs.util.Stopwatch; import edu.tum.cs.util.datastruct.Cache2D; import edu.tum.cs.util.datastruct.MutableDouble; /** * Semi-lifted version of backward sampling with children, where the parameter sharing property * is fully exploited by the cache mechanism. * @author Dominik Jain */ public class LiftedBackwardSampling extends Sampler { /** * a mapping from belief node objects to integers identifying equivalence classes with respect to the algorithm */ HashMap<BeliefNode,Integer> node2class = new HashMap<BeliefNode, Integer>(); public LiftedBackwardSampling(GroundBLN gbln) throws Exception { super(gbln); } @Override public SampledDistribution _infer() throws Exception { // compute node equivalence classes with respect to the "backward sampling // with children" procedure System.out.println("computing equivalence classes..."); Integer classNo = 0; Cache2D<String, String, Integer> classes = new Cache2D<String, String, Integer>(); BeliefNetworkEx groundBN = gbln.getGroundNetwork(); for(BeliefNode node : groundBN.bn.getNodes()) { // construct string key StringBuffer key = new StringBuffer(); BeliefNode[] domprod = node.getCPF().getDomainProduct(); for(int i = 1; i < domprod.length; i++) { key.append(",").append(gbln.getCPFID(domprod[i])); for(BeliefNode c : groundBN.bn.getChildren(domprod[i])) { for(BeliefNode d : c.getCPF().getDomainProduct()) { key.append(",").append(gbln.getCPFID(d)); } } } String skey = key.toString(); // check if we already have it String mainCPFID = gbln.getCPFID(node); if(mainCPFID == null) throw new Exception("Node " + node + " has no CPF-ID"); Integer value = classes.get(mainCPFID, skey); if(value == null) { value = ++classNo; classes.put(classNo); } node2class.put(node, value); if(debug) System.out.println(node + " is class " + value + "\n " + mainCPFID + skey); } System.out.println(" reduced " + groundBN.bn.getNodes().length + " nodes to " + classNo + " equivalence classes"); // inference String[][] evidence = this.gbln.getDatabase().getEntriesAsArray(); int[] evidenceDomainIndices = gbln.getFullEvidence(evidence); Sampler sampler = new Sampler(gbln.getGroundNetwork()); sampler.setDebugMode(debug); sampler.setNumSamples(numSamples); sampler.setInfoInterval(infoInterval); sampler.setEvidence(evidenceDomainIndices); //sampler.setMaxTrials(maxTrials); //sampler.setSkipFailedSteps(skipFailedSteps); SampledDistribution dist = sampler.infer(); return dist; } /** * the actual backward sampler (largely equivalent to BackwardSamplingWithChildren) * @author Dominik Jain * */ protected class Sampler extends BackwardSamplingWithPriors { // TODO I think there are problems with the probability caches here, because using the CPFID alone ignores the fact that the priors may be different since nodes with the same CPFID may have differently instantiated ancestors // Could solve this problem by using a 3D cache that includes the prior of the node protected Cache2D<String, Integer, Double> probCache; /** * cache of backward sampling distributions */ protected Cache2D<Integer, Long, BackSamplingDistribution> distCache; protected Stopwatch probSW, distSW; protected boolean useDistributionCache = true; protected boolean useProbabilityCache = false; public class BackSamplingDistribution extends probcog.bayesnets.inference.BackwardSamplingWithPriors.BackSamplingDistribution { public BackSamplingDistribution(BackwardSamplingWithPriors sampler) { super(sampler); } /** * recursively gets a distribution to backward sample from * @param i the node to instantiate next (as an index into the CPF's domain product) * @param addr the current setting of node indices of the CPF's domain product * @param cpf the conditional probability function of the node we are backward sampling */ @Override protected void construct(int i, int[] addr, CPF cpf, int[] nodeDomainIndices) { BeliefNode[] domProd = cpf.getDomainProduct(); if(i == addr.length) { double child_prob = cpf.getDouble(addr); // temporarily set evidence boolean[] tempEvidence = new boolean[addr.length]; for(int k = 1; k < addr.length; k++) { int nodeIdx = sampler.nodeIndices.get(domProd[k]); tempEvidence[k] = nodeDomainIndices[nodeIdx] == -1; if(tempEvidence[k]) nodeDomainIndices[nodeIdx] = addr[k]; } // consider parent configuration double parent_prob = 1.0; HashSet<BeliefNode> handledChildren = new HashSet<BeliefNode>(); handledChildren.add(domProd[0]); for(int j = 1; j < addr.length; j++) { double[] parentPrior = ((BackwardSamplingWithPriors)sampler).priors.get(domProd[j]); parent_prob *= parentPrior[addr[j]]; // consider children of parents with evidence // get child probability BeliefNode[] children = sampler.bn.bn.getChildren(domProd[j]); for(BeliefNode child : children) { if(nodeDomainIndices[sampler.getNodeIndex(child)] >= 0 && !handledChildren.contains(child)) { //getProb(childCPF, 0, new int[childCPF.getDomainProduct().length], nodeDomainIndices, p); double p = getProb(child, nodeDomainIndices); parent_prob *= p; handledChildren.add(child); } } } // unset temporary evidence for(int k = 1; k < addr.length; k++) { if(tempEvidence[k]) nodeDomainIndices[sampler.nodeIndices.get(domProd[k])] = -1; } // add to distribution double p = child_prob * parent_prob; if(p != 0) { addValue(p, addr.clone()); parentProbs.add(parent_prob); } return; } int nodeIdx = sampler.nodeIndices.get(domProd[i]); if(nodeDomainIndices[nodeIdx] >= 0) { addr[i] = nodeDomainIndices[nodeIdx]; construct(i+1, addr, cpf, nodeDomainIndices); } else { Discrete dom = (Discrete)domProd[i].getDomain(); for(int j = 0; j < dom.getOrder(); j++) { addr[i] = j; construct(i+1, addr, cpf, nodeDomainIndices); } } } protected double getProb(BeliefNode node, int[] nodeDomainIndices) { CPF cpf = node.getCPF(); boolean debugCache = debug; probSW.start(); // get the key in the CPF-specific cache Double cacheValue = null; BeliefNode[] domProd = cpf.getDomainProduct(); int[] addr = new int[domProd.length]; boolean allSet = true; int key = 0; for(int i = 0; i < addr.length; i++) { int idx = nodeDomainIndices[sampler.getNodeIndex(domProd[i])]; allSet = allSet && idx >= 0; addr[i] = idx; key *= cpf._SizeBuffer[i]+1; key += idx == -1 ? cpf._SizeBuffer[i] : idx; } if(allSet) { probSW.stop(); return cpf.getDouble(addr); } // check if we already have the value in the cache Double value = null; if(useProbabilityCache) value = cacheValue = probCache.get(gbln.getCPFID(node), key); if(value != null) { probSW.stop(); if(!debugCache) return value; } // (not in the cache, so) calculate the value MutableDouble p = new MutableDouble(0.0); getProb(cpf, 0, addr, nodeDomainIndices, p); // store in cache if(useProbabilityCache) { probCache.put(p.value); if(cacheValue != null && p.value != cacheValue) { throw new RuntimeException("Probability cache mismatch"); } } // return value probSW.stop(); return p.value; } /** * gets the probability indicated by the given CPF for the given domain indices, summing over all parents whose values are not set (i.e. set to -1) in nodeDomainIndices; * i.e. computes the probability of the node whose CPF is provided given the evidence set in nodeDomainIndices * @param cpf the conditional probability function * @param i index of the next node to instantiate * @param addr the address (list of node domain indices relevant to the CPF) * @param nodeDomainIndices evidences (mapping of all nodes in the network to domain indices, -1 for no evidence) * @param ret variable in which to store the result (initialize to 0.0, because we are summing probability values) */ protected void getProb(CPF cpf, int i, int[] addr, int[] nodeDomainIndices, MutableDouble ret) { BeliefNode[] domProd = cpf.getDomainProduct(); // if all nodes have been instantiated... if(i == addr.length) { double p = cpf.getDouble(addr); for(int j = 1; j < addr.length; j++) { if(nodeDomainIndices[sampler.getNodeIndex(domProd[j])] == -1); { double[] parentPrior = ((BackwardSamplingWithPriors)sampler).priors.get(domProd[j]); p *= parentPrior[addr[j]]; } } ret.value += p; return; } // otherwise instantiate the next node BeliefNode node = domProd[i]; int nodeIdx = sampler.getNodeIndex(node); // - if we have evidence, use it if(nodeDomainIndices[nodeIdx] >= 0) { addr[i] = nodeDomainIndices[nodeIdx]; getProb(cpf, i+1, addr, nodeDomainIndices, ret); } // - otherwise sum over all settings else { Domain dom = node.getDomain(); for(int j = 0; j < dom.getOrder(); j++) { addr[i] = j; getProb(cpf, i+1, addr, nodeDomainIndices, ret); } } } } @Override protected BackSamplingDistribution getBackSamplingDistribution(BeliefNode node, WeightedSample s) { BackSamplingDistribution d; long key = 0; distSW.start(); if(useDistributionCache) { // calculate key BeliefNode[] domProd = node.getCPF().getDomainProduct(); // - consider node itself and all parents for(int i = 0; i < domProd.length; i++) { BeliefNode n = domProd[i]; int idx = s.nodeDomainIndices[getNodeIndex(n)]; int order = n.getDomain().getOrder(); key *= order + 1; key += idx == -1 ? order : idx; // - children of parents if(i != 0) { BeliefNode[] children = bn.bn.getChildren(n); for(int j = 0; j < children.length; j++) { if(children[j] != node) { n = children[j]; idx = s.nodeDomainIndices[getNodeIndex(n)]; order = n.getDomain().getOrder(); key *= order + 1; key += idx == -1 ? order : idx; // - parents of children BeliefNode[] parentsofchildren = children[j].getCPF().getDomainProduct(); for(int k = 1; k < parentsofchildren.length; k++) { n = parentsofchildren[k]; idx = s.nodeDomainIndices[getNodeIndex(n)]; order = n.getDomain().getOrder(); key *= order + 1; key += idx == -1 ? order : idx; } } } } } // check if we have a cache value d = distCache.get(node2class.get(node), key); if(d != null) return d; } // obtain new distribution d = new BackSamplingDistribution(this); d.construct(node, s.nodeDomainIndices); // store in cache if(useDistributionCache) distCache.put(d); distSW.stop(); return d; } public Sampler(BeliefNetworkEx bn) throws Exception { super(bn); } @Override public void _initialize() throws Exception { probCache = new Cache2D<String, Integer, Double>(); distCache = new Cache2D<Integer, Long, BackSamplingDistribution>(); super._initialize(); } @Override public void _infer() throws Exception { probSW = new Stopwatch(); distSW = new Stopwatch(); super._infer(); System.out.println("prob time: " + probSW.getElapsedTimeSecs()); System.out.println(String.format(" cache hit ratio: %f (%d accesses)", this.probCache.getHitRatio(), this.probCache.getNumAccesses())); System.out.println("dist time: " + distSW.getElapsedTimeSecs()); System.out.println(String.format(" cache hit ratio: %f (%d accesses)", this.distCache.getHitRatio(), this.distCache.getNumAccesses())); System.out.println(); } } }