/*
* File CompoundProbabilityDistribution.java
*
* Copyright (C) 2010 Remco Bouckaert remco@cs.auckland.ac.nz
*
* This file is part of BEAST2.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package beast.core.util;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.RejectedExecutionException;
import beast.app.BeastMCMC;
import beast.core.BEASTInterface;
import beast.core.Description;
import beast.core.Distribution;
import beast.core.Input;
import beast.core.State;
@Description("Takes a collection of distributions, typically a number of likelihoods " +
"and priors and combines them into the compound of these distributions " +
"typically interpreted as the posterior.")
public class CompoundDistribution extends Distribution {
// no need to make this input REQUIRED. If no distribution input is
// specified the class just returns probability 1.
final public Input<List<Distribution>> pDistributions =
new Input<>("distribution",
"individual probability distributions, e.g. the likelihood and prior making up a posterior",
new ArrayList<>());
final public Input<Boolean> useThreadsInput = new Input<>("useThreads", "calculated the distributions in parallel using threads (default false)", false);
final public Input<Integer> maxNrOfThreadsInput = new Input<>("threads","maximum number of threads to use, if less than 1 the number of threads in BeastMCMC is used (default -1)", -1);
final public Input<Boolean> ignoreInput = new Input<>("ignore", "ignore all distributions and return 1 as distribution (default false)", false);
/**
* flag to indicate threads should be used. Only effective if the useThreadsInput is
* true and BeasMCMC.nrOfThreads > 1
*/
boolean useThreads;
int nrOfThreads;
boolean ignore;
public static ExecutorService exec;
@Override
public void initAndValidate() {
super.initAndValidate();
useThreads = useThreadsInput.get() && (BeastMCMC.m_nThreads > 1);
nrOfThreads = useThreads ? BeastMCMC.m_nThreads : 1;
if (useThreads && maxNrOfThreadsInput.get() > 0) {
nrOfThreads = Math.min(maxNrOfThreadsInput.get(), BeastMCMC.m_nThreads);
}
if (useThreads) {
exec = Executors.newFixedThreadPool(nrOfThreads);
}
ignore = ignoreInput.get();
if (pDistributions.get().size() == 0) {
logP = 0;
}
// for(Distribution dists : pDistributions.get()) {
// logP += dists.calculateLogP();
// }
}
/**
* Distribution implementation follows *
*/
@Override
public double calculateLogP() {
logP = 0;
if (ignore) {
return logP;
}
int workAvailable = 0;
if (useThreads) {
for (Distribution dists : pDistributions.get()) {
if (dists.isDirtyCalculation()) {
workAvailable++;
}
}
}
if (useThreads && workAvailable > 1) {
logP = calculateLogPUsingThreads();
} else {
for (Distribution dists : pDistributions.get()) {
if (dists.isDirtyCalculation()) {
logP += dists.calculateLogP();
} else {
logP += dists.getCurrentLogP();
}
if (Double.isInfinite(logP) || Double.isNaN(logP)) {
return logP;
}
}
}
return logP;
}
class CoreRunnable implements Runnable {
Distribution distr;
CoreRunnable(Distribution core) {
distr = core;
}
@Override
public void run() {
try {
if (distr.isDirtyCalculation()) {
logP += distr.calculateLogP();
} else {
logP += distr.getCurrentLogP();
}
} catch (Exception e) {
Log.err.println("Something went wrong in a calculation of " + distr.getID());
e.printStackTrace();
System.exit(1);
}
countDown.countDown();
}
} // CoreRunnable
CountDownLatch countDown;
private double calculateLogPUsingThreads() {
try {
int dirtyDistrs = 0;
for (Distribution dists : pDistributions.get()) {
if (dists.isDirtyCalculation()) {
dirtyDistrs++;
}
}
countDown = new CountDownLatch(dirtyDistrs);
// kick off the threads
for (Distribution dists : pDistributions.get()) {
if (dists.isDirtyCalculation()) {
CoreRunnable coreRunnable = new CoreRunnable(dists);
exec.execute(coreRunnable);
}
}
countDown.await();
logP = 0;
for (Distribution distr : pDistributions.get()) {
logP += distr.getCurrentLogP();
}
return logP;
} catch (RejectedExecutionException | InterruptedException e) {
useThreads = false;
Log.err.println("Stop using threads: " + e.getMessage());
return calculateLogP();
}
}
@Override
public void sample(State state, Random random) {
for (Distribution distribution : pDistributions.get()) {
distribution.sample(state, random);
}
}
@Override
public List<String> getArguments() {
List<String> arguments = new ArrayList<>();
for (Distribution distribution : pDistributions.get()) {
arguments.addAll(distribution.getArguments());
}
return arguments;
}
@Override
public List<String> getConditions() {
List<String> conditions = new ArrayList<>();
for (Distribution distribution : pDistributions.get()) {
conditions.addAll(distribution.getConditions());
}
return conditions;
}
@Override
public List<BEASTInterface> listActiveBEASTObjects() {
if (ignoreInput.get()) {
return new ArrayList<>();
} else {
return super.listActiveBEASTObjects();
}
}
@Override
public boolean isStochastic() {
for (Distribution distribution : pDistributions.get()) {
if (distribution.isStochastic())
return true;
}
return false;
}
@Override
public double getNonStochasticLogP() {
double logP = 0;
if (ignore) {
return logP;
}
// The loop could gain a little bit from being multithreaded
// though getNonStochasticLogP is called for debugging purposes only
// so efficiency is not an immediate issue.
for (Distribution dists : pDistributions.get()) {
logP += dists.getNonStochasticLogP();
if (Double.isInfinite(logP) || Double.isNaN(logP)) {
return logP;
}
}
return logP;
}
} // class CompoundDistribution