/******************************************************************************* * Copyright (C) 2010-2012 Dominik Jain. * * This file is part of ProbCog. * * ProbCog is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * ProbCog is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with ProbCog. If not, see <http://www.gnu.org/licenses/>. ******************************************************************************/ package probcog.bayesnets.inference; import java.util.Comparator; import java.util.HashMap; import java.util.PriorityQueue; import probcog.bayesnets.core.BeliefNetworkEx; import edu.ksu.cis.bnj.ver3.core.BeliefNode; import edu.tum.cs.util.datastruct.PrioritySet; /** * SampleSearch with backjumping * @author Dominik Jain */ public class SampleSearchBJ extends SampleSearch { protected HashMap<BeliefNode, Integer> node2orderIndex; public SampleSearchBJ(BeliefNetworkEx bn) throws Exception { super(bn); } @Override protected int[] computeNodeOrdering() throws Exception { if(verbose) System.out.println("computing node ordering..."); /* TopologicalOrdering topologicalOrdering = new TopologicalSort(this.bn.bn).run(); order = new int[nodes.length]; int i = 0; for(int nodeIdx : topologicalOrdering) order[i++] = nodeIdx; */ // this ordering seems to work slightly better than the above in practice int[] samplingOrder = bn.getTopologicalOrder(); // maintain mapping of node to index in ordering node2orderIndex = new HashMap<BeliefNode, Integer>(); for(int i = 0; i < samplingOrder.length; i++) { node2orderIndex.put(nodes[samplingOrder[i]], i); } return samplingOrder; } protected static final class HighestFirst implements Comparator<Integer> { public HighestFirst() {} @Override public int compare(Integer o1, Integer o2) { return -o1.compareTo(o2); } } public class DomainExclusions { HashMap<Integer, boolean[]> domExclusions = new HashMap<Integer, boolean[]>(); public boolean[] get(Integer nodeIdx) { boolean[] excluded = domExclusions.get(nodeIdx); if(excluded == null) { excluded = new boolean[nodes[nodeIdx].getDomain().getOrder()]; domExclusions.put(nodeIdx, excluded); } return excluded; } public void add(Integer nodeIdx, int domIdx) { boolean[] excl = get(nodeIdx); excl[domIdx] = true; } public void remove(Integer nodeIdx) { domExclusions.remove(nodeIdx); } public int getNumExclusions(Integer nodeIdx) { boolean[] excl = domExclusions.get(nodeIdx); if(excl == null) return 0; int n = 0; for(boolean b : excl) if(b) ++n; return n; } } @Override public WeightedSample getWeightedSample(WeightedSample s, int[] nodeOrder, int[] evidenceDomainIndices) throws Exception { s.trials = 1; s.operations = 0; s.weight = 1.0; //PriorityQueue<BeliefNode> backtrack = new PriorityQueue<BeliefNode>(10, new BacktrackOrderingComparator()); boolean backtracking = false; HighestFirst highestFirst = new HighestFirst(); DomainExclusions domExclusions = new DomainExclusions(); HashMap<Integer,PrioritySet<Integer>> backtrackQueues = new HashMap<Integer,PrioritySet<Integer>>(); // assign values to the nodes in order for(int orderIdx = 0; orderIdx < nodeOrder.length;) { s.operations++; boolean valueSuccessfullyAssigned = false; int nodeIdx = nodeOrder[orderIdx]; if(!debug && infoInterval == 1) System.out.printf(" #%d \r", orderIdx); if(!backtracking) { domExclusions.remove(nodeIdx); backtrackQueues.remove(orderIdx); } else { // since we are backtracking, the previous setting of this node is // inapplicable, so we add it to the exclusions domExclusions.add(nodeIdx, s.nodeDomainIndices[nodeIdx]); } int domainIdx = evidenceDomainIndices[nodeIdx]; // get domain exclusions boolean[] excluded = domExclusions.get(nodeIdx); // for evidence nodes, we can continue if the evidence // probability was non-zero if(domainIdx >= 0) { s.nodeDomainIndices[nodeIdx] = domainIdx; samplingProb[nodeIdx] = 1.0; double prob = getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices); if(prob != 0.0) { valueSuccessfullyAssigned = true; } } // for non-evidence nodes, do forward sampling else { SampledAssignment sa = sampleForward(nodes[nodeIdx], s.nodeDomainIndices, excluded); if(sa != null) { domainIdx = sa.domIdx; samplingProb[nodeIdx] = sa.probability; s.nodeDomainIndices[nodeIdx] = domainIdx; valueSuccessfullyAssigned = true; } } if(valueSuccessfullyAssigned) { // if we are backtracking and could assign a value, // we are done backtracking and can continue processing the actual queue backtracking = false; ++orderIdx; } else { if(orderIdx == 0) // can't backtrack further throw new Exception("Backtracking past first level. Most likely, the evidence that was specified is contradictory"); backtracking = true; PrioritySet<Integer> backtrackQueue = backtrackQueues.get(orderIdx); if(backtrackQueue == null) backtrackQueue = new PrioritySet<Integer>(new PriorityQueue<Integer>(1, highestFirst)); if(debug) System.out.printf(" initial backtrack queue: %s\n", backtrackQueue); // extend the queue depending on the current constraint: // add the non-evidence parents BeliefNode[] domprod = nodes[nodeIdx].getCPF().getDomainProduct(); for(int j = 1; j < domprod.length; j++) { Integer level = node2orderIndex.get(domprod[j]); int parentNodeIdx = getNodeIndex(domprod[j]); if(evidenceDomainIndices[parentNodeIdx] < 0) { if(debug) System.out.printf(" adding %d\n", level, domprod[j].getName()); backtrackQueue.add(level); } } // back jump Integer iprev = orderIdx; if(backtrackQueue.isEmpty()) throw new Exception("Nowhere left to backjump to from node #" + orderIdx + ". Most likely, the evidence has 0 probability."); else orderIdx = backtrackQueue.remove(); assert orderIdx < iprev : "Invalid backjump from " + iprev + " to " + orderIdx; // undo all assignments along the way (necessary for backward sampling distributions to be constructed correctly) //for(int j = iprev-1; j >= orderIdx; j--) // undoAssignment(j, s); // merge to update the new node's backtracking queue PrioritySet<Integer> oldQueue = backtrackQueues.get(orderIdx); if(oldQueue == null) { //backtrackQueues.put(i, backtrackQueue); // unsafe? -- would assign same queue to i and i-1 oldQueue = new PrioritySet<Integer>(new PriorityQueue<Integer>(1, highestFirst)); backtrackQueues.put(orderIdx, oldQueue); } for(Integer j : backtrackQueue) oldQueue.add(j); } // debug info /* if(debug) { int numex = 0; for(int j = 0; j < excluded.length; j++) if(excluded[j]) numex++; System.out.printf(" step %d, node #%d '%s' (%d/%d exclusions) ", currentStep, node2orderIndex.get(nodes[nodeIdx]), nodes[nodeIdx].getName(), numex, excluded.length); if(valueSuccessfullyAssigned) System.out.printf("assigned %d (%s)\n", domainIdx, nodes[nodeIdx].getDomain().getName(domainIdx)); else { if(evidenceDomainIndices[nodeIdx] == -1) System.out.println("impossible case; backtracking..."); else System.out.println("evidence with probability 0.0; backtracking..."); } } */ } return s; } }