/*******************************************************************************
* Copyright (C) 2011-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.TreeMap;
import java.util.TreeSet;
import probcog.bayesnets.core.BeliefNetworkEx;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.Domain;
import edu.tum.cs.util.datastruct.Pair;
import edu.tum.cs.util.datastruct.PrioritySet;
/**
* SampleSearch with conflict-directed backjumping and propositional constraint learning (using earliest minimal conflict sets)
* @author Dominik Jain
*/
public class SampleSearchBJLearning extends SampleSearchBJ {
protected HighestFirst highestFirst = new HighestFirst();
protected NoGoods noGoods = new NoGoods();
protected int numNoGoods = 0;
protected boolean useNoGoods = true;
protected int maxNoGoodSize = 0, totalNoGoodSize = 0;
protected int numNoGoodNodeChecks = 0;
private HashSet<NoGood> verifiedNoGoods;
/**
* whether to verify recorded nogoods with regular SampleSearch; applies only when debug is on
*/
private boolean verifyNoGoods = false;
public SampleSearchBJLearning(BeliefNetworkEx bn) throws Exception {
super(bn);
this.paramHandler.add("useNoGoods", "setUseNoGoods");
this.paramHandler.add("verifyNoGoods", "setVerifyNoGoods");
}
public void setUseNoGoods(boolean enabled) {
useNoGoods = enabled;
}
public void setVerifyNoGoods(boolean enabled) {
verifyNoGoods = enabled;
}
protected class NoGood {
public int domIdx;
protected Map<Integer,Pair<Integer,Integer>> nodeSettings;
protected final boolean debugNoGoodMatching = true;
public NoGood(int domIdx) {
this.domIdx = domIdx;
nodeSettings = new TreeMap<Integer, Pair<Integer,Integer>>(highestFirst); // use a tree map to have the domain (keys) ordered
}
public void addSetting(int nodeIdx, int domIdx) {
nodeSettings.put(node2orderIndex.get(nodes[nodeIdx]), new Pair<Integer,Integer>(nodeIdx, domIdx));
}
public void addSettingOrderIdx(int orderIdx, int domIdx) {
nodeSettings.put(orderIdx, new Pair<Integer, Integer>(nodeOrder[orderIdx], domIdx));
}
public boolean isApplicable(int[] nodeDomainIndices) {
if(debugNoGoodMatching && debug) System.out.println(" checking " + this);
for(Pair<Integer,Integer> e : nodeSettings.values()) {
numNoGoodNodeChecks++;
if(nodeDomainIndices[e.first] != e.second) {
if(debugNoGoodMatching && debug) System.out.printf(" not applicable because %s=%d (should be %d)\n", nodes[e.first], nodeDomainIndices[e.first], e.second);
return false;
}
}
return true;
}
public String toString() {
StringBuilder sb = new StringBuilder(String.format("NoGood#%X(", System.identityHashCode(this)));
sb.append(domIdx);
sb.append("; ");
int i = 0;
for(Pair<Integer,Integer> e : nodeSettings.values()) {
if(i++ > 0) sb.append(", ");
sb.append(nodes[e.first].toString());
sb.append("=");
sb.append(nodes[e.first].getDomain().getName(e.second));
}
sb.append(")");
return sb.toString();
}
/**
* @return the collection of variable order indices referenced by this constraint
*/
public Collection<Integer> getDomain() {
return nodeSettings.keySet();
}
public String getDomainString() {
Iterator<Integer> i = getDomain().iterator();
if(!i.hasNext())
return "";
StringBuilder sb = new StringBuilder();
sb.append(nodes[nodeOrder[i.next()]].toString());
while(i.hasNext()) {
sb.append(", ");
sb.append(nodes[nodeOrder[i.next()]].toString());
}
return sb.toString();
}
}
public static class NoGoods {
protected HashMap<Integer, Collection<NoGood>> node2nogoods = new HashMap<Integer, Collection<NoGood>>();
protected Earlier earlierRelation = new Earlier();
public void add(int nodeIdx, NoGood nogood) {
Collection<NoGood> v = node2nogoods.get(nodeIdx);
if(v == null) {
// use a tree set with the earlier relation to make sure the
// constraints are ordered correctly (earliest first).
// We don't actually require a set, because equal elements are never
// added, but there seems to be no standard data structure for this.
v = new TreeSet<NoGood>(earlierRelation);
node2nogoods.put(nodeIdx, v);
}
v.add(nogood);
}
/**
* returns a collection of constraints for the given node index. The collection is
* ordered according to the earlier relation.
* @param nodeIdx
* @return
*/
public Collection<NoGood> get(int nodeIdx) {
return node2nogoods.get(nodeIdx);
}
/**
* defines the earlier relation for constraints/nogoods:
* The original definition of this relation is as follows:
* A constraint C1 with scope S1 is said to be earlier than a constraint C2 with scope S2
* if the largest-order variable in S1-S2 has a lower order than the largest-order variable
* in S2-S1.
* For the case where S1=S2, this implementation arbitrarily selects C1 as earlier.
*/
public static final class Earlier implements Comparator<NoGood> {
@Override
public int compare(NoGood c1, NoGood c2) {
final int c1_earlier = -1;
final int c2_earlier = 1;
// look for the largest number that appears in one domain but not the other
// based on the reverse-order iterators
Collection<Integer> s1 = c1.getDomain();
Collection<Integer> s2 = c2.getDomain();
Iterator<Integer> i1 = s1.iterator();
Iterator<Integer> i2 = s2.iterator();
Integer relevant1 = null, relevant2 = null;
while(true) {
// if either list has no further elements, we are done
if(!i1.hasNext() || !i2.hasNext())
break;
// get the next two elements
int n1 = i1.next();
int n2 = i2.next();
// if they are the same, proceed to the next elements
if(n1 == n2)
continue;
// otherwise, we now have one of the relevant numbers and only need to determine the second
if(n1 < n2) {
// since n2 was not found in c1, n2 is c2's relevant number
if(relevant2 == null) {
relevant2 = n2;
if(relevant1 != null)
break;
}
// at this point, relevant1 must be null
assert relevant1 == null;
// look for n1 in c2
while(n1 < n2) {
if(!i2.hasNext())
// cannot reach a number as low as c1's, so c1 is a relevant number
// for c1 that is not found in c2, so c1 is definitely earlier
return c1_earlier;
n2 = i2.next();
}
if(n2 < n1) {
// n1 was not found in c2, so n1 is c1's relevant number, which is definitely
// smaller than c2's relevant number, so c1 is earlier
return c1_earlier;
}
// otherwise (n1 == n2), continue
}
else { // n2 < n1 (analogous to the above)
if(relevant1 == null) {
relevant1 = n1;
if(relevant2 != null)
break;
}
assert relevant2 == null;
while(n2 < n1) {
if(!i1.hasNext())
return c2_earlier;
n1 = i1.next();
}
if(n1 < n2) {
return c2_earlier;
}
}
}
// we have all the information...
if(relevant1 == null) {
if(relevant2 == null) // no relevant values, so decide based on scope size
return s1.size() > s2.size() ? c1_earlier : c2_earlier;
else // c2's scope is a superset of c1's
return c2_earlier;
}
else {
if(relevant2 == null) // c1's scope is a superset of c2's
return c1_earlier;
else
return relevant1 < relevant2 ? c1_earlier : c2_earlier;
}
}
}
public boolean earlier(NoGood n1, NoGood n2) {
return earlierRelation.compare(n1, n2) <= 0;
}
}
@Override
public void _infer() throws Exception {
super._infer();
report(String.format("#no-goods: %s; max. size: %d; avg. size: %f; total node checks: %d", numNoGoods, maxNoGoodSize, (float)totalNoGoodSize/numNoGoods, numNoGoodNodeChecks));
}
@Override
protected void info(int step) {
System.out.printf(" step %d: %d no-goods recorded\n", step, this.numNoGoods);
}
@Override
public WeightedSample getWeightedSample(WeightedSample s, int[] nodeOrder, int[] evidenceDomainIndices) throws Exception {
s.trials = 1;
s.operations = 0;
s.weight = 1.0;
HashMap<Integer,PrioritySet<Integer>> backjumpSets = new HashMap<Integer,PrioritySet<Integer>>();
boolean backtracking = false;
DomainExclusions domExclusions = new DomainExclusions();
for(int i = 0; i < evidenceDomainIndices.length; i++)
s.nodeDomainIndices[i] = evidenceDomainIndices[i];
// assign values to the nodes in order
for(int orderIdx = 0; orderIdx < nodeOrder.length;) {
s.operations++;
int nodeIdx = nodeOrder[orderIdx];
boolean valueSuccessfullyAssigned = false;
if(!backtracking) {
// if we get to a node going forward, any previous exclusions are obsolete
// and the backjump set can be reset
domExclusions.remove(nodeIdx);
backjumpSets.remove(orderIdx);
if(debug) System.out.println(" Op" + s.operations + ": #" + node2orderIndex.get(nodes[nodeIdx]) + " " + nodes[nodeIdx].getName() + ", current setting: " + s.nodeDomainIndices[nodeIdx]);
}
else {
if(debug) System.out.println(" Op" + s.operations + ": backtracking to #" + node2orderIndex.get(nodes[nodeIdx]) + " " + nodes[nodeIdx].getName() + ", current setting: " + s.nodeDomainIndices[nodeIdx]);
domExclusions.add(nodeIdx, s.nodeDomainIndices[nodeIdx]);
}
if(!debug && infoInterval == 1)
System.out.printf("#%d, %d nogoods\r", nodeIdx, numNoGoods);
int domainIdx = evidenceDomainIndices[nodeIdx];
PrioritySet<Integer> backjumpSet = backjumpSets.get(orderIdx);
if(backjumpSet == null)
backjumpSet = new PrioritySet<Integer>(new PriorityQueue<Integer>(2, highestFirst));
// for evidence nodes, we can continue if the evidence
// probability was non-zero
if(domainIdx >= 0) {
s.nodeDomainIndices[nodeIdx] = domainIdx;
samplingProb[nodeIdx] = 1.0;
double prob = getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices);
if(prob != 0.0)
valueSuccessfullyAssigned = true;
else {
// the minimal conflict set is given by the non-evidence parents,
// so add them to the backjump set
BeliefNode[] domProd = nodes[nodeIdx].getCPF().getDomainProduct();
for(int k = 1; k < domProd.length; k++) {
int parentNodeIdx = getNodeIndex(domProd[k]);
if(evidenceDomainIndices[parentNodeIdx] == -1) {
Integer parentOrderIdx = node2orderIndex.get(domProd[k]);
backjumpSet.add(parentOrderIdx);
}
}
}
}
// for non-evidence nodes, do forward sampling
else {
// get domain exclusions
boolean[] excluded = domExclusions.get(nodeIdx);
// get conditional distribution that applies to the current parent configuration and
// determine the domain indices for which the parents constitute a nogood
double[] dist = getConditionalDistribution(nodes[nodeIdx], s.nodeDomainIndices);
NoGood parentNoGood = null;
NoGood[] parentNoGoods = new NoGood[dist.length];
for(int i = 0; i < dist.length; i++) {
if(dist[i] == 0.0) {
if(parentNoGood == null) {
parentNoGood = new NoGood(-1);
BeliefNode[] domProd = nodes[nodeIdx].getCPF().getDomainProduct();
for(int k = 1; k < domProd.length; k++) {
int parentNodeIdx = getNodeIndex(domProd[k]);
if(evidenceDomainIndices[parentNodeIdx] == -1)
parentNoGood.addSetting(parentNodeIdx, s.nodeDomainIndices[parentNodeIdx]);
}
}
parentNoGoods[i] = parentNoGood;
}
else {
parentNoGoods[i] = null;
if(excluded[i])
dist[i] = 0;
}
}
// add additional exclusions based on no-goods
NoGood[] earliest = new NoGood[dist.length];
if(useNoGoods) {
Collection<NoGood> v = noGoods.get(nodeIdx);
if(v != null) {
//if(debug) System.out.println(" checking " + v.size() + " nogoods: " + StringTool.join(", ", v));
for(NoGood ng : v) {
int domIdx = ng.domIdx;
if(!excluded[domIdx] && earliest[domIdx] == null) {
// if there is a parent nogood for this domain index, check if the current
// nogood is earlier
boolean checkNoGood = true;
if(parentNoGoods[domIdx] != null)
checkNoGood = noGoods.earlier(ng, parentNoGoods[domIdx]);
if(checkNoGood && ng.isApplicable(s.nodeDomainIndices)) {
if(debug) {
s.nodeDomainIndices[nodeIdx] = ng.domIdx;
boolean OK1 = verifyNoGoodInContext(ng, s.nodeDomainIndices);
boolean OK2 = verifyNoGood(this.nodes[nodeIdx], ng);
if(!OK1 || !OK2)
throw new Exception("nogood is bad");
}
earliest[domIdx] = ng;
dist[domIdx] = 0;
if(debug) System.out.printf(" nogood excluded %d (%s): %s\n", ng.domIdx, nodes[nodeIdx].getDomain().getName(ng.domIdx), ng.toString());
}
}
}
}
}
// sample
SampledAssignment sa;
sa = sample(dist);
//sa = sampleForward(nodes[nodeIdx], s.nodeDomainIndices, excluded);
if(sa != null) {
domainIdx = sa.domIdx;
samplingProb[nodeIdx] = sa.probability;
s.nodeDomainIndices[nodeIdx] = domainIdx;
valueSuccessfullyAssigned = true;
}
else {
// extend the backjump set for all the constraints that applied
for(int i = 0; i < dist.length; i++) {
NoGood ng = earliest[i];
if(ng == null)
ng = parentNoGoods[i];
if(ng != null) {
if(debug) System.out.println(" mce constraint for value " + i + " on: " + ng.getDomainString() + (ng == parentNoGoods[i] ? " (parents)" : ""));
for(Integer o : ng.getDomain()) {
backjumpSet.add(o);
}
}
}
}
}
// debug info
if(debug) {
//System.out.printf(" step %d, node #%d '%s' (%d/%d exclusions) ", currentStep, node2orderIndex.get(nodes[nodeIdx]), nodes[nodeIdx].getName(), numex, excluded.length);*/
if(evidenceDomainIndices[nodeIdx] == -1) {
if(valueSuccessfullyAssigned) {
Domain dom = nodes[nodeIdx].getDomain();
System.out.printf(" assigned %d (%s), (%d/%d) exclusions\n", domainIdx, dom.getName(domainIdx), domExclusions.getNumExclusions(nodeIdx), dom.getOrder());
}
else
System.out.println(" out of choices; backtracking...");
}
else {
if(valueSuccessfullyAssigned)
System.out.printf(" evidence %d (%s) OK\n", domainIdx, nodes[nodeIdx].getDomain().getName(domainIdx));
else
System.out.printf(" evidence %d (%s) with probability 0.0; backtracking... cond: %s\n", domainIdx, nodes[nodeIdx].getDomain().getName(domainIdx), s.getCPDLookupString(nodes[nodeIdx]));
}
}
// if a value was successfully assigned, continue to the next node in the order
if(valueSuccessfullyAssigned) {
backtracking = false;
++orderIdx;
}
// otherwise, jump back and record a constraint/nogood
else {
s.trials++;
backtracking = true;
// back jump
if(backjumpSet.isEmpty())
throw new Exception("Nowhere left to backjump to from node #" + orderIdx + ". Most likely, the evidence has 0 probability.");
orderIdx = backjumpSet.remove();
// record nogood
NoGood ng = new NoGood(s.nodeDomainIndices[nodeOrder[orderIdx]]);
for(Integer oIdx : backjumpSet)
ng.addSettingOrderIdx(oIdx, s.nodeDomainIndices[nodeOrder[oIdx]]);
noGoods.add(nodeOrder[orderIdx], ng);
if(debug) System.out.println(" recorded nogood for " + nodes[nodeOrder[orderIdx]] + ": " + ng);
++numNoGoods;
// merge to update the new node's backjump set
PrioritySet<Integer> oldQueue = backjumpSets.get(orderIdx);
if(oldQueue == null) {
oldQueue = new PrioritySet<Integer>(new PriorityQueue<Integer>(1, highestFirst));
backjumpSets.put(orderIdx, oldQueue);
}
for(Integer j : backjumpSet)
oldQueue.add(j);
// reset assignment
s.nodeDomainIndices[nodeIdx] = evidenceDomainIndices[nodeIdx];
}
}
//System.out.printf(" no-goods: %d, trials: %d\n", this.numNoGoods, s.trials);
return s;
}
protected SampledAssignment sample(double[] dist) {
double sum = 0;
for(int i = 0; i < dist.length; i++)
sum += dist[i];
// if the distribution contains only zeros, it is an impossible case -> cannot sample
if(sum == 0)
return null;
int domIdx = sample(dist, sum, generator);
return new SampledAssignment(domIdx, dist[domIdx]/sum);
}
/**
* verifies that the given nogood is indeed correct by running SampleSearch to find
* a sample
* @param ng the nogood to check
* @return whether the nogood is OK
*/
protected boolean verifyNoGood(BeliefNode n, NoGood ng) throws Exception {
if(!verifyNoGoods)
return true;
if(verifiedNoGoods == null)
verifiedNoGoods = new HashSet<NoGood>();
if(verifiedNoGoods.contains(ng))
return true;
SampleSearch ss = new SampleSearch(this.bn);
int[] nodeDomainIndices = evidenceDomainIndices.clone();
nodeDomainIndices[getNodeIndex(n)] = ng.domIdx;
for(Pair<Integer,Integer> e : ng.nodeSettings.values()) {
nodeDomainIndices[e.first] = e.second;
}
ss.setEvidence(nodeDomainIndices);
ss.setNumSamples(1);
ss.setVerbose(false);
System.out.print(" verifying " + ng + "... ");
boolean haveSample = true;
try {
ss.infer();
}
catch(Exception e) {
haveSample = false;
}
System.out.println(!haveSample ? "OK" : "BAD!");
verifiedNoGoods.add(ng);
return !haveSample;
}
protected boolean verifyNoGoodInContext(NoGood ng, int[] nodeDomainIndices) throws Exception {
if(!verifyNoGoods)
return true;
SampleSearch ss = new SampleSearch(this.bn);
ss.setEvidence(nodeDomainIndices);
ss.setNumSamples(1);
ss.setVerbose(false);
System.out.print(" verifying " + ng + " in current context... ");
boolean haveSample = true;
try {
ss.infer();
}
catch(Exception e) {
haveSample = false;
}
System.out.println(!haveSample ? "OK" : "BAD!");
return !haveSample;
}
}