/*******************************************************************************
* Copyright (C) 2011-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.Vector;
import probcog.bayesnets.core.BeliefNetworkEx;
import probcog.bayesnets.util.TopologicalOrdering;
import probcog.bayesnets.util.TopologicalSort;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.tum.cs.util.datastruct.Map2Set;
import edu.tum.cs.util.datastruct.Pair;
import edu.tum.cs.util.datastruct.PrioritySet;
/**
* Backward SampleSearch with backjumping
*
* @author Dominik Jain
*/
public class BackwardSampleSearchBJ extends BackwardSampleSearch {
public BackwardSampleSearchBJ(BeliefNetworkEx bn) throws Exception {
super(bn);
}
public static class HighestFirst implements Comparator<Integer> {
@Override
public int compare(Integer o1, Integer o2) {
return -o1.compareTo(o2);
}
}
@Override
public void getSample(WeightedSample s) throws Exception {
Map2Set<BeliefNode,Integer> domExclusions = new Map2Set<BeliefNode,Integer>();
initSample(s);
backSamplingDistributionCache = new HashMap<BeliefNode, BackSamplingDistribution>();
boolean backtracking = false;
HighestFirst highestFirst = new HighestFirst();
HashMap<Integer,PrioritySet<Integer>> backtrackQueues = new HashMap<Integer,PrioritySet<Integer>>();
for(int i = 0; i < samplingOrder.size();) {
currentOrderIndex = i;
Pair<BeliefNode,NodeMode> p = samplingOrder.get(i);
// get the node
BeliefNode node = p.first;
NodeMode mode = p.second;
// if we got to the node backtracking, we add the last value as an exclusion
if(backtracking) {
domExclusions.add(node, sampledIndices[i]);
if(mode == NodeMode.Outside)
throw new Exception("Backtracked to node outside order");
}
else {
// if we get to a node going forward, forget all exclusions and invalidate cache
domExclusions.remove(node);
if(mode == NodeMode.Backward) backSamplingDistributionCache.remove(node);
backtrackQueues.remove(i);
}
// info
++s.operations;
//if(s.operations == 10000) debug=true;
if(debug)
out.printf(" Op%d: #%d %s\n", s.operations, i, node.getName());
else
if(infoInterval == 1) out.printf("#%d \r", i);
// get domain exclusions
Set<Integer> excluded = domExclusions.get(node);
boolean valueSuccessfullyAssigned = true;
switch(mode) {
case Backward:
if(debug) out.printf(" backward sampling (%d exclusions)\n", excluded == null ? 0 : excluded.size());
//Stopwatch sw3 = new Stopwatch();
//sw3.start();
if(!sampleBackward(node, s, excluded)){
//if (debug) out.println("CPT contains only zeros for backward sampled node: "+ node);
valueSuccessfullyAssigned = false;
}
break;
case Forward:
if(debug) out.printf(" forward sampling (%d exclusions)\n", excluded == null ? 0 : excluded.size());
if(!sampleForward(node, s, excluded)){
//if (debug) out.println("CPT contains only zeros for forward sampled node: "+ node);
valueSuccessfullyAssigned = false;
}
break;
case Outside:
if(debug) out.printf(" outside sampling order\n", excluded == null ? 0 : excluded.size());
double prob = this.getCPTProbability(node, s.nodeDomainIndices);
if(prob == 0.0)
valueSuccessfullyAssigned = false;
break;
}
if(valueSuccessfullyAssigned){ // go forward
// end backtracking
backtracking = false;
++i;
//backtrackQueues.remove(i);
}
else { // backtrack
if(i == 0) // can't backtrack further
throw new Exception("Backtracking past first level. Most likely, the evidence that was specified is contradictory");
backtracking = true;
PrioritySet<Integer> backtrackQueue = backtrackQueues.get(i);
if(backtrackQueue == null)
backtrackQueue = new PrioritySet<Integer>(new PriorityQueue<Integer>(1, highestFirst));
if(debug) System.out.printf(" initial backtrack queue: %s\n", backtrackQueue);
// extend the queue depending on the current constraint
if(mode == NodeMode.Backward) {
//if(debug) System.out.println(" parents: " + bsAlreadyGivenParents.get(node));
// nodes that instantiated parents that were given before
for(BeliefNode parent : bsAlreadyGivenParents.get(node)) {
Integer level = node2instantiatorOrderIndex.get(parent);
if(level != null) {
if(debug) System.out.printf(" adding %d (instantiated parent %s)\n", level, parent.getName());
assert level < i;
backtrackQueue.add(level);
}
}
// the node that instantiated the node itself
Integer level = node2instantiatorOrderIndex.get(node);
if(debug) System.out.println(" adding " + level + " (instantiator of main node)");
if(level != null)
backtrackQueue.add(level);
}
else { // for forward and outside nodes, add the nodes that instantiated the parents
BeliefNode[] domprod = node.getCPF().getDomainProduct();
for(int j = 1; j < domprod.length; j++) {
Integer level = node2instantiatorOrderIndex.get(domprod[j]);
if(debug) System.out.printf(" adding %d (instantiated parent %s)\n", level, domprod[j].getName());
if(level != null)
backtrackQueue.add(level);
}
// for outside nodes also the node that instantiated this node (if any)
if(mode == NodeMode.Outside) {
Integer level = node2instantiatorOrderIndex.get(node);
if(debug) System.out.println(" adding " + level + " (instantiator of main node)");
if(level != null)
backtrackQueue.add(level);
}
}
// back jump
Integer iprev = i;
if(backtrackQueue.isEmpty())
throw new Exception("Nowhere left to backjump to from node #" + i + ". Most likely, the evidence has 0 probability.");
else
i = backtrackQueue.remove();
assert i < iprev : "Invalid backjump from " + iprev + " to " + i;
// undo all assignments along the way (necessary for backward sampling distributions to be constructed correctly)
for(int j = iprev-1; j >= i; j--)
undoAssignment(j, s);
// merge to update the new node's backtracking queue
PrioritySet<Integer> oldQueue = backtrackQueues.get(i);
if(oldQueue == null) {
//backtrackQueues.put(i, backtrackQueue); // unsafe? -- would assign same queue to i and i-1
oldQueue = new PrioritySet<Integer>(new PriorityQueue<Integer>(1, highestFirst));
backtrackQueues.put(i, oldQueue);
}
for(Integer j : backtrackQueue)
oldQueue.add(j);
// this is probably unnecessary
domExclusions.remove(node);
if(mode == NodeMode.Backward) backSamplingDistributionCache.remove(node);
if(debug) System.out.printf(" backtracking to #%d, queue: %s\n", i, backtrackQueue.toString());
s.trials++;
}
}
}
protected void undoAssignment(int i, WeightedSample s) {
Pair<BeliefNode, NodeMode> p = samplingOrder.get(i);
switch(p.second) {
case Backward:
for(Integer idx : assignedNodeIndicesByOrderIndex.get(i))
s.nodeDomainIndices[idx] = -1;
break;
case Forward:
s.nodeDomainIndices[getNodeIndex(p.first)] = -1;
break;
}
}
protected HashMap<BeliefNode,Integer> node2instantiatorOrderIndex;
protected HashMap<BeliefNode,Vector<BeliefNode>> bsAlreadyGivenParents;
protected HashMap<BeliefNode,Integer> node2orderIndex;
/**
* gets the sampling order by filling the members for backward and forward sampled nodes as well as the set of nodes not in the sampling order
* @param evidenceDomainIndices
* @throws Exception
*/
@Override
protected void getOrdering(int[] evidenceDomainIndices) throws Exception {
HashSet<BeliefNode> uninstantiatedNodes = new HashSet<BeliefNode>(Arrays.asList(nodes));
backwardSampledNodes = new Vector<BeliefNode>();
forwardSampledNodes = new Vector<BeliefNode>();
outsideSamplingOrder = new HashSet<BeliefNode>();
samplingOrder = new Vector<Pair<BeliefNode,NodeMode>>();
node2instantiatorOrderIndex = new HashMap<BeliefNode,Integer>();
bsAlreadyGivenParents = new HashMap<BeliefNode,Vector<BeliefNode>>();
node2orderIndex = new HashMap<BeliefNode,Integer>();
TopologicalOrdering topOrder = new TopologicalSort(bn.bn).run(true);
PriorityQueue<BeliefNode> backSamplingCandidates = new PriorityQueue<BeliefNode>(1, new TierComparator(topOrder));
// check which nodes have evidence; ones that are are candidates for backward sampling and are instantiated
for(int i = 0; i < evidenceDomainIndices.length; i++) {
if(evidenceDomainIndices[i] >= 0) {
backSamplingCandidates.add(nodes[i]);
uninstantiatedNodes.remove(nodes[i]);
}
}
// check all backward sampling candidates
while(!backSamplingCandidates.isEmpty()) {
Integer orderIndex = samplingOrder.size();
BeliefNode node = backSamplingCandidates.remove();
// check if there are any uninstantiated parents
BeliefNode[] domProd = node.getCPF().getDomainProduct();
boolean doBackSampling = false;
Vector<BeliefNode> givenParents = new Vector<BeliefNode>();
for(int j = 1; j < domProd.length; j++) {
BeliefNode parent = domProd[j];
// if there are uninstantiated parents, we do backward sampling on the child node
if(uninstantiatedNodes.remove(parent)) {
doBackSampling = true;
backSamplingCandidates.add(parent);
node2instantiatorOrderIndex.put(parent, orderIndex);
}
else
givenParents.add(parent);
}
if(doBackSampling) {
backwardSampledNodes.add(node);
samplingOrder.add(new Pair<BeliefNode,NodeMode>(node, NodeMode.Backward));
bsAlreadyGivenParents.put(node, givenParents);
}
// if there are no uninstantiated parents, the node is not backward sampled but is instantiated,
// i.e. it is not in the sampling order
else {
outsideSamplingOrder.add(node);
samplingOrder.add(new Pair<BeliefNode,NodeMode>(node, NodeMode.Outside));
}
}
// schedule all uninstantiated node for forward sampling in the topological order
for(int i : topOrder) {
if(uninstantiatedNodes.contains(nodes[i])) {
forwardSampledNodes.add(nodes[i]);
node2instantiatorOrderIndex.put(nodes[i], samplingOrder.size());
samplingOrder.add(new Pair<BeliefNode,NodeMode>(nodes[i], NodeMode.Forward));
}
}
Integer i = 0;
for(Pair<BeliefNode,NodeMode> p : samplingOrder)
node2orderIndex.put(p.first, i++);
}
}