/*
* RandomLocalYuleModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.speciation;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.randomlocalmodel.RandomLocalTreeVariable;
import dr.evomodelxml.speciation.RandomLocalYuleModelParser;
import dr.inference.model.Parameter;
import java.text.NumberFormat;
import java.util.Locale;
import java.util.logging.Logger;
/**
* This class contains methods that describe a Yule speciation model whose rate of birth changes
* at different points in the tree.
*
* @author Alexei Drummond
*/
public class RandomLocalYuleModel extends UltrametricSpeciationModel implements TreeTraitProvider, RandomLocalTreeVariable {
private boolean calculateAllBirthRates = false;
public RandomLocalYuleModel(Parameter birthRates, Parameter indicators, Parameter meanRate,
boolean ratesAsMultipliers, Type units, int dp) {
super(RandomLocalYuleModelParser.YULE_MODEL, units);
addVariable(birthRates);
birthRates.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0, birthRates.getDimension()));
for (int i = 0; i < indicators.getDimension(); i++) {
indicators.setParameterValueQuietly(i, 0.0);
}
addVariable(indicators);
this.meanRate = meanRate;
addVariable(meanRate);
birthRatesAreMultipliers = ratesAsMultipliers;
format.setMaximumFractionDigits(dp);
birthRatesName = birthRates.getParameterName();
Logger.getLogger("dr.evomodel").info(" birth rates parameter is named '" + birthRatesName + "'");
indicatorsName = indicators.getParameterName();
Logger.getLogger("dr.evomodel").info(" indicator parameter is named '" + indicatorsName + "'");
this.birthRates = new double[birthRates.getDimension() + 1];
treeTraits.addTrait(new TreeTrait.I() {
public String getTraitName() {
return "I";
}
public Intent getIntent() {
return Intent.NODE;
}
public Integer getTrait(Tree tree, NodeRef node) {
return (isVariableSelected(tree, node) ? 1 : 0);
}
});
treeTraits.addTrait(new TreeTrait.D() {
public String getTraitName() {
return "b";
}
public Intent getIntent() {
return Intent.NODE;
}
public Double getTrait(Tree tree, NodeRef node) {
return RandomLocalYuleModel.this.birthRates[node.getNumber()];
}
});
}
public final double getVariable(Tree tree, NodeRef node) {
return ((TreeModel)tree).getNodeTrait(node, birthRatesName);
}
public final boolean isVariableSelected(Tree tree, NodeRef node) {
return ((TreeModel)tree).getNodeTrait(node, indicatorsName) > 0.5;
}
//
// functions that define a speciation model
//
public final double logTreeProbability(int taxonCount) {
// calculate all nodes birth rates
calculateAllBirthRates = true;
return 0.0;
}
//
// functions that define a speciation model
//
public final double logNodeProbability(Tree tree, NodeRef node) {
if (calculateAllBirthRates) {
calculateBirthRates((TreeModel) tree, tree.getRoot(), 0.0);
calculateAllBirthRates = false;
}
if (tree.isRoot(node)) {
return 0.0;
} else {
double lambda = birthRates[node.getNumber()];
double branchLength = tree.getNodeHeight(tree.getParent(node)) - tree.getNodeHeight(node);
double logP = -lambda * branchLength;
if (tree.isExternal(node)) logP += Math.log(lambda);
return logP;
}
}
private void calculateBirthRates(TreeModel tree, NodeRef node, double rate) {
int nodeNumber = node.getNumber();
if (tree.isRoot(node)) {
rate = meanRate.getParameterValue(0);
} else {
if (isVariableSelected(tree, node)) {
if (birthRatesAreMultipliers) {
rate *= getVariable(tree, node);
} else {
rate = getVariable(tree, node);
}
}
}
birthRates[nodeNumber] = rate;
int childCount = tree.getChildCount(node);
for (int i = 0; i < childCount; i++) {
calculateBirthRates(tree, tree.getChild(node, i), rate);
}
}
// /**
// * @param tree the tree
// * @param node the node to retrieve the birth rate of
// * @return the birth rate of the given node;
// */
// private double getBirthRate(TreeModel tree, NodeRef node) {
//
// double birthRate;
// if (!tree.isRoot(node)) {
//
// double parentRate = getBirthRate(tree, tree.getParent(node));
// if (isVariableSelected(tree, node)) {
// birthRate = getVariable(tree, node);
// if (birthRatesAreMultipliers) {
// birthRate *= parentRate;
// } else {
// throw new RuntimeException("Rates must be multipliers in current implementation! " +
// "Otherwise root rate might be ignored");
// }
// } else {
// birthRate = parentRate;
// }
// } else {
// birthRate = meanRate.getParameterValue(0);
// }
// return birthRate;
// }
protected TreeTraitProvider.Helper treeTraits = new Helper();
public TreeTrait[] getTreeTraits() {
return treeTraits.getTreeTraits();
}
public TreeTrait getTreeTrait(String key) {
return treeTraits.getTreeTrait(key);
}
public boolean includeExternalNodesInLikelihoodCalculation() {
return true;
}
// **************************************************************
// XMLElement IMPLEMENTATION
// **************************************************************
public org.w3c.dom.Element createElement(org.w3c.dom.Document d) {
throw new RuntimeException("createElement not implemented");
}
private double[] birthRates;
private String birthRatesName = "birthRates";
private String indicatorsName = "birthRateIndicator";
private Parameter meanRate;
private boolean birthRatesAreMultipliers = false;
private NumberFormat format = NumberFormat.getNumberInstance(Locale.ENGLISH);
}