/*******************************************************************************
* Copyright (C) 2009-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import probcog.bayesnets.core.BeliefNetworkEx;
import edu.ksu.cis.bnj.ver3.core.Domain;
import edu.tum.cs.util.Stopwatch;
/**
* Inference via enumeration of possible worlds (exact).
* @author Dominik Jain
*/
public class EnumerationAsk extends Sampler {
int[] nodeOrder;
int numPathsPruned;
double numWorldsPruned, numWorldsCounted;
Stopwatch timer;
/**
* total number of possible worlds
*/
double numTotalWorlds;
public EnumerationAsk(BeliefNetworkEx bn) throws Exception {
super(bn);
nodeOrder = bn.getTopologicalOrder();
numTotalWorlds = bn.getNumWorlds();
}
public void _infer() throws Exception {
Stopwatch sw = new Stopwatch();
numPathsPruned = 0;
numWorldsPruned = numWorldsCounted = 0;
if(verbose) out.printf("enumerating %s worlds...\n", numTotalWorlds);
sw.start();
WeightedSample s = new WeightedSample(bn);
timer = new Stopwatch();
timer.start();
enumerateWorlds(s, nodeOrder, evidenceDomainIndices, 0, 1);
sw.stop();
report(String.format("\ntime taken: %.2fs (%f worlds enumerated, %d paths pruned)\n", sw.getElapsedTimeSecs(), numWorldsCounted, numPathsPruned));
}
public void enumerateWorlds(WeightedSample s, int[] nodeOrder, int[] evidenceDomainIndices, int i, double combinationsHandled) throws Exception {
//out.printf("enum %s, domain size = %d\n", nodes[nodeOrder[i]].getName(), nodes[nodeOrder[i]].getDomain().getOrder());
// status messages
if(timer.getElapsedTimeSecs() > 1) {
double numDone = numWorldsCounted+numWorldsPruned;
if(verbose) out.printf(" ~ %.4f%% done (%s worlds handled, %d paths pruned)\r", 100.0*numDone/numTotalWorlds, numDone, numPathsPruned);
timer = new Stopwatch();
timer.start();
}
// if we have completed the world, we are done and can add the world as a sample
if(i == nodes.length) {
//out.println("counting sample");
addSample(s);
numWorldsCounted++;
return;
}
// otherwise continue
int nodeIdx = nodeOrder[i];
combinationsHandled *= nodes[nodeOrder[i]].getDomain().getOrder();
int domainIdx = evidenceDomainIndices[nodeIdx];
// for evidence nodes, adjust the weight
if(domainIdx >= 0) {
s.nodeDomainIndices[nodeIdx] = domainIdx;
double prob = getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices);
s.weight *= prob;
if(prob == 0.0) { // we have reached zero, so we can save us the trouble of further ramifications
//out.println("zero reached");
numPathsPruned++;
numWorldsPruned += numTotalWorlds / combinationsHandled;
return;
}
enumerateWorlds(s, nodeOrder, evidenceDomainIndices, i+1, combinationsHandled);
}
// for non-evidence nodes, consider all settings
else {
Domain d = nodes[nodeIdx].getDomain();
int order = d.getOrder();
//out.println(" enumerating all " + order + " cases for " + nodes[nodeIdx].getName());
double weight = s.weight;
for(int j = 0; j < order; j++) {
s.nodeDomainIndices[nodeIdx] = j;
double prob = getCPTProbability(nodes[nodeIdx], s.nodeDomainIndices);
if(prob == 0.0) {
//out.println("zero reached");
numPathsPruned++;
numWorldsPruned += numTotalWorlds / combinationsHandled;
continue;
}
s.weight = weight * prob;
enumerateWorlds(s, nodeOrder, evidenceDomainIndices, i+1, combinationsHandled);
}
}
}
}