/*
* LatentStateBranchRateModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.branchratemodel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeParameterModel;
import dr.inference.markovjumps.TwoStateOccupancyMarkovReward;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
/**
* LatentStateBranchRateModel
*
* @author Andrew Rambaut
* @author Marc Suchard
* @version $Id$
* <p/>
* $HeadURL$
* <p/>
* $LastChangedBy$
* $LastChangedDate$
* $LastChangedRevision$
*/
public class LatentStateBranchRateModel extends AbstractModelLikelihood implements BranchRateModel {
public static final String LATENT_STATE_BRANCH_RATE_MODEL = "latentStateBranchRateModel";
public static final boolean USE_CACHING = true;
// seed 666, caching off: 204.69 seconds for 20000 states
// state 20000 -5510.2520
// 85.7% 5202 + 6 dr.inference.markovjumps.SericolaSeriesMarkovReward.accumulatePdf
// seed 666, caching on: 119.43 seconds for 20000 states
// state 20000 -5510.2520
// 83.4% 3156 + 4 dr.inference.markovjumps.SericolaSeriesMarkovReward.accumulatePdf
private final TreeModel tree;
private final BranchRateModel nonLatentRateModel;
private final Parameter latentTransitionRateParameter;
private final Parameter latentTransitionFrequencyParameter;
private final TreeParameterModel latentStateProportions;
private final Parameter latentStateProportionParameter;
private final CountableBranchCategoryProvider branchCategoryProvider;
private TwoStateOccupancyMarkovReward markovReward;
private TwoStateOccupancyMarkovReward storedMarkovReward;
private boolean likelihoodKnown = false;
private boolean storedLikelihoodKnown;
private double logLikelihood;
private double storedLogLikelihood;
private double[] branchLikelihoods;
private double[] storedbranchLikelihoods;
private boolean[] updateBranch;
private boolean[] storedUpdateBranch;
private boolean[] updateCategory;
private boolean[] storedUpdateCategory;
public LatentStateBranchRateModel(String name,
TreeModel treeModel,
BranchRateModel nonLatentRateModel,
Parameter latentTransitionRateParameter,
Parameter latentTransitionFrequencyParameter,
Parameter latentStateProportionParameter,
CountableBranchCategoryProvider branchCategoryProvider) {
super(name);
this.tree = treeModel;
addModel(tree);
this.nonLatentRateModel = nonLatentRateModel;
addModel(nonLatentRateModel);
this.latentTransitionRateParameter = latentTransitionRateParameter;
addVariable(latentTransitionRateParameter);
this.latentTransitionFrequencyParameter = latentTransitionFrequencyParameter;
addVariable(latentTransitionFrequencyParameter);
if (branchCategoryProvider == null) {
this.latentStateProportions = new TreeParameterModel(tree, latentStateProportionParameter, false, Intent.BRANCH);
addModel(latentStateProportions);
this.latentStateProportionParameter = null;
this.branchCategoryProvider = null;
} else {
this.latentStateProportions = null;
this.branchCategoryProvider = branchCategoryProvider;
this.latentStateProportionParameter = latentStateProportionParameter;
this.latentStateProportionParameter.setDimension(branchCategoryProvider.getCategoryCount());
if (USE_CACHING) {
updateCategory = new boolean[branchCategoryProvider.getCategoryCount()];
storedUpdateCategory = new boolean[branchCategoryProvider.getCategoryCount()];
setUpdateAllCategories();
}
addVariable(latentStateProportionParameter);
}
branchLikelihoods = new double[tree.getNodeCount()];
if (USE_CACHING) {
updateBranch = new boolean[tree.getNodeCount()];
storedUpdateBranch = new boolean[tree.getNodeCount()];
storedbranchLikelihoods = new double[tree.getNodeCount()];
setUpdateAllBranches();
}
}
public LatentStateBranchRateModel(Parameter rate, Parameter prop) {
super(LATENT_STATE_BRANCH_RATE_MODEL);
tree = null;
nonLatentRateModel = null;
latentTransitionRateParameter = rate;
latentTransitionFrequencyParameter = prop;
latentStateProportions = null;
this.latentStateProportionParameter = null;
this.branchCategoryProvider = null;
}
private double[] createLatentInfinitesimalMatrix() {
final double rate = latentTransitionRateParameter.getParameterValue(0);
final double prop = latentTransitionFrequencyParameter.getParameterValue(0);
double[] mat = new double[]{
-rate * prop, rate * prop,
rate * (1.0 - prop), -rate * (1.0 - prop)
};
return mat;
}
private static double[] createReward() {
return new double[]{0.0, 1.0};
}
private TwoStateOccupancyMarkovReward createMarkovReward() {
TwoStateOccupancyMarkovReward markovReward = new TwoStateOccupancyMarkovReward(createLatentInfinitesimalMatrix());
return markovReward;
}
public TwoStateOccupancyMarkovReward getMarkovReward() {
if (markovReward == null) {
markovReward = createMarkovReward();
}
return markovReward;
}
@Override
public double getBranchRate(Tree tree, NodeRef node) {
double nonLatentRate = nonLatentRateModel.getBranchRate(tree, node);
double latentProportion = getLatentProportion(tree, node);
return calculateBranchRate(nonLatentRate, latentProportion);
}
public double getLatentProportion(Tree tree, NodeRef node) {
if (latentStateProportions != null) {
return latentStateProportions.getNodeValue(tree, node);
} else {
return latentStateProportionParameter.getParameterValue(branchCategoryProvider.getBranchCategory(tree, node));
}
}
private double calculateBranchRate(double nonLatentRate, double latentProportion) {
return nonLatentRate * (1.0 - latentProportion);
}
@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model == tree) {
likelihoodKnown = false; // node heights change elapsed times on branches, TODO could cache
if (index == -1) {
setUpdateAllBranches();
} else {
setUpdateBranch(index);
}
} else if (model == nonLatentRateModel) {
// rates will change but the latent proportions haven't so the density is unchanged
} else if (model == latentStateProportions) {
likelihoodKnown = false; // argument of density has changed
if (index == -1) {
setUpdateAllBranches();
} else {
setUpdateBranch(index);
}
}
fireModelChanged();
}
@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if (variable == latentTransitionFrequencyParameter || variable == latentTransitionRateParameter) {
// markovReward computations have changed
markovReward = null;
setUpdateAllBranches();
likelihoodKnown = false;
} else if (variable == latentStateProportionParameter) {
if (index == -1) {
setUpdateAllBranches();
} else {
setUpdateBranchCategory(index);
}
likelihoodKnown = false;
fireModelChanged();
}
}
private void setUpdateBranch(int nodeNumber) {
if (USE_CACHING) {
updateBranch[nodeNumber] = true;
}
}
private void setUpdateAllBranches() {
if (USE_CACHING) {
for (int i = 0; i < updateBranch.length; i++) {
updateBranch[i] = true;
}
}
}
private void clearUpdateAllBranches() {
if (USE_CACHING) {
for (int i = 0; i < updateBranch.length; i++) {
updateBranch[i] = false;
}
}
}
private void setUpdateBranchCategory(int category) {
if (USE_CACHING) {
updateCategory[category] = true;
}
}
private void setUpdateAllCategories() {
if (USE_CACHING) {
for (int i = 0; i < updateCategory.length; i++) {
updateCategory[i] = true;
}
}
}
private void clearAllCategories() {
if (USE_CACHING && updateCategory != null) {
for (int i = 0; i < updateCategory.length; i++) {
updateCategory[i] = false;
}
}
}
@Override
protected void storeState() {
storedMarkovReward = markovReward;
storedLogLikelihood = logLikelihood;
storedLikelihoodKnown = likelihoodKnown;
if (USE_CACHING) {
System.arraycopy(branchLikelihoods, 0, storedbranchLikelihoods, 0, branchLikelihoods.length);
System.arraycopy(updateBranch, 0, storedUpdateBranch, 0, updateBranch.length);
if (updateCategory != null) {
System.arraycopy(updateCategory, 0, storedUpdateCategory, 0, updateCategory.length);
}
}
}
@Override
protected void restoreState() {
markovReward = storedMarkovReward;
logLikelihood = storedLogLikelihood;
likelihoodKnown = storedLikelihoodKnown;
if (USE_CACHING) {
double[] tmp = branchLikelihoods;
branchLikelihoods = storedbranchLikelihoods;
storedbranchLikelihoods = tmp;
boolean[] tmp2 = updateBranch;
updateBranch = storedUpdateBranch;
storedUpdateBranch = tmp2;
boolean[] tmp3 = updateCategory;
updateCategory = storedUpdateCategory;
storedUpdateCategory = tmp3;
}
}
@Override
protected void acceptState() {
}
@Override
public Model getModel() {
return this;
}
@Override
public double getLogLikelihood() {
if (!likelihoodKnown) {
logLikelihood = calculateLogLikelihood();
likelihoodKnown = true;
}
return logLikelihood;
}
private double calculateLogLikelihood() {
double logLike = 0.0;
for (int i = 0; i < tree.getInternalNodeCount(); ++i) {
NodeRef node = tree.getNode(i);
if (node != tree.getRoot()) {
if (updateNeededForNode(tree, node)) {
double branchLength = tree.getBranchLength(node);
double latentProportion = getLatentProportion(tree, node);
assert(latentProportion < 1.0);
double density = getBranchRewardDensity(latentProportion, branchLength);
branchLikelihoods[node.getNumber()] = Math.log(density);
}
logLike += branchLikelihoods[node.getNumber()];
}
}
clearUpdateAllBranches();
clearAllCategories();
return logLike;
}
private boolean updateNeededForNode(Tree tree, NodeRef node) {
if (USE_CACHING) {
return (updateCategory != null && updateCategory[branchCategoryProvider.getBranchCategory(tree, node)]) || updateBranch[node.getNumber()];
} else {
return true;
}
}
public double getBranchRewardDensity(double proportion, double branchLength) {
if (markovReward == null) {
markovReward = createMarkovReward();
}
// int state = 0 * 2 + 0; // just start = end = 0 entry
// Reward is [0,1], and we want to track time in latent state (= 1).
// Therefore all nodes are in state 0
// double joint = markovReward.computePdf(reward, branchLength)[state];
final double joint = markovReward.computePdf(proportion * branchLength, branchLength, 0, 0);
final double marg = markovReward.computeConditionalProbability(branchLength, 0, 0);
final double rate = latentTransitionRateParameter.getParameterValue(0) *
latentTransitionFrequencyParameter.getParameterValue(0) * branchLength;
final double zeroJumps = Math.exp(-rate);
// Check numerical tolerance
if (marg - zeroJumps <= 0.0) {
return 0.0;
}
// TODO Overhead in creating double[] could be saved by changing signature to computePdf
double density = joint / (marg - zeroJumps); // conditional on ending state and >= 2 jumps
density *= branchLength; // random variable is latentProportion = reward / branchLength, so include Jacobian
if (DEBUG) {
if (Double.isInfinite(Math.log(density))) {
System.err.println("Infinite density in LatentStateBranchRateModel:");
System.err.println("proportion = " + proportion);
System.err.println("branchLength = " + branchLength);
System.err.println("lTRP = " + latentTransitionRateParameter.getParameterValue(0));
System.err.println("lTFP = " + latentTransitionFrequencyParameter.getParameterValue(0));
System.err.println("rate = " + rate);
System.err.println("joint = " + joint);
System.err.println("marg = " + marg);
System.err.println("zero = " + zeroJumps);
System.err.println("Hit debugger");
final double joint2 = markovReward.computePdf(proportion * branchLength, branchLength, 0, 0);
final double marg2 = markovReward.computeConditionalProbability(branchLength, 0, 0);
}
}
return density;
}
@Override
public void makeDirty() {
likelihoodKnown = false;
markovReward = null;
setUpdateAllBranches();
}
@Override
public String getTraitName() {
return BranchRateModel.RATE;
}
@Override
public Intent getIntent() {
return Intent.BRANCH;
}
@Override
public TreeTrait getTreeTrait(final String key) {
if (key.equals(BranchRateModel.RATE)) {
return this;
} else if (latentStateProportions != null && key.equals(latentStateProportions.getTraitName())) {
return latentStateProportions;
} else if (branchCategoryProvider != null && key.equals(branchCategoryProvider.getTraitName())) {
return branchCategoryProvider;
} else {
throw new IllegalArgumentException("Unrecognised Tree Trait key, " + key);
}
}
@Override
public TreeTrait[] getTreeTraits() {
return new TreeTrait[]{this, latentStateProportions, branchCategoryProvider};
}
@Override
public Class getTraitClass() {
return Double.class;
}
@Override
public boolean getLoggable() {
return true;
}
@Override
public Double getTrait(final Tree tree, final NodeRef node) {
return getBranchRate(tree, node);
}
@Override
public String getTraitString(final Tree tree, final NodeRef node) {
return Double.toString(getBranchRate(tree, node));
}
public static void main(String[] args) {
Parameter rate = new Parameter.Default(4.4);
Parameter prop = new Parameter.Default(0.25);
LatentStateBranchRateModel model = new LatentStateBranchRateModel(rate, prop);
double branchLength = 2.0;
for (double reward = 0; reward < branchLength; reward += 0.01) {
System.out.println(reward + ",\t" + model.getBranchRewardDensity(reward, branchLength) + ",");
}
System.out.println();
System.out.println(model.getMarkovReward());
}
private static boolean DEBUG = true;
}