/*******************************************************************************
* Copyright (C) 2008-2012 Dominik Jain.
*
* This file is part of ProbCog.
*
* ProbCog is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* ProbCog is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with ProbCog. If not, see <http://www.gnu.org/licenses/>.
******************************************************************************/
package probcog.bayesnets.inference;
import java.io.PrintStream;
import probcog.bayesnets.core.BeliefNetworkEx;
import probcog.inference.BasicSampledDistribution;
import edu.ksu.cis.bnj.ver3.core.BeliefNode;
import edu.ksu.cis.bnj.ver3.core.Discrete;
/**
* class that allows the incremental construction of a probability distribution from (weighted) samples
* (see {@link WeightedSample})
* @author Dominik Jain
*
*/
public class SampledDistribution extends BasicSampledDistribution implements Cloneable {
/**
* the belief network for which we are representing a distribution
*/
public BeliefNetworkEx bn;
/**
* values that may be used by certain algorithms to store the number of steps involved in creating the distribution
*/
public int steps, trials, operations;
protected double maxWeight = 0.0;
protected boolean debug = true;
protected BeliefNode[] nodes;
public SampledDistribution(BeliefNetworkEx bn) throws Exception {
this.bn = bn;
this.Z = 0.0;
nodes = bn.bn.getNodes();
values = new double[nodes.length][];
for(int i = 0; i < nodes.length; i++)
values[i] = new double[nodes[i].getDomain().getOrder()];
}
public synchronized void addSample(WeightedSample s) {
if(s.weight == 0.0) {
throw new RuntimeException("Zero-weight sample was added to distribution. Precision loss?");
}
// update normalization constant and maximum weight
Z += s.weight;
if(maxWeight < s.weight)
maxWeight = s.weight;
// debug info
if(debug) {
double prob = bn.getWorldProbability(s.nodeDomainIndices);
/*for(int i = 0; i < nodes.length; i++) {
System.out.printf(" %s = %s\n", nodes[i].getName(), nodes[i].getDomain().getName(s.nodeDomainIndices[i]));
}*/
System.out.printf("sample weight: %s (%.2f%%); max weight: %s (%.2f%%); prob: %s\n", s.weight, s.weight*100/Z, maxWeight, maxWeight*100/Z, prob);
}
// update distribution values
for(int i = 0; i < s.nodeIndices.length; i++) {
try {
values[s.nodeIndices[i]][s.nodeDomainIndices[i]] += s.weight;
}
catch(ArrayIndexOutOfBoundsException e) {
System.err.println("Error: Node " + nodes[s.nodeIndices[i]].getName() + " was not sampled correctly.");
throw e;
}
}
// update number of steps and trials
trials += s.trials;
operations += s.operations;
steps++;
}
@Override
public void printVariableDistribution(PrintStream out, int index) {
BeliefNode node = nodes[index];
out.println(node.getName() + ":");
Discrete domain = (Discrete)node.getDomain();
for(int j = 0; j < domain.getOrder(); j++) {
double prob = values[index][j] / Z;
out.println(String.format(" %.4f %s", prob, domain.getName(j)));
}
}
public double getTrialsPerStep() {
return (double)trials/steps;
}
@Override
public synchronized SampledDistribution clone() throws CloneNotSupportedException {
return (SampledDistribution)super.clone();
}
@Override
public String[] getDomain(int idx) {
return BeliefNetworkEx.getDiscreteDomainAsArray(bn.getNode(idx));
}
@Override
public String getVariableName(int idx) {
return bn.getNode(idx).getName();
}
@Override
public int getVariableIndex(String name) {
return bn.getNodeIndex(name);
}
public void setDebugMode(boolean active) {
debug = active;
}
@Override
public Integer getNumSamples() {
return steps;
}
}