/*
* MarkovChain.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.inference.markovchain;
import dr.evomodel.continuous.GibbsIndependentCoalescentOperator;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.PathLikelihood;
import dr.inference.operators.*;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.logging.Logger;
/**
* A concrete markov chain. This is final as the only things that should need
* overriding are in the delegates (prior, likelihood, schedule and acceptor).
* The design of this class is to be fairly immutable as far as settings goes.
*
* @author Alexei Drummond
* @author Andrew Rambaut
* @version $Id: MarkovChain.java,v 1.10 2006/06/21 13:34:42 rambaut Exp $
*/
public final class MarkovChain implements Serializable {
private static final long serialVersionUID = 181L;
private final static boolean DEBUG = false;
private final static boolean PROFILE = true;
public static final double EVALUATION_TEST_THRESHOLD = 1e-1;
private final OperatorSchedule schedule;
private final Acceptor acceptor;
private final Likelihood likelihood;
private boolean pleaseStop = false;
private boolean isStopped = false;
private double bestScore, currentScore, initialScore;
private long currentLength;
private boolean useCoercion = true;
private final long fullEvaluationCount;
private final int minOperatorCountForFullEvaluation;
private double evaluationTestThreshold = EVALUATION_TEST_THRESHOLD;
public MarkovChain(Likelihood likelihood,
OperatorSchedule schedule, Acceptor acceptor,
long fullEvaluationCount, int minOperatorCountForFullEvaluation, double evaluationTestThreshold,
boolean useCoercion) {
currentLength = 0;
this.likelihood = likelihood;
this.schedule = schedule;
this.acceptor = acceptor;
this.useCoercion = useCoercion;
this.fullEvaluationCount = fullEvaluationCount;
this.minOperatorCountForFullEvaluation = minOperatorCountForFullEvaluation;
this.evaluationTestThreshold = evaluationTestThreshold;
Likelihood.CONNECTED_LIKELIHOOD_SET.add(likelihood);
Likelihood.CONNECTED_LIKELIHOOD_SET.addAll(likelihood.getLikelihoodSet());
for (Likelihood l : Likelihood.FULL_LIKELIHOOD_SET) {
if (!Likelihood.CONNECTED_LIKELIHOOD_SET.contains(l)) {
System.err.println("WARNING: Likelihood component, " + l.getId() + ", created but not used in the MCMC");
}
}
currentScore = evaluate(likelihood);
}
/**
* Resets the markov chain
*/
public void reset() {
currentLength = 0;
// reset operator acceptance levels
for (int i = 0; i < schedule.getOperatorCount(); i++) {
schedule.getOperator(i).reset();
}
}
/**
* Run the chain for a given number of states.
*
* @param length number of states to run the chain.
*/
public long runChain(long length, boolean disableCoerce) {
likelihood.makeDirty();
currentScore = evaluate(likelihood);
long currentState = currentLength;
final Model currentModel = likelihood.getModel();
if (currentState == 0) {
initialScore = currentScore;
bestScore = currentScore;
fireBestModel(currentState, currentModel);
}
if (currentScore == Double.NEGATIVE_INFINITY) {
// identify which component of the score is zero...
String message = "The initial likelihood is zero";
if (likelihood instanceof CompoundLikelihood) {
message += ": " + ((CompoundLikelihood) likelihood).getDiagnosis();
} else if (likelihood instanceof PathLikelihood) {
message += ": " + ((CompoundLikelihood)((PathLikelihood) likelihood).getSourceLikelihood()).getDiagnosis();
message += ": " + ((CompoundLikelihood)((PathLikelihood) likelihood).getDestinationLikelihood()).getDiagnosis();
} else {
message += ".";
}
throw new IllegalArgumentException(message);
} else if (currentScore == Double.POSITIVE_INFINITY || Double.isNaN(currentScore)) {
String message = "A likelihood returned with a numerical error";
if (likelihood instanceof CompoundLikelihood) {
message += ": " + ((CompoundLikelihood) likelihood).getDiagnosis();
} else {
message += ".";
}
throw new IllegalArgumentException(message);
}
pleaseStop = false;
isStopped = false;
//int otfcounter = onTheFlyOperatorWeights > 0 ? onTheFlyOperatorWeights : 0;
double[] logr = {0.0};
boolean usingFullEvaluation = true;
// set ops count in mcmc element instead
if (fullEvaluationCount == 0) // Temporary solution until full code review
usingFullEvaluation = false;
boolean fullEvaluationError = false;
while (!pleaseStop && (currentState < (currentLength + length))) {
String diagnosticStart = "";
// periodically log states
fireCurrentModel(currentState, currentModel);
if (pleaseStop) {
isStopped = true;
break;
}
// Get the operator
final int op = schedule.getNextOperatorIndex();
final MCMCOperator mcmcOperator = schedule.getOperator(op);
double oldScore = currentScore;
if (usingFullEvaluation) {
diagnosticStart = likelihood instanceof CompoundLikelihood ?
((CompoundLikelihood) likelihood).getDiagnosis() : "";
}
// assert Profiler.startProfile("Store");
// The current model is stored here in case the proposal fails
if (currentModel != null) {
currentModel.storeModelState();
}
// assert Profiler.stopProfile("Store");
boolean operatorSucceeded = true;
double hastingsRatio = 1.0;
boolean accept = false;
logr[0] = -Double.MAX_VALUE;
long elaspedTime = 0;
if (PROFILE) {
elaspedTime = System.currentTimeMillis();
}
// The new model is proposed
// assert Profiler.startProfile("Operate");
if (DEBUG) {
System.out.println("\n>> Iteration: " + currentState);
System.out.println("\n&& Operator: " + mcmcOperator.getOperatorName());
}
if (mcmcOperator instanceof GeneralOperator) {
hastingsRatio = ((GeneralOperator) mcmcOperator).operate(likelihood);
} else {
hastingsRatio = mcmcOperator.operate();
}
// assert Profiler.stopProfile("Operate");
if (hastingsRatio == Double.NEGATIVE_INFINITY) {
// Should the evaluation be short-cutted?
// Previously this was set to false if OperatorFailedException was thrown.
// Now a -Inf HR is returned.
operatorSucceeded = false;
}
if (PROFILE) {
long duration = System.currentTimeMillis() - elaspedTime;
if (DEBUG) {
System.out.println("Time: " + duration);
}
mcmcOperator.addEvaluationTime(duration);
}
double score = Double.NaN;
double deviation = Double.NaN;
// System.err.print("" + currentState + ": ");
if (operatorSucceeded) {
// The new model is proposed
// assert Profiler.startProfile("Evaluate");
if (DEBUG) {
System.out.println("** Evaluate");
}
long elapsedTime = 0;
if (PROFILE) {
elapsedTime = System.currentTimeMillis();
}
// The new model is evaluated
score = evaluate(likelihood);
if (PROFILE) {
long duration = System.currentTimeMillis() - elapsedTime;
if (DEBUG) {
System.out.println("Time: " + duration);
}
mcmcOperator.addEvaluationTime(duration);
}
String diagnosticOperator = "";
if (usingFullEvaluation) {
diagnosticOperator = likelihood instanceof CompoundLikelihood ?
((CompoundLikelihood) likelihood).getDiagnosis() : "";
}
if (score == Double.NEGATIVE_INFINITY && mcmcOperator instanceof GibbsOperator) {
if (!(mcmcOperator instanceof GibbsIndependentNormalDistributionOperator) && !(mcmcOperator instanceof GibbsIndependentGammaOperator) && !(mcmcOperator instanceof GibbsIndependentCoalescentOperator) && !(mcmcOperator instanceof GibbsIndependentJointNormalGammaOperator)) {
Logger.getLogger("error").severe("State " + currentState + ": A Gibbs operator, " + mcmcOperator.getOperatorName() + ", returned a state with zero likelihood.");
}
}
if (score == Double.POSITIVE_INFINITY ||
Double.isNaN(score) ) {
if (likelihood instanceof CompoundLikelihood) {
Logger.getLogger("error").severe("State "+currentState+": A likelihood returned with a numerical error:\n" +
((CompoundLikelihood)likelihood).getDiagnosis());
} else {
Logger.getLogger("error").severe("State "+currentState+": A likelihood returned with a numerical error.");
}
// If the user has chosen to ignore this error then we transform it
// to a negative infinity so the state is rejected.
score = Double.NEGATIVE_INFINITY;
}
if (usingFullEvaluation) {
// This is a test that the state was correctly evaluated. The
// likelihood of all components of the model are flagged as
// needing recalculation, then the full likelihood is calculated
// again and compared to the first result. This checks that the
// BEAST is aware of all changes that the operator induced.
likelihood.makeDirty();
final double testScore = evaluate(likelihood);
final String d2 = likelihood instanceof CompoundLikelihood ?
((CompoundLikelihood) likelihood).getDiagnosis() : "";
if (Math.abs(testScore - score) > evaluationTestThreshold) {
Logger.getLogger("error").severe(
"State "+currentState+": State was not correctly calculated after an operator move.\n"
+ "Likelihood evaluation: " + score
+ "\nFull Likelihood evaluation: " + testScore
+ "\n" + "Operator: " + mcmcOperator
+ " " + mcmcOperator.getOperatorName()
+ (diagnosticOperator.length() > 0 ? "\n\nDetails\nBefore: " + diagnosticOperator + "\nAfter: " + d2 : "")
+ "\n\n");
fullEvaluationError = true;
}
}
if (score > bestScore) {
bestScore = score;
fireBestModel(currentState, currentModel);
}
accept = mcmcOperator instanceof GibbsOperator || acceptor.accept(oldScore, score, hastingsRatio, logr);
deviation = score - oldScore;
}
// The new model is accepted or rejected
if (accept) {
if (DEBUG) {
System.out.println("** Move accepted: new score = " + score
+ ", old score = " + oldScore);
}
mcmcOperator.accept(deviation);
currentModel.acceptModelState();
currentScore = score;
} else {
if (DEBUG) {
System.out.println("** Move rejected: new score = " + score
+ ", old score = " + oldScore);
}
mcmcOperator.reject();
// assert Profiler.startProfile("Restore");
currentModel.restoreModelState();
if (usingFullEvaluation) {
// This is a test that the state is correctly restored. The
// restored state is fully evaluated and the likelihood compared with
// that before the operation was made.
likelihood.makeDirty();
final double testScore = evaluate(likelihood);
final String d2 = likelihood instanceof CompoundLikelihood ?
((CompoundLikelihood) likelihood).getDiagnosis() : "";
if (Math.abs(testScore - oldScore) > evaluationTestThreshold) {
final Logger logger = Logger.getLogger("error");
logger.severe("State "+currentState+": State was not correctly restored after reject step.\n"
+ "Likelihood before: " + oldScore
+ " Likelihood after: " + testScore
+ "\n" + "Operator: " + mcmcOperator
+ " " + mcmcOperator.getOperatorName()
+ (diagnosticStart.length() > 0 ? "\n\nDetails\nBefore: " + diagnosticStart + "\nAfter: " + d2 : "")
+ "\n\n");
fullEvaluationError = true;
}
}
}
// assert Profiler.stopProfile("Restore");
if (!disableCoerce && mcmcOperator instanceof CoercableMCMCOperator) {
coerceAcceptanceProbability((CoercableMCMCOperator) mcmcOperator, logr[0]);
}
if (usingFullEvaluation) {
if (schedule.getMinimumAcceptAndRejectCount() >= minOperatorCountForFullEvaluation &&
currentState >= fullEvaluationCount) {
// full evaluation is only switched off when each operator has done a
// minimum number of operations (currently 1) and fullEvalationCount
// operations in total.
usingFullEvaluation = false;
if (fullEvaluationError) {
// If there has been an error then stop with an error
throw new RuntimeException(
"One or more evaluation errors occurred during the test phase of this\n" +
"run. These errors imply critical errors which may produce incorrect\n" +
"results.");
}
}
}
fireEndCurrentIteration(currentState);
currentState += 1;
}
currentLength = currentState;
return currentLength;
}
public void terminateChain() {
fireFinished(currentLength);
// Profiler.report();
}
public Likelihood getLikelihood() {
return likelihood;
}
public Model getModel() {
return likelihood.getModel();
}
public OperatorSchedule getSchedule() {
return schedule;
}
public Acceptor getAcceptor() {
return acceptor;
}
public double getInitialScore() {
return initialScore;
}
public double getBestScore() {
return bestScore;
}
public long getCurrentLength() {
return currentLength;
}
public void setCurrentLength(long currentLength) {
this.currentLength = currentLength;
}
public double getCurrentScore() {
return currentScore;
}
public void pleaseStop() {
pleaseStop = true;
}
public boolean isStopped() {
return isStopped;
}
public double evaluate() {
return evaluate(likelihood);
}
protected double evaluate(Likelihood likelihood) {
double logPosterior = 0.0;
final double logLikelihood = likelihood.getLogLikelihood();
if (Double.isNaN(logLikelihood)) {
return Double.NEGATIVE_INFINITY;
}
// System.err.println("** " + logPosterior + " + " + logLikelihood +
// " = " + (logPosterior + logLikelihood));
logPosterior += logLikelihood;
return logPosterior;
}
/**
* Updates the proposal parameter, based on the target acceptance
* probability This method relies on the proposal parameter being a
* decreasing function of acceptance probability.
*
* @param op The operator
* @param logr
*/
private void coerceAcceptanceProbability(CoercableMCMCOperator op, double logr) {
if (DEBUG) {
System.out.println("coerceAcceptanceProbability " + isCoercable(op));
}
if (isCoercable(op)) {
final double p = op.getCoercableParameter();
final double i = schedule.getOptimizationTransform(MCMCOperator.Utils.getOperationCount(op));
final double target = op.getTargetAcceptanceProbability();
final double newp = p + ((1.0 / (i + 1.0)) * (Math.exp(logr) - target));
if (newp > -Double.MAX_VALUE && newp < Double.MAX_VALUE) {
op.setCoercableParameter(newp);
if (DEBUG) {
System.out.println("Setting coercable parameter: " + newp + " target: " + target + " logr: " + logr);
}
}
}
}
private boolean isCoercable(CoercableMCMCOperator op) {
return op.getMode() == CoercionMode.COERCION_ON
|| (op.getMode() != CoercionMode.COERCION_OFF && useCoercion);
}
public void addMarkovChainListener(MarkovChainListener listener) {
if (listener != null) {
listeners.add(listener);
}
}
public void removeMarkovChainListener(MarkovChainListener listener) {
listeners.remove(listener);
}
private void fireBestModel(long state, Model bestModel) {
for (MarkovChainListener listener : listeners) {
listener.bestState(state, this, bestModel);
}
}
private void fireCurrentModel(long state, Model currentModel) {
for (MarkovChainListener listener : listeners) {
listener.currentState(state, this, currentModel);
}
}
private void fireFinished(long chainLength) {
for (MarkovChainListener listener : listeners) {
listener.finished(chainLength, this);
}
}
private void fireEndCurrentIteration(long state) {
}
private final ArrayList<MarkovChainListener> listeners = new ArrayList<MarkovChainListener>();
}