/*
* File MCMC.java
*
* Copyright (C) 2010 Remco Bouckaert remco@cs.auckland.ac.nz
*
* This file is part of BEAST2.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package beast.core;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import javax.xml.parsers.ParserConfigurationException;
import org.xml.sax.SAXException;
import beast.core.util.CompoundDistribution;
import beast.core.util.Evaluator;
import beast.core.util.Log;
import beast.util.Randomizer;
@Description("MCMC chain. This is the main element that controls which posterior " +
"to calculate, how long to run the chain and all other properties, " +
"which operators to apply on the state space and where to log results.")
@Citation(value=
"Bouckaert RR, Heled J, Kuehnert D, Vaughan TG, Wu C-H, Xie D, Suchard MA,\n" +
" Rambaut A, Drummond AJ (2014) BEAST 2: A software platform for Bayesian\n" +
" evolutionary analysis. PLoS Computational Biology 10(4): e1003537"
, year = 2014, firstAuthorSurname = "bouckaert",
DOI="10.1371/journal.pcbi.1003537")
public class MCMC extends Runnable {
final public Input<Integer> chainLengthInput =
new Input<>("chainLength", "Length of the MCMC chain i.e. number of samples taken in main loop",
Input.Validate.REQUIRED);
final public Input<State> startStateInput =
new Input<>("state", "elements of the state space");
final public Input<List<StateNodeInitialiser>> initialisersInput =
new Input<>("init", "one or more state node initilisers used for determining " +
"the start state of the chain",
new ArrayList<>());
final public Input<Integer> storeEveryInput =
new Input<>("storeEvery", "store the state to disk every X number of samples so that we can " +
"resume computation later on if the process failed half-way.", -1);
final public Input<Integer> burnInInput =
new Input<>("preBurnin", "Number of burn in samples taken before entering the main loop", 0);
final public Input<Integer> numInitializationAttempts =
new Input<>("numInitializationAttempts", "Number of initialization attempts before failing (default=10)", 10);
final public Input<Distribution> posteriorInput =
new Input<>("distribution", "probability distribution to sample over (e.g. a posterior)",
Input.Validate.REQUIRED);
final public Input<List<Operator>> operatorsInput =
new Input<>("operator", "operator for generating proposals in MCMC state space",
new ArrayList<>());//, Input.Validate.REQUIRED);
final public Input<List<Logger>> loggersInput =
new Input<>("logger", "loggers for reporting progress of MCMC chain",
new ArrayList<>(), Input.Validate.REQUIRED);
final public Input<Boolean> sampleFromPriorInput = new Input<>("sampleFromPrior", "whether to ignore the likelihood when sampling (default false). " +
"The distribution with id 'likelihood' in the posterior input will be ignored when this flag is set.", false);
final public Input<OperatorSchedule> operatorScheduleInput = new Input<>("operatorschedule", "specify operator selection and optimisation schedule", new OperatorSchedule());
/**
* Alternative representation of operatorsInput that allows random selection
* of operators and calculation of statistics.
*/
protected OperatorSchedule operatorSchedule;
/**
* The state that takes care of managing StateNodes,
* operations on StateNodes and propagates store/restore/requireRecalculation
* calls to the appropriate BEASTObjects.
*/
protected State state;
/**
* number of samples taken where calculation is checked against full
* recalculation of the posterior. Note that after every proposal that
* is checked, there are 2 that are not checked. This allows errors
* in store/restore to be detected that cannot be found when every single
* consecutive sample is checked.
* So, only after 3*NR_OF_DEBUG_SAMPLES samples checking is stopped.
*/
final protected int NR_OF_DEBUG_SAMPLES = 2000;
/**
* Interval for storing state to disk, if negative the state will not be stored periodically *
* Mirrors m_storeEvery input, or if this input is negative, the State.m_storeEvery input
*/
protected int storeEvery;
/**
* Set this to true to enable detailed MCMC debugging information
* to be displayed.
*/
private static final boolean printDebugInfo = false;
public MCMC() {
}
@Override
public void initAndValidate() {
Log.info.println("===============================================================================");
Log.info.println("Citations for this model:");
Log.info.println(getCitations());
Log.info.println("===============================================================================");
operatorSchedule = operatorScheduleInput.get();
for (final Operator op : operatorsInput.get()) {
operatorSchedule.addOperator(op);
}
if (sampleFromPriorInput.get()) {
// remove beastObject with id likelihood from posterior, if it is a CompoundDistribution
if (posteriorInput.get() instanceof CompoundDistribution) {
final CompoundDistribution posterior = (CompoundDistribution) posteriorInput.get();
final List<Distribution> distrs = posterior.pDistributions.get();
final int distrCount = distrs.size();
for (int i = 0; i < distrCount; i++) {
final Distribution distr = distrs.get(i);
final String id = distr.getID();
if (id != null && id.equals("likelihood")) {
distrs.remove(distr);
break;
}
}
if (distrs.size() == distrCount) {
throw new RuntimeException("Sample from prior flag is set, but distribution with id 'likelihood' is " +
"not an input to posterior.");
}
} else {
throw new RuntimeException("Don't know how to sample from prior since posterior is not a compound distribution. " +
"Suggestion: set sampleFromPrior flag to false.");
}
}
// StateNode initialisation, only required when the state is not read from file
if (restoreFromFile) {
final HashSet<StateNode> initialisedStateNodes = new HashSet<>();
for (final StateNodeInitialiser initialiser : initialisersInput.get()) {
// make sure that the initialiser does not re-initialises a StateNode
final List<StateNode> list = new ArrayList<>(1);
initialiser.getInitialisedStateNodes(list);
for (final StateNode stateNode : list) {
if (initialisedStateNodes.contains(stateNode)) {
throw new RuntimeException("Trying to initialise stateNode (id=" + stateNode.getID() + ") more than once. " +
"Remove an initialiser from MCMC to fix this.");
}
}
initialisedStateNodes.addAll(list);
// do the initialisation
//initialiser.initStateNodes();
}
}
// State initialisation
final HashSet<StateNode> operatorStateNodes = new HashSet<>();
for (final Operator op : operatorsInput.get()) {
for (final StateNode stateNode : op.listStateNodes()) {
operatorStateNodes.add(stateNode);
}
}
if (startStateInput.get() != null) {
this.state = startStateInput.get();
if (storeEveryInput.get() > 0) {
this.state.m_storeEvery.setValue(storeEveryInput.get(), this.state);
}
} else {
// create state from scratch by collecting StateNode inputs from Operators
this.state = new State();
for (final StateNode stateNode : operatorStateNodes) {
this.state.stateNodeInput.setValue(stateNode, this.state);
}
this.state.m_storeEvery.setValue(storeEveryInput.get(), this.state);
}
// grab the interval for storing the state to file
if (storeEveryInput.get() > 0) {
storeEvery = storeEveryInput.get();
} else {
storeEvery = state.m_storeEvery.get();
}
this.state.initialise();
this.state.setPosterior(posteriorInput.get());
// sanity check: all operator state nodes should be in the state
final List<StateNode> stateNodes = this.state.stateNodeInput.get();
for (final Operator op : operatorsInput.get()) {
List<StateNode> nodes = op.listStateNodes();
if (nodes.size() == 0) {
throw new RuntimeException("Operator " + op.getID() + " has no state nodes in the state. "
+ "Each operator should operate on at least one estimated state node in the state. "
+ "Remove the operator or add its statenode(s) to the state and/or set estimate='true'.");
// otherwise the chain may hang without obvious reason
}
for (final StateNode stateNode : op.listStateNodes()) {
if (!stateNodes.contains(stateNode)) {
throw new RuntimeException("Operator " + op.getID() + " has a statenode " + stateNode.getID() + " in its inputs that is missing from the state.");
}
}
}
// sanity check: at least one operator required to run MCMC
if (operatorsInput.get().size() == 0) {
Log.warning.println("Warning: at least one operator required to run the MCMC properly, but none found.");
}
// sanity check: all state nodes should be operated on
for (final StateNode stateNode : stateNodes) {
if (!operatorStateNodes.contains(stateNode)) {
Log.warning.println("Warning: state contains a node " + stateNode.getID() + " for which there is no operator.");
}
}
} // init
public void log(final int sampleNr) {
for (final Logger log : loggers) {
log.log(sampleNr);
}
} // log
public void close() {
for (final Logger log : loggers) {
log.close();
}
} // close
protected double logAlpha;
protected boolean debugFlag;
protected double oldLogLikelihood;
protected double newLogLikelihood;
protected int burnIn;
protected int chainLength;
protected Distribution posterior;
protected List<Logger> loggers;
@Override
public void run() throws IOException, SAXException, ParserConfigurationException {
// set up state (again). Other beastObjects may have manipulated the
// StateNodes, e.g. set up bounds or dimensions
state.initAndValidate();
// also, initialise state with the file name to store and set-up whether to resume from file
state.setStateFileName(stateFileName);
operatorSchedule.setStateFileName(stateFileName);
burnIn = burnInInput.get();
chainLength = chainLengthInput.get();
int initialisationAttempts = 0;
state.setEverythingDirty(true);
posterior = posteriorInput.get();
if (restoreFromFile) {
state.restoreFromFile();
operatorSchedule.restoreFromFile();
burnIn = 0;
oldLogLikelihood = state.robustlyCalcPosterior(posterior);
} else {
do {
for (final StateNodeInitialiser initialiser : initialisersInput.get()) {
initialiser.initStateNodes();
}
oldLogLikelihood = state.robustlyCalcPosterior(posterior);
initialisationAttempts += 1;
} while (Double.isInfinite(oldLogLikelihood) && initialisationAttempts < numInitializationAttempts.get());
}
final long startTime = System.currentTimeMillis();
state.storeCalculationNodes();
// do the sampling
logAlpha = 0;
debugFlag = Boolean.valueOf(System.getProperty("beast.debug"));
// System.err.println("Start state:");
// System.err.println(state.toString());
Log.info.println("Start likelihood: " + oldLogLikelihood + " " + (initialisationAttempts > 1 ? "after " + initialisationAttempts + " initialisation attempts" : ""));
if (Double.isInfinite(oldLogLikelihood) || Double.isNaN(oldLogLikelihood)) {
reportLogLikelihoods(posterior, "");
throw new RuntimeException("Could not find a proper state to initialise. Perhaps try another seed.");
}
loggers = loggersInput.get();
// put the loggers logging to stdout at the bottom of the logger list so that screen output is tidier.
Collections.sort(loggers, (o1, o2) -> {
if (o1.isLoggingToStdout()) {
return o2.isLoggingToStdout() ? 0 : 1;
} else {
return o2.isLoggingToStdout() ? -1 : 0;
}
});
// warn if none of the loggers is to stdout, so no feedback is given on screen
boolean hasStdOutLogger = false;
boolean hasScreenLog = false;
for (Logger l : loggers) {
if (l.isLoggingToStdout()) {
hasStdOutLogger = true;
}
if (l.getID() != null && l.getID().equals("screenlog")) {
hasScreenLog = true;
}
}
if (!hasStdOutLogger) {
Log.warning.println("WARNING: If nothing seems to be happening on screen this is because none of the loggers give feedback to screen.");
if (hasScreenLog) {
Log.warning.println("WARNING: This happens when a filename is specified for the 'screenlog' logger.");
Log.warning.println("WARNING: To get feedback to screen, leave the filename for screenlog blank.");
Log.warning.println("WARNING: Otherwise, the screenlog is saved into the specified file.");
}
}
// initialises log so that log file headers are written, etc.
for (final Logger log : loggers) {
log.init();
}
doLoop();
Log.info.println();
operatorSchedule.showOperatorRates(System.out);
Log.info.println();
final long endTime = System.currentTimeMillis();
Log.info.println("Total calculation time: " + (endTime - startTime) / 1000.0 + " seconds");
close();
Log.warning.println("End likelihood: " + oldLogLikelihood);
// System.err.println(state);
state.storeToFile(chainLength);
operatorSchedule.storeToFile();
//Randomizer.storeToFile(stateFileName);
} // run;
/**
* main MCMC loop
* @throws IOException *
*/
protected void doLoop() throws IOException {
int corrections = 0;
final boolean isStochastic = posterior.isStochastic();
if (burnIn > 0) {
Log.warning.println("Please wait while BEAST takes " + burnIn + " pre-burnin samples");
}
for (int sampleNr = -burnIn; sampleNr <= chainLength; sampleNr++) {
final Operator operator = propagateState(sampleNr);
if (debugFlag && sampleNr % 3 == 0 || sampleNr % 10000 == 0) {
// check that the posterior is correctly calculated at every third
// sample, as long as we are in debug mode
final double originalLogP = isStochastic ? posterior.getNonStochasticLogP() : oldLogLikelihood;
final double logLikelihood = isStochastic ? state.robustlyCalcNonStochasticPosterior(posterior) : state.robustlyCalcPosterior(posterior);
if (isTooDifferent(logLikelihood, originalLogP)) {
reportLogLikelihoods(posterior, "");
Log.err.println("At sample " + sampleNr + "\nLikelihood incorrectly calculated: " + originalLogP + " != " + logLikelihood
+ "(" + (originalLogP - logLikelihood) + ")"
+ " Operator: " + operator.getClass().getName());
}
if (sampleNr > NR_OF_DEBUG_SAMPLES * 3) {
// switch off debug mode once a sufficient large sample is checked
debugFlag = false;
if (isTooDifferent(logLikelihood, originalLogP)) {
// incorrect calculation outside debug period.
// This happens infrequently enough that it should repair itself after a robust posterior calculation
corrections++;
if (corrections > 100) {
// after 100 repairs, there must be something seriously wrong with the implementation
Log.err.println("Too many corrections. There is something seriously wrong that cannot be corrected");
state.storeToFile(sampleNr);
operatorSchedule.storeToFile();
System.exit(1);
}
oldLogLikelihood = state.robustlyCalcPosterior(posterior);;
}
} else {
if (isTooDifferent(logLikelihood, originalLogP)) {
// halt due to incorrect posterior during intial debug period
state.storeToFile(sampleNr);
operatorSchedule.storeToFile();
System.exit(1);
}
}
} else {
if (sampleNr >= 0) {
operator.optimize(logAlpha);
}
}
callUserFunction(sampleNr);
// make sure we always save just before exiting
if (storeEvery > 0 && (sampleNr + 1) % storeEvery == 0 || sampleNr == chainLength) {
/*final double logLikelihood = */
state.robustlyCalcNonStochasticPosterior(posterior);
state.storeToFile(sampleNr);
operatorSchedule.storeToFile();
}
if (posterior.getCurrentLogP() == Double.POSITIVE_INFINITY) {
throw new RuntimeException("Encountered a positive infinite posterior. This is a sign there may be numeric instability in the model.");
}
}
if (corrections > 0) {
Log.err.println("\n\nNB: " + corrections + " posterior calculation corrections were required. This analysis may not be valid!\n\n");
}
}
/**
* Perform a single MCMC propose+accept/reject step.
*
* @param sampleNr the index of the current MCMC step
* @return the selected {@link beast.core.Operator}
*/
protected Operator propagateState(final int sampleNr) {
state.store(sampleNr);
// if (m_nStoreEvery > 0 && sample % m_nStoreEvery == 0 && sample > 0) {
// state.storeToFile(sample);
// operatorSchedule.storeToFile();
// }
final Operator operator = operatorSchedule.selectOperator();
if (printDebugInfo) System.err.print("\n" + sampleNr + " " + operator.getName()+ ":");
final Distribution evaluatorDistribution = operator.getEvaluatorDistribution();
Evaluator evaluator = null;
if (evaluatorDistribution != null) {
evaluator = new Evaluator() {
@Override
public double evaluate() {
double logP = 0.0;
state.storeCalculationNodes();
state.checkCalculationNodesDirtiness();
try {
logP = evaluatorDistribution.calculateLogP();
} catch (Exception e) {
e.printStackTrace();
System.exit(1);
}
state.restore();
state.store(sampleNr);
return logP;
}
};
}
final double logHastingsRatio = operator.proposal(evaluator);
if (logHastingsRatio != Double.NEGATIVE_INFINITY) {
if (operator.requiresStateInitialisation()) {
state.storeCalculationNodes();
state.checkCalculationNodesDirtiness();
}
newLogLikelihood = posterior.calculateLogP();
logAlpha = newLogLikelihood - oldLogLikelihood + logHastingsRatio; //CHECK HASTINGS
if (printDebugInfo) System.err.print(logAlpha + " " + newLogLikelihood + " " + oldLogLikelihood);
if (logAlpha >= 0 || Randomizer.nextDouble() < Math.exp(logAlpha)) {
// accept
oldLogLikelihood = newLogLikelihood;
state.acceptCalculationNodes();
if (sampleNr >= 0) {
operator.accept();
}
if (printDebugInfo) System.err.print(" accept");
} else {
// reject
if (sampleNr >= 0) {
operator.reject(newLogLikelihood == Double.NEGATIVE_INFINITY ? -1 : 0);
}
state.restore();
state.restoreCalculationNodes();
if (printDebugInfo) System.err.print(" reject");
}
state.setEverythingDirty(false);
} else {
// operation failed
if (sampleNr >= 0) {
operator.reject(-2);
}
state.restore();
if (!operator.requiresStateInitialisation()) {
state.setEverythingDirty(false);
state.restoreCalculationNodes();
}
if (printDebugInfo) System.err.print(" direct reject");
}
log(sampleNr);
return operator;
}
private boolean isTooDifferent(double logLikelihood, double originalLogP) {
//return Math.abs((logLikelihood - originalLogP)/originalLogP) > 1e-6;
return Math.abs(logLikelihood - originalLogP) > 1e-6;
}
/*
* report posterior and subcomponents recursively, for debugging
* incorrectly recalculated posteriors *
*/
protected void reportLogLikelihoods(final Distribution distr, final String tabString) {
final double full = distr.logP, last = distr.storedLogP;
final String changed = full == last ? "" : " **";
Log.info.println(tabString + "P(" + distr.getID() + ") = " + full + " (was " + last + ")" + changed);
if (distr instanceof CompoundDistribution) {
for (final Distribution distr2 : ((CompoundDistribution) distr).pDistributions.get()) {
reportLogLikelihoods(distr2, tabString + "\t");
}
}
}
protected void callUserFunction(final int sample) {
}
/**
* Calculate posterior by setting all StateNodes and CalculationNodes dirty.
* Clean everything afterwards.
*/
public double robustlyCalcPosterior(final Distribution posterior) {
return state.robustlyCalcPosterior(posterior);
}
/**
* Calculate posterior by setting all StateNodes and CalculationNodes dirty.
* Clean everything afterwards.
*/
public double robustlyCalcNonStochasticPosterior(final Distribution posterior) {
return state.robustlyCalcNonStochasticPosterior(posterior);
}
} // class MCMC