/*******************************************************************************
* Copyright (C) 2008-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.io.PrintStream;
import java.util.Collection;
import java.util.HashMap;
import java.util.Random;
import probcog.bayesnets.core.BeliefNetworkEx;
import probcog.inference.IParameterHandler;
import probcog.inference.ParameterHandler;
import probcog.inference.BasicSampledDistribution.ConfidenceInterval;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.CPF;
import edu.tum.cs.util.Stopwatch;
public abstract class Sampler implements ITimeLimitedInference, IParameterHandler {
public BeliefNetworkEx bn;
public HashMap<BeliefNode, Integer> nodeIndices;
public Random generator;
public BeliefNode[] nodes;
public int[] evidenceDomainIndices;
protected ParameterHandler paramHandler;
protected Collection<Integer> queryVars = null;
protected StringBuffer report = new StringBuffer();
protected boolean verbose;
protected PrintStream out;
protected boolean initialized = false;
protected IDistributionBuilder distributionBuilder;
/**
* general sampler setting: how many samples to pull from the distribution
*/
public int numSamples = 1000;
protected int maxTrials = 5000;
protected boolean skipFailedSteps = false;
protected Double confidenceIntervalSizeThreshold = null;
public double convergenceCheckInterval = 100;
protected double totalInferenceTime, initTime, inferenceTime;
/**
* general sampler setting: after how many samples to display a message that reports the current status
*/
public int infoInterval = 100;
public boolean debug = false;
public Sampler(BeliefNetworkEx bn) throws Exception {
this.bn = bn;
this.nodes = bn.bn.getNodes();
nodeIndices = new HashMap<BeliefNode, Integer>();
for(int i = 0; i < nodes.length; i++) {
nodeIndices.put(nodes[i], i);
}
generator = new Random();
setVerbose(true);
paramHandler = new ParameterHandler(this);
paramHandler.add("confidenceIntervalSizeThreshold", "setConfidenceIntervalSizeThreshold");
paramHandler.add("randomSeed", "setRandomSeed");
paramHandler.add("verbose", "setVerbose");
}
protected SampledDistribution createDistribution() throws Exception {
SampledDistribution dist = new SampledDistribution(bn);
dist.setDebugMode(debug);
paramHandler.addSubhandler(dist.getParameterHandler());
return dist;
}
protected synchronized void addSample(WeightedSample s) throws Exception {
// security check: in debug mode, check if sample respects evidence
if(debug) {
for(int i = 0; i < evidenceDomainIndices.length; i++)
if(evidenceDomainIndices[i] >= 0 && s.nodeDomainIndices[i] != evidenceDomainIndices[i])
throw new Exception("Attempted to add sample to distribution that does not respect evidence");
}
// add to distribution builder
distributionBuilder.addSample(s);
}
public void setQueryVars(Collection<Integer> queryVars) {
this.queryVars = queryVars;
initialized = false;
}
protected boolean converged() throws Exception {
if(!(this.distributionBuilder instanceof DirectDistributionBuilder))
return false;
SampledDistribution dist = distributionBuilder.getDistribution();
if(dist.getNumSamples() % this.convergenceCheckInterval != 0)
return false; // TODO assumes that all algorithms call this method after each step
// determine convergence based on confidence interval sizes
if(confidenceIntervalSizeThreshold != null) {
if(!dist.usesConfidenceComputation())
throw new Exception("Cannot determine convergence based on confidence interval size: No confidence level specified.");
double max = 0;
for(Integer i : queryVars) {
ConfidenceInterval interval = dist.getConfidenceInterval(i, 0);
max = Math.max(max, interval.getSize());
}
if(max <= confidenceIntervalSizeThreshold) {
if(verbose) System.out.printf("Convergence criterion reached: maximum confidence interval size = %f\n", max);
return true;
}
}
return false;
}
public void setConfidenceIntervalSizeThreshold(double t) {
confidenceIntervalSizeThreshold = t;
}
/**
* polls the results during time-limited inference
* @return
* @throws Exception
*/
public synchronized SampledDistribution pollResults() throws Exception {
if(distributionBuilder == null)
return null;
SampledDistribution dist = distributionBuilder.getDistribution();
if(dist == null)
return null;
return dist.clone();
}
/**
* samples from a distribution whose normalization constant is not known
* @param distribution
* @param generator
* @return the index of the value that was sampled (or -1 if the distribution is not well-defined)
*/
public static int sample(double[] distribution, Random generator) {
double sum = 0;
for(int i = 0; i < distribution.length; i++)
sum += distribution[i];
return sample(distribution, sum, generator);
}
/**
* samples from the given distribution
* @param distribution
* @param sum the distribution's normalization constant
* @param generator
* @return the index of the value that was sampled (or -1 if the distribution is not well-defined)
*/
public static int sample(double[] distribution, double sum, Random generator) {
double random = generator.nextDouble() * sum;
int ret = 0;
sum = 0;
int i = 0;
while(sum < random && i < distribution.length) {
sum += distribution[ret = i++];
}
return sum >= random ? ret : -1;
}
/**
* samples from a distribution whose normalization constant is not known
* @param distribution
* @param generator
* @return the index of the value in the collection that was sampled (or -1 if the distribution is not well-defined)
*/
public static int sample(Collection<Double> distribution, Random generator) {
double sum = 0;
for(Double d : distribution)
sum += d;
return sample(distribution, sum, generator);
}
/**
* samples from the given distribuion
* @param distribution
* @param sum the distribution's normalization constant
* @param generator
* @return the index of the value in the collection that was sampled (or -1 if the distribution is not well-defined)
*/
public static int sample(Collection<Double> distribution, double sum, Random generator) {
double random = generator.nextDouble() * sum;
if(sum == 0)
return -1;
sum = 0;
int i = 0;
for(Double d : distribution) {
sum += d;
if(sum >= random)
return i;
++i;
}
return -1;
}
/**
* gets the CPT entry of the given node for the configuration of parents that is provided in the array of domain indices
* @param node
* @param nodeDomainIndices domain indices for each node in the network (only the parents of 'node' are required to be set)
* @return the probability value
*/
protected double getCPTProbability(BeliefNode node, int[] nodeDomainIndices) {
CPF cpf = node.getCPF();
BeliefNode[] domProd = cpf.getDomainProduct();
int[] addr = new int[domProd.length];
for(int i = 0; i < addr.length; i++)
addr[i] = nodeDomainIndices[this.nodeIndices.get(domProd[i])];
return cpf.getDouble(addr);
}
public void setNumSamples(int numSamples) {
this.numSamples = numSamples;
}
public void setInfoInterval(int infoInterval) {
this.infoInterval = infoInterval;
}
public void setMaxTrials(int maxTrials) {
this.maxTrials = maxTrials;
}
public void setSkipFailedSteps(boolean canSkip) {
this.skipFailedSteps = canSkip;
}
public void setEvidence(int[] evidenceDomainIndices) throws Exception {
this.evidenceDomainIndices = evidenceDomainIndices;
initialized = false;
}
public void setRandomSeed(int seed) {
generator.setSeed(seed);
}
protected abstract void _infer() throws Exception;
protected void _initialize() throws Exception {}
/**
* initializes the inference method such that inference can be run
*/
public final void initialize() throws Exception {
Stopwatch sw = new Stopwatch();
sw.start();
_initialize();
distributionBuilder = createDistributionBuilder();
sw.stop();
initTime = sw.getElapsedTimeSecs();
initialized = true;
}
/**
* runs the actual inference method (initializing first if necessary)
*/
public final SampledDistribution infer() throws Exception {
// initialize
if(!initialized)
initialize();
// run inference
Stopwatch sw = new Stopwatch();
sw.start();
_infer();
inferenceTime = sw.getElapsedTimeSecs();
report(String.format("total inference time: %fs (initialization: %fs; core %fs)\n", getTotalInferenceTime(), getInitTime(), getInferenceTime()));
if(verbose) out.print(report.toString());
return distributionBuilder.getDistribution();
}
/**
* @return returns the distribution builder that creates the distribution
* based on weighted samples
* @throws Exception
*/
protected IDistributionBuilder createDistributionBuilder() throws Exception {
return new DirectDistributionBuilder(createDistribution());
}
/**
* @return the time taken for the inference process in seconds
*/
public double getTotalInferenceTime() {
return getInferenceTime() + getInitTime();
}
public double getInferenceTime() {
return inferenceTime;
}
public double getInitTime() {
return initTime;
}
/**
* samples forward, i.e. samples a value for 'node' given its parents
* @param node the node for which to sample a value
* @param nodeDomainIndices array of domain indices for all nodes in the network; the values for the parents of 'node' must be set already
* @return the index of the domain element of 'node' that is sampled, or -1 if sampling is impossible because all entries in the relevant column are 0
*/
protected int sampleForward(BeliefNode node, int[] nodeDomainIndices) {
CPF cpf = node.getCPF();
BeliefNode[] domProd = cpf.getDomainProduct();
int[] addr = new int[domProd.length];
// get the addresses of the first two relevant fields and the difference between them
for(int i = 1; i < addr.length; i++)
addr[i] = nodeDomainIndices[this.nodeIndices.get(domProd[i])];
addr[0] = 0; // (the first element in the index into the domain of the node we are sampling)
int realAddr = cpf.addr2realaddr(addr);
addr[0] = 1;
int diff = cpf.addr2realaddr(addr) - realAddr; // diff is the address difference between two consecutive entries in the relevant column
// get probabilities for outcomes
double[] cpt_entries = new double[domProd[0].getDomain().getOrder()];
double sum = 0;
for(int i = 0; i < cpt_entries.length; i++){
cpt_entries[i] = cpf.getDouble(realAddr);
sum += cpt_entries[i];
realAddr += diff;
}
// if the column contains only zeros, it is an impossible case -> cannot sample
if(sum == 0)
return -1;
return sample(cpt_entries, sum, generator);
}
public double[] getConditionalDistribution(BeliefNode node, int[] nodeDomainIndices) {
CPF cpf = node.getCPF();
BeliefNode[] domProd = cpf.getDomainProduct();
int[] addr = new int[domProd.length];
// get the addresses of the first two relevant fields and the difference between them
for(int i = 1; i < addr.length; i++)
addr[i] = nodeDomainIndices[this.nodeIndices.get(domProd[i])];
addr[0] = 0; // (the first element in the index into the domain of the node we are sampling)
int realAddr = cpf.addr2realaddr(addr);
addr[0] = 1;
int diff = cpf.addr2realaddr(addr) - realAddr; // diff is the address difference between two consecutive entries in the relevant column
// get probabilities for outcomes
double[] cpt_entries = new double[domProd[0].getDomain().getOrder()];
for(int i = 0; i < cpt_entries.length; i++){
cpt_entries[i] = cpf.getDouble(realAddr);
realAddr += diff;
}
return cpt_entries;
}
public int getNodeIndex(BeliefNode node) {
return nodeIndices.get(node);
}
public void setDebugMode(boolean active) {
debug = active;
}
public void setVerbose(boolean verbose) {
this.verbose = verbose;
if(verbose)
out = System.out;
else
out = new PrintStream(new java.io.OutputStream() { public void write(int b){} });
}
public String getAlgorithmName() {
return this.getClass().getSimpleName();
}
public ParameterHandler getParameterHandler() {
return paramHandler;
}
/**
* adds a string to the report that is displayed after the inference procedure has returned
* @param s
*/
protected void report(String s) {
this.report.append(s);
this.report.append('\n');
}
}