/******************************************************************************* * Copyright (C) 2011-2012 Dominik Jain, Klaus von Gleissenthall. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.bayesnets.inference; import java.math.BigInteger; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.PriorityQueue; import java.util.Random; import java.util.Set; import java.util.Vector; import probcog.bayesnets.core.BeliefNetworkEx; import probcog.bayesnets.util.TopologicalOrdering; import probcog.bayesnets.util.TopologicalSort; import edu.ksu.cis.bnj.ver3.core.BeliefNode; import edu.ksu.cis.bnj.ver3.core.CPF; import edu.tum.cs.util.datastruct.Map2D; import edu.tum.cs.util.datastruct.Map2List; import edu.tum.cs.util.datastruct.Map2Set; import edu.tum.cs.util.datastruct.Pair; /** * Backward SampleSearch: a combination of backward simulation and sample searching * with backtracking. This simple implementation uses chronological backtracking. * @author Dominik Jain */ public class BackwardSampleSearch extends BackwardSamplingWithPriors { protected enum NodeMode {Backward, Forward, Outside}; protected TopologicalOrdering topOrder; protected boolean useProperWeighting = false; /** * the order in which nodes are sampled along with the mode in which they are to be handled */ protected Vector<Pair<BeliefNode, NodeMode>> samplingOrder; /** * the index in the sampling order of the node currently being treated */ protected int currentOrderIndex; /** * the currently sampled indices that were sampled from the applicable distributions, * i.e. sampledIndices[i] = j if for the i-th node in the sampling order, we sampled * the j-th value in the distribution (regardless whether it is a backward sampling * distribution or a condition distribution) */ protected int[] sampledIndices; /** * mapping from the sampling order index of a node N to a list of node indices * that are assigned by backward sampling N */ protected Map2List<Integer, Integer> assignedNodeIndicesByOrderIndex; /** * used only for proper weighting scheme */ protected HashMap<BeliefNode, BackSamplingDistribution> backSamplingDistributionCache; protected HashMap<BeliefNode,Double> weightingFactors; /** * whether to cache backward sampling distributions */ protected boolean useCache = true; public BackwardSampleSearch(BeliefNetworkEx bn) throws Exception { super(bn); this.paramHandler.add("unbiased","setUseProperWeighting"); } public void setUseProperWeighting(boolean enabled){ useProperWeighting = enabled; } protected boolean sampleForward(BeliefNode node, WeightedSample s, Set<Integer> excluded) throws Exception { CPF cpf = node.getCPF(); BeliefNode[] domProd = cpf.getDomainProduct(); int[] addr = new int[domProd.length]; addr[0] = 0; for(int i = 1; i < addr.length; i++) addr[i] = s.nodeDomainIndices[this.nodeIndices.get(domProd[i])]; int realAddr = cpf.addr2realaddr(addr); // address of the first element in the distribution we sample from int addrOffset = cpf.getColumnValueAddressOffset(); // get probabilities for outcomes double[] cpt_entries = new double[domProd[0].getDomain().getOrder()]; double sum = 0; double value; for(int i = 0; i < cpt_entries.length; i++) { if(excluded != null && excluded.contains(i)) { value = 0.0; //System.out.println("forward exclusion"); } else value = cpf.getDouble(realAddr); if(debug) out.printf(" %d: %f\n", i, value); cpt_entries[i] = value; sum += value; realAddr += addrOffset; } if(sum == 0) return false; // sample int domIdx = sample(cpt_entries, sum, generator); s.nodeDomainIndices[this.nodeIndices.get(node)] = domIdx; sampledIndices[currentOrderIndex] = domIdx; // remember weighting factor weightingFactors.put(node, getCPTProbability(node, s.nodeDomainIndices) / (cpt_entries[domIdx] / sum)); if(debug) out.println(" assigned " + domIdx); return true; } protected boolean sampleBackward(BeliefNode node, WeightedSample s, Set<Integer> excluded) throws Exception{ // get backward sampling distribution BackSamplingDistribution d = backSamplingDistributionCache.get(node); if(d == null) { d = getBackSamplingDistribution(node, s); if(useCache) backSamplingDistributionCache.put(node, d); } // get normalisation constant double Z = 0.0; Integer i = 0; int numValues = 0; for(Double v : d.distribution) { if(excluded == null || !excluded.contains(i)) { if(v > 0) { Z += v; numValues++; } } ++i; } if(Z == 0.0) return false; if(debug) System.out.printf(" %d choosable values in distribution\n", numValues); // sample a value int idx = sample(d.distribution, Z, excluded, generator); int[] state = d.states.get(idx); this.sampledIndices[currentOrderIndex] = idx; // apply sampled parent setting boolean buildAssignedIndices = assignedNodeIndicesByOrderIndex.get(currentOrderIndex) == null; BeliefNode[] domProd = node.getCPF().getDomainProduct(); for(i = 1; i < state.length; i++) { int nodeIdx = getNodeIndex(domProd[i]); if(buildAssignedIndices && s.nodeDomainIndices[nodeIdx] == -1) assignedNodeIndicesByOrderIndex.add(currentOrderIndex, nodeIdx); s.nodeDomainIndices[nodeIdx] = state[i]; } // save same data for weighting this.weightingFactors.put(node, Z / d.parentProbs.get(idx)); return true; } /** * sampling from a distribution with exclusions * @param distribution the distribution * @param sum the normalization constant of the distribution (which must consider the exclusions already) * @param excluded the set of excluded distribution indices * @param generator * @return a distribution index or -1 if no value can be sampled */ public static int sample(Collection<Double> distribution, double sum, Set<Integer> excluded, Random generator) { double random = generator.nextDouble() * sum; sum = 0; Integer i = 0; for(Double d : distribution) { if(excluded == null || !excluded.contains(i)) { sum += d; if(sum >= random) return i; } ++i; } return -1; } @Override protected void _initialize() throws Exception { super._initialize(); sampledIndices = new int[samplingOrder.size()]; assignedNodeIndicesByOrderIndex = new Map2List<Integer, Integer>(); weightingFactors = new HashMap<BeliefNode, Double>(); } @Override public void getSample(WeightedSample s) throws Exception { Map2Set<BeliefNode,Integer> domExclusions = new Map2Set<BeliefNode,Integer>(); initSample(s); backSamplingDistributionCache = new HashMap<BeliefNode, BackSamplingDistribution>(); boolean backtracking = false; for(int i = 0; i < samplingOrder.size();) { currentOrderIndex = i; Pair<BeliefNode,NodeMode> p = samplingOrder.get(i); // get the node BeliefNode node = p.first; NodeMode mode = p.second; // if we got to the node backtracking, we add the last value as an exclusion if(backtracking) { switch(mode) { case Backward: domExclusions.add(node, sampledIndices[i]); for(Integer idx : assignedNodeIndicesByOrderIndex.get(i)) s.nodeDomainIndices[idx] = -1; break; case Forward: domExclusions.add(node, sampledIndices[i]); s.nodeDomainIndices[getNodeIndex(node)] = -1; break; case Outside: --i; continue; } backtracking = false; } // info if(debug) out.printf(" Op%d: #%d %s\n", ++s.operations, i, node.getName()); else if(infoInterval == 1) out.printf("#%d \r", i); // get domain exclusions Set<Integer> excluded = domExclusions.get(node); boolean valueSuccessfullyAssigned = true; switch(mode) { case Backward: if(debug) out.printf(" backward sampling (%d exclusions)\n", excluded == null ? 0 : excluded.size()); //Stopwatch sw3 = new Stopwatch(); //sw3.start(); if(!sampleBackward(node, s, excluded)){ //if (debug) out.println("CPT contains only zeros for backward sampled node: "+ node); valueSuccessfullyAssigned = false; } break; case Forward: if(debug) out.printf(" forward sampling (%d exclusions)\n", excluded == null ? 0 : excluded.size()); if(!sampleForward(node, s, excluded)){ //if (debug) out.println("CPT contains only zeros for forward sampled node: "+ node); valueSuccessfullyAssigned = false; } break; case Outside: if(debug) out.printf(" outside sampling order\n", excluded == null ? 0 : excluded.size()); double prob = this.getCPTProbability(node, s.nodeDomainIndices); if(prob == 0.0) valueSuccessfullyAssigned = false; break; } //out.println("Node "+ node+ "has prob.: "+ getCPTProbability(node,s.nodeDomainIndices)); //check if the sample is consistent //if (getCPTProbability(node,s.nodeDomainIndices)==0.0){ if(!valueSuccessfullyAssigned){ // backtrack if(debug) System.out.println(" backtracking"); domExclusions.remove(node); if(mode == NodeMode.Backward) backSamplingDistributionCache.remove(node); backtracking = true; --i; if(i < 0) throw new Exception("Evidence seems to be contradictory"); s.trials++; } else { // go forward ++i; } } } /** * gets the sampling order by filling the members for backward and forward sampled nodes as well as the set of nodes not in the sampling order * @param evidenceDomainIndices * @throws Exception */ @Override protected void getOrdering(int[] evidenceDomainIndices) throws Exception { HashSet<BeliefNode> uninstantiatedNodes = new HashSet<BeliefNode>(Arrays.asList(nodes)); backwardSampledNodes = new Vector<BeliefNode>(); forwardSampledNodes = new Vector<BeliefNode>(); outsideSamplingOrder = new HashSet<BeliefNode>(); samplingOrder = new Vector<Pair<BeliefNode,NodeMode>>(); topOrder = new TopologicalSort(bn.bn).run(true); PriorityQueue<BeliefNode> backSamplingCandidates = new PriorityQueue<BeliefNode>(1, new TierComparator(topOrder)); // check which nodes have evidence; ones that are are candidates for backward sampling and are instantiated for(int i = 0; i < evidenceDomainIndices.length; i++) { if(evidenceDomainIndices[i] >= 0) { backSamplingCandidates.add(nodes[i]); uninstantiatedNodes.remove(nodes[i]); } } // check all backward sampling candidates while(!backSamplingCandidates.isEmpty()) { BeliefNode node = backSamplingCandidates.remove(); // check if there are any uninstantiated parents BeliefNode[] domProd = node.getCPF().getDomainProduct(); boolean doBackSampling = false; for(int j = 1; j < domProd.length; j++) { BeliefNode parent = domProd[j]; // if there are uninstantiated parents, we do backward sampling on the child node if(uninstantiatedNodes.remove(parent)) { doBackSampling = true; backSamplingCandidates.add(parent); } } if(doBackSampling) { backwardSampledNodes.add(node); samplingOrder.add(new Pair<BeliefNode,NodeMode>(node, NodeMode.Backward)); } // if there are no uninstantiated parents, the node is not backward sampled but is instantiated, // i.e. it is not in the sampling order else { outsideSamplingOrder.add(node); samplingOrder.add(new Pair<BeliefNode,NodeMode>(node, NodeMode.Outside)); } } // schedule all uninstantiated node for forward sampling in the topological order for(int i : topOrder) { if(uninstantiatedNodes.contains(nodes[i])) { forwardSampledNodes.add(nodes[i]); samplingOrder.add(new Pair<BeliefNode,NodeMode>(nodes[i], NodeMode.Forward)); } } } public IDistributionBuilder createDistributionBuilder() throws Exception { return new DistributionBuilder(); } protected class DistributionBuilder implements IDistributionBuilder { protected Map2D<Integer,BigInteger,Double> minFactors; protected Vector<Pair<WeightedSample,Vector<Pair<Integer,BigInteger>>>> samples; protected SampledDistribution dist; protected boolean dirty = false; public DistributionBuilder() throws Exception { if(useProperWeighting) { minFactors = new Map2D<Integer,BigInteger,Double>(); samples = new Vector<Pair<WeightedSample,Vector<Pair<Integer,BigInteger>>>>(); } else dist = createDistribution(); } @Override public synchronized void addSample(WeightedSample s) throws Exception { dirty = true; s.weight = 1.0; // for both weighting schemes: // * the nodes that are outside the // sampling order are sampled with probability 1, therefore // the conditional probability of those nodes applies as a factor for(BeliefNode node : outsideSamplingOrder) { double p = getCPTProbability(node, s.nodeDomainIndices); s.weight *= p; if(s.weight == 0.0) throw new Exception(p != 0.0 ? "Precision loss while computing sample weight" : "Sample has 0 probability"); } if(!useProperWeighting) { // simple weighting scheme (ignores the fact that we actually sample // from the backtrack-free distribution): // just use the factors that we recorded // * forward sampled nodes are sampled according to the prior, // therefore no factor would usually apply. // However, if CPTs are allowed to contain 0 columns, // then the sampling probability may be higher as a result of backtracking. for(BeliefNode node : forwardSampledNodes) { s.weight *= weightingFactors.get(node); if(s.weight == 0.0) throw new Exception("Precision loss while computing sample weight"); } for(BeliefNode node : backwardSampledNodes) { s.weight *= weightingFactors.get(node); if(s.weight == 0.0) { throw new Exception("Precision loss while computing sample weight");} } // and we just add the sample to the distribution dist.addSample(s); } else { // unbiased weighting: store all samples and keep track of minimum factors BigInteger partAssign = BigInteger.valueOf(0); Vector<Pair<Integer,BigInteger>> keys = new Vector<Pair<Integer,BigInteger>>(); for(int i = 0; i < samplingOrder.size(); i++) { Pair<BeliefNode,NodeMode> p = samplingOrder.get(i); if(p.second == NodeMode.Outside) { continue; } // extend assignment BigInteger distSize; if(p.second == NodeMode.Backward) distSize = BigInteger.valueOf(p.first.getCPF().getRowLength()); // the row length is an upper bound for the distribution size else distSize = BigInteger.valueOf(p.first.getDomain().getOrder()); partAssign = partAssign.multiply(distSize); partAssign = partAssign.add(BigInteger.valueOf(sampledIndices[i])); Double prevFactor = minFactors.get(i,partAssign); Double factor = weightingFactors.get(p.first); if(prevFactor == null || factor < prevFactor) { // TODO verify this condition minFactors.put(i, partAssign, factor); /*if(prevFactor != null) System.out.println("reducing factor " + prevFactor + " to " + factor); */ } keys.add(new Pair<Integer,BigInteger>(i,partAssign)); } Pair<WeightedSample, Vector<Pair<Integer,BigInteger>>> sample = new Pair<WeightedSample, Vector<Pair<Integer,BigInteger>>>(s.clone(), keys); samples.add(sample); } } @Override public synchronized SampledDistribution getDistribution() throws Exception { if(!useProperWeighting) return dist; else { System.out.println("unbiased sample weighting..."); if(!dirty) return dist; dist = createDistribution(); for(Pair<WeightedSample, Vector<Pair<Integer,BigInteger>>> sample : samples) { WeightedSample s = sample.first; s.weight = 1.0; for(BeliefNode node : outsideSamplingOrder) { double p = getCPTProbability(node, s.nodeDomainIndices); s.weight *= p; if(s.weight == 0.0) throw new Exception(p != 0.0 ? "Precision loss while computing sample weight" : "Sample has 0 probability"); } for(Pair<Integer,BigInteger> key : sample.second) { Double factor = minFactors.get(key.first, key.second); //System.out.println("factor " + factor); s.weight *= factor; if(s.weight == 0.0) { throw new Exception("Precision loss while computing sample weight");} } //System.out.println("added sample with weight " + s.weight); dist.addSample(s); } dirty = false; return dist; } } } }