/******************************************************************************* * Copyright (C) 2009-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.bayesnets.inference; import java.math.BigInteger; import java.util.HashMap; import java.util.Vector; import probcog.bayesnets.core.BeliefNetworkEx; import edu.ksu.cis.bnj.ver3.core.BeliefNode; import edu.ksu.cis.bnj.ver3.core.CPF; import edu.tum.cs.util.Stopwatch; import edu.tum.cs.util.datastruct.Map2D; /** * Simple implementation of the SampleSearch algorithm by Gogate & Dechter. * <p>The implementation will apply an unbiased estimator (which uses more memory) * only if enabled via {@link #setUseProperWeighting}.</p> * * @author Dominik Jain */ public class SampleSearch extends Sampler { protected int[] nodeOrder; protected int currentStep; protected double[] samplingProb; protected boolean useProperWeighting = false; protected boolean usingTopologicalOrdering = true; protected ImportanceFunction importanceFunction = ImportanceFunction.Prior; protected SampledDistribution importanceDist = null; protected int importanceFunctionSteps = 2; protected enum ImportanceFunction { Prior, BP, IJGP; } public SampleSearch(BeliefNetworkEx bn) throws Exception { super(bn); this.paramHandler.add("importanceFunction", "setImportanceFunction"); this.paramHandler.add("ifSteps", "setImportanceFunctionSteps"); this.paramHandler.add("bpSteps", "setImportanceFunctionSteps"); this.paramHandler.add("ijgpSteps", "setImportanceFunctionSteps"); this.paramHandler.add("unbiased", "setUseProperWeighting"); } @Override protected void _initialize() throws Exception { // TODO could help to guarantee for BLNs that formula nodes appear as early as possible nodeOrder = computeNodeOrdering(); samplingProb = new double[nodes.length]; if(importanceFunction != ImportanceFunction.Prior) { if(verbose) System.out.println("computing importance function with " + importanceFunction + "..."); Sampler s = importanceFunction == ImportanceFunction.BP ? new BeliefPropagation(this.bn) : new IJGP(bn); s.setNumSamples(importanceFunctionSteps); s.setEvidence(this.evidenceDomainIndices); importanceDist = s.infer(); if(debug) { System.out.println("importance distribution:"); importanceDist.print(System.out); } } } public void setImportanceFunction(String name) { importanceFunction = ImportanceFunction.valueOf(name); } public void setImportanceFunctionSteps(int steps) { this.importanceFunctionSteps = steps; } protected int[] computeNodeOrdering() throws Exception { return bn.getTopologicalOrder(); } public void setUseProperWeighting(boolean enabled){ useProperWeighting = enabled; } protected void info(int step) { out.println(" step " + step); } @Override public void _infer() throws Exception { // sample Stopwatch sw = new Stopwatch(); out.println("sampling..."); sw.start(); WeightedSample s = new WeightedSample(bn); for(int i = 1; i <= numSamples; i++) { currentStep = i; if(i % infoInterval == 0) info(i); WeightedSample ret = getWeightedSample(s, nodeOrder, evidenceDomainIndices); if(ret != null) { addSample(ret); /* // debugging of weighting out.print("w=" + ret.weight); double prod = 1.0; for(int j = 0; j < evidenceDomainIndices.length; j++) if(true || evidenceDomainIndices[j] == -1) { BeliefNode node = nodes[j]; out.print(" " + node.getName() + "=" + node.getDomain().getName(s.nodeDomainIndices[j])); double p = bn.getCPTProbability(node, s.nodeDomainIndices); out.printf(" %f", p); if(p == 0.0) throw new Exception("Sample has 0 probability."); prod *= p; if(prod == 0.0) throw new Exception("Precision loss - product became 0"); } out.println(); */ } if(converged()) break; } SampledDistribution dist = distributionBuilder.getDistribution(); report(String.format("time taken: %.2fs (%.4fs per sample, %.1f trials/sample, %.4f*N assignments/sample, %d samples)\n", sw.getElapsedTimeSecs(), sw.getElapsedTimeSecs()/numSamples, dist.getTrialsPerStep(), (float)dist.operations/nodes.length/numSamples, dist.steps)); } public WeightedSample getWeightedSample(WeightedSample s, int[] nodeOrder, int[] evidenceDomainIndices) throws Exception { s.trials = 1; s.operations = 0; s.weight = 1.0; // assign values to the nodes in order HashMap<Integer, boolean[]> domExclusions = new HashMap<Integer, boolean[]>(); for(int i=0; i < nodeOrder.length;) { s.operations++; int nodeIdx = nodeOrder[i]; int domainIdx = evidenceDomainIndices[nodeIdx]; // get domain exclusions boolean[] excluded = domExclusions.get(nodeIdx); if(excluded == null) { excluded = new boolean[nodes[nodeIdx].getDomain().getOrder()]; domExclusions.put(nodeIdx, excluded); } // debug info if(debug) { int numex = 0; for(int j=0; j<excluded.length; j++) if(excluded[j]) numex++; out.printf(" step %d, node %d '%s' (%d/%d exclusions)\n", currentStep, i, nodes[nodeIdx].getName(), numex, excluded.length); } // for evidence nodes, we can continue if the evidence probability was non-zero if(domainIdx >= 0) { s.nodeDomainIndices[nodeIdx] = domainIdx; samplingProb[nodeIdx] = 1.0; double prob = getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices); if(prob != 0.0) { ++i; continue; } else { if(debug) out.println(" evidence with probability 0.0; backtracking..."); } } // for non-evidence nodes, do forward sampling else { SampledAssignment sa = sampleForward(nodes[nodeIdx], s.nodeDomainIndices, excluded); if(sa != null) { domainIdx = sa.domIdx; samplingProb[nodeIdx] = sa.probability; s.nodeDomainIndices[nodeIdx] = domainIdx; ++i; continue; } else if(debug) out.println(" impossible case; backtracking..."); } // if we get here, we need to backtrack to the last non-evidence node s.trials++; do { // kill the current node's exclusions domExclusions.remove(nodeIdx); // add the previous node's setting as an exclusion --i; if(i < 0) throw new Exception("Could not find a sample with non-zero probability. Most likely, the evidence specified has 0 probability."); nodeIdx = nodeOrder[i]; boolean[] prevExcl = domExclusions.get(nodeIdx); prevExcl[s.nodeDomainIndices[nodeIdx]] = true; // proceed with previous node... } while(evidenceDomainIndices[nodeIdx] != -1); } return s; } public class SampledAssignment { public int domIdx; public double probability; public SampledAssignment(int domainIdx, double p) { domIdx = domainIdx; probability = p; } } /** * samples forward, i.e. samples a value for 'node' given its parents * @param node the node for which to sample a value * @param nodeDomainIndices array of domain indices for all nodes in the network; the values for the parents of 'node' must be set already * @return the index of the domain element of 'node' that is sampled, or -1 if sampling is impossible because all entries in the relevant column are 0 */ protected SampledAssignment sampleForwardPrior(BeliefNode node, int[] nodeDomainIndices, boolean[] excluded) { CPF cpf = node.getCPF(); BeliefNode[] domProd = cpf.getDomainProduct(); int[] addr = new int[domProd.length]; // get the addresses of the first two relevant fields and the difference between them for(int i = 1; i < addr.length; i++) addr[i] = nodeDomainIndices[this.nodeIndices.get(domProd[i])]; addr[0] = 0; // (the first element in the index into the domain of the node we are sampling) int realAddr = cpf.addr2realaddr(addr); addr[0] = 1; int diff = cpf.addr2realaddr(addr) - realAddr; // diff is the address difference between two consecutive entries in the relevant column // get probabilities for outcomes double[] cpt_entries = new double[domProd[0].getDomain().getOrder()]; double sum = 0; for(int i = 0; i < cpt_entries.length; i++) { double value; if(excluded[i]) value = 0.0; else value = cpf.getDouble(realAddr); cpt_entries[i] = value; sum += value; realAddr += diff; } // if the column contains only zeros, it is an impossible case -> cannot sample if(sum == 0) return null; int domIdx = sample(cpt_entries, sum, generator); return new SampledAssignment(domIdx, cpt_entries[domIdx]/sum); } protected SampledAssignment sampleForward(BeliefNode node, int[] nodeDomainIndices, boolean[] excluded) { if(this.importanceDist == null) return sampleForwardPrior(node, nodeDomainIndices, excluded); CPF cpf = node.getCPF(); BeliefNode[] domProd = cpf.getDomainProduct(); int[] addr = new int[domProd.length]; // get the addresses of the first two relevant fields and the difference between them for(int i = 1; i < addr.length; i++) addr[i] = nodeDomainIndices[this.nodeIndices.get(domProd[i])]; addr[0] = 0; // (the first element in the index into the domain of the node we are sampling) int realAddr = cpf.addr2realaddr(addr); addr[0] = 1; int diff = cpf.addr2realaddr(addr) - realAddr; // diff is the address difference between two consecutive entries in the relevant column // get probabilities for outcomes // If we are sampling in top. order, we always additionally filter // values that are zero given the parents double[] samplingDist = importanceDist.getDistribution(getNodeIndex(node)); double sum = 0; for(int i = 0; i < samplingDist.length; i++) { Double cptValue = null; if(usingTopologicalOrdering) cptValue = cpf.getDouble(realAddr); if(excluded[i] || (cptValue != null && cptValue.equals(0.0))) samplingDist[i] = 0.0; sum += samplingDist[i]; realAddr += diff; } // if the column contains only zeros, it is an impossible case -> cannot sample if(sum == 0) return null; int domIdx = sample(samplingDist, sum, generator); return new SampledAssignment(domIdx, samplingDist[domIdx]/sum); } @Override public String getAlgorithmName() { return super.getAlgorithmName() + "[" + importanceFunction + "]"; } @Override protected IDistributionBuilder createDistributionBuilder() throws Exception { if(useProperWeighting) return new UnbiasedEstimator(); else return new BiasedEstimator(); } /** * simple but biased estimator */ protected class BiasedEstimator extends DirectDistributionBuilder { public BiasedEstimator() throws Exception { super(createDistribution()); } @Override public void addSample(WeightedSample s) { // do weighting s.weight = 1.0; for(int i = 0; i < nodes.length; i++) { s.weight *= getCPTProbability(nodes[i], s.nodeDomainIndices) / samplingProb[i]; } // directly add to distribution super.addSample(s); } } /** * unbiased "max"-estimator (additional storage space and computation time required) */ protected class UnbiasedEstimator implements IDistributionBuilder { protected Map2D<Integer,BigInteger,Double> maxQ; protected Vector<WeightedSample> samples; protected SampledDistribution dist; protected boolean dirty = false; public UnbiasedEstimator() throws Exception { maxQ = new Map2D<Integer,BigInteger,Double>(); samples = new Vector<WeightedSample>(); } @Override public synchronized void addSample(WeightedSample s) throws Exception { BigInteger partAssign = BigInteger.valueOf(0); Vector<Integer> partAssign2 = new Vector<Integer>(); for(int i = 0; i < nodeOrder.length; i++) { int nodeIdx = nodeOrder[i]; if(evidenceDomainIndices[nodeIdx] < 0) { partAssign = partAssign.multiply(BigInteger.valueOf(nodes[nodeIdx].getDomain().getOrder())); partAssign = partAssign.add(BigInteger.valueOf(s.nodeDomainIndices[nodeIdx])); partAssign2.add(s.nodeDomainIndices[nodeIdx]); Double p = maxQ.get(i, partAssign); if(p == null || samplingProb[nodeIdx] > p) { this.maxQ.put(i, partAssign, samplingProb[nodeIdx]); } } } samples.add(s.clone()); dirty = true; } @Override public synchronized SampledDistribution getDistribution() throws Exception { if(!dirty) return dist; System.out.println("unbiased sample weighting..."); dist = createDistribution(); for(WeightedSample s : samples) { s.weight = 1.0; BigInteger partAssign = BigInteger.valueOf(0); for(int i = 0; i < nodeOrder.length; i++) { int nodeIdx = nodeOrder[i]; if(evidenceDomainIndices[nodeIdx] < 0) { partAssign = partAssign.multiply(BigInteger.valueOf(nodes[nodeIdx].getDomain().getOrder())); partAssign = partAssign.add(BigInteger.valueOf(s.nodeDomainIndices[nodeIdx])); s.weight *= getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices) / maxQ.get(i, partAssign); } else s.weight *= getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices); } dist.addSample(s); } dirty = false; return dist; } } }