package beast.evolution.branchratemodel;
import java.util.Arrays;
import org.apache.commons.math.MathException;
import beast.core.Citation;
import beast.core.Description;
import beast.core.Input;
import beast.core.parameter.IntegerParameter;
import beast.core.parameter.RealParameter;
import beast.core.util.Log;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.math.distributions.ParametricDistribution;
import beast.util.Randomizer;
/**
* @author Alexei Drummond
*/
@Description("Defines an uncorrelated relaxed molecular clock.")
@Citation(value =
"Drummond AJ, Ho SYW, Phillips MJ, Rambaut A (2006) Relaxed Phylogenetics and\n" +
" Dating with Confidence. PLoS Biol 4(5): e88", DOI = "10.1371/journal.pbio.0040088",
year = 2006, firstAuthorSurname = "drummond")
public class UCRelaxedClockModel extends BranchRateModel.Base {
final public Input<ParametricDistribution> rateDistInput = new Input<>("distr", "the distribution governing the rates among branches. Must have mean of 1. The clock.rate parameter can be used to change the mean rate.", Input.Validate.REQUIRED);
final public Input<IntegerParameter> categoryInput = new Input<>("rateCategories", "the rate categories associated with nodes in the tree for sampling of individual rates among branches.", Input.Validate.REQUIRED);
final public Input<Integer> numberOfDiscreteRates = new Input<>("numberOfDiscreteRates", "the number of discrete rate categories to approximate the rate distribution by. A value <= 0 will cause the number of categories to be set equal to the number of branches in the tree. (default = -1)", -1);
final public Input<RealParameter> quantileInput = new Input<>("rateQuantiles", "the rate quantiles associated with nodes in the tree for sampling of individual rates among branches.", Input.Validate.XOR, categoryInput);
final public Input<Tree> treeInput = new Input<>("tree", "the tree this relaxed clock is associated with.", Input.Validate.REQUIRED);
final public Input<Boolean> normalizeInput = new Input<>("normalize", "Whether to normalize the average rate (default false).", false);
// public Input<Boolean> initialiseInput = new Input<>("initialise", "Whether to initialise rates by a heuristic instead of random (default false).", false);
RealParameter meanRate;
// boolean initialise;
int LATTICE_SIZE_FOR_DISCRETIZED_RATES = 100;
// true if quantiles are used, false if discrete rate categories are used.
boolean usingQuantiles;
private int branchCount;
@Override
public void initAndValidate() {
tree = treeInput.get();
branchCount = tree.getNodeCount() - 1;
categories = categoryInput.get();
usingQuantiles = (categories == null);
if (!usingQuantiles) {
LATTICE_SIZE_FOR_DISCRETIZED_RATES = numberOfDiscreteRates.get();
if (LATTICE_SIZE_FOR_DISCRETIZED_RATES <= 0) LATTICE_SIZE_FOR_DISCRETIZED_RATES = branchCount;
Log.info.println(" UCRelaxedClockModel: using " + LATTICE_SIZE_FOR_DISCRETIZED_RATES + " rate " +
"categories to approximate rate distribution across branches.");
} else {
if (numberOfDiscreteRates.get() != -1) {
throw new RuntimeException("Can't specify both numberOfDiscreteRates and rateQuantiles inputs.");
}
Log.info.println(" UCRelaxedClockModel: using quantiles for rate distribution across branches.");
}
if (usingQuantiles) {
quantiles = quantileInput.get();
quantiles.setDimension(branchCount);
Double[] initialQuantiles = new Double[branchCount];
for (int i = 0; i < branchCount; i++) {
initialQuantiles[i] = Randomizer.nextDouble();
}
RealParameter other = new RealParameter(initialQuantiles);
quantiles.assignFromWithoutID(other);
quantiles.setLower(0.0);
quantiles.setUpper(1.0);
} else {
categories.setDimension(branchCount);
Integer[] initialCategories = new Integer[branchCount];
for (int i = 0; i < branchCount; i++) {
initialCategories[i] = Randomizer.nextInt(LATTICE_SIZE_FOR_DISCRETIZED_RATES);
}
// set initial values of rate categories
IntegerParameter other = new IntegerParameter(initialCategories);
categories.assignFromWithoutID(other);
categories.setLower(0);
categories.setUpper(LATTICE_SIZE_FOR_DISCRETIZED_RATES - 1);
}
distribution = rateDistInput.get();
if (!usingQuantiles) {
// rates are initially zero and are computed by getRawRate(int i) as needed
rates = new double[LATTICE_SIZE_FOR_DISCRETIZED_RATES];
storedRates = new double[LATTICE_SIZE_FOR_DISCRETIZED_RATES];
//System.arraycopy(rates, 0, storedRates, 0, rates.length);
}
normalize = normalizeInput.get();
meanRate = meanRateInput.get();
if (meanRate == null) {
meanRate = new RealParameter("1.0");
}
try {
double mean = rateDistInput.get().getMean();
if (Math.abs(mean - 1.0) > 1e-6) {
Log.warning.println("WARNING: mean of distribution for relaxed clock model is not 1.0.");
}
} catch (RuntimeException e) {
// ignore
}
}
@Override
public double getRateForBranch(Node node) {
if (node.isRoot()) {
// root has no rate
return 1;
}
if (recompute) {
// this must be synchronized to avoid being called simultaneously by
// two different likelihood threads
synchronized (this) {
prepare();
recompute = false;
}
}
if (renormalize) {
if (normalize) {
synchronized (this) {
computeFactor();
}
}
renormalize = false;
}
return getRawRate(node) * scaleFactor * meanRate.getValue();
}
/**
* Computes a scale factor for normalization. Only called if normalize=true.
*/
private void computeFactor() {
//scale mean rate to 1.0 or separate parameter
double treeRate = 0.0;
double treeTime = 0.0;
if (!usingQuantiles) {
for (int i = 0; i < tree.getNodeCount(); i++) {
Node node = tree.getNode(i);
if (!node.isRoot()) {
treeRate += getRawRateForCategory(node) * node.getLength();
treeTime += node.getLength();
}
}
} else {
for (int i = 0; i < tree.getNodeCount(); i++) {
Node node = tree.getNode(i);
if (!node.isRoot()) {
treeRate += getRawRateForQuantile(node) * node.getLength();
treeTime += node.getLength();
}
}
}
scaleFactor = 1.0 / (treeRate / treeTime);
}
private double getRawRate(Node node) {
if (usingQuantiles) {
return getRawRateForQuantile(node);
}
return getRawRateForCategory(node);
}
/**
* @param node the node to get the rate of
* @return the rate of the branch
*/
private double getRawRateForCategory(Node node) {
int nodeNumber = node.getNr();
if (nodeNumber == branchCount) {
// root node has nr less than #categories, so use that nr
nodeNumber = node.getTree().getRoot().getNr();
}
int category = categories.getValue(nodeNumber);
if (rates[category] == 0.0) {
try {
rates[category] = distribution.inverseCumulativeProbability((category + 0.5) / rates.length);
} catch (MathException e) {
throw new RuntimeException("Failed to compute inverse cumulative probability!");
}
}
return rates[category];
}
private double getRawRateForQuantile(Node node) {
int nodeNumber = node.getNr();
if (nodeNumber == branchCount) {
// root node has nr less than #categories, so use that nr
nodeNumber = node.getTree().getRoot().getNr();
}
try {
return distribution.inverseCumulativeProbability(quantiles.getValue(nodeNumber));
} catch (MathException e) {
throw new RuntimeException("Failed to compute inverse cumulative probability!");
}
}
private void prepare() {
categories = categoryInput.get();
usingQuantiles = (categories == null);
distribution = rateDistInput.get();
tree = treeInput.get();
if (!usingQuantiles) {
// rates array initialized to correct length in initAndValidate
// here we just reset rates to zero and they are computed by getRawRate(int i) as needed
Arrays.fill(rates, 0.0);
}
}
/**
* initialise rate categories by matching rates to tree using JC69 *
*/
// private void initialise() {
// try {
// for (BEASTObject output : outputs) {
// if (output.getInput("data") != null && output.getInput("tree") != null) {
//
// // set up treelikelihood with Jukes Cantor no gamma, no inv, strict clock
// Alignment data = (Alignment) output.getInput("data").get();
// Tree tree = (Tree) output.getInput("tree").get();
// TreeLikelihoodD likelihood = new TreeLikelihoodD();
// SiteModel siteModel = new SiteModel();
// JukesCantor substitutionModel = new JukesCantor();
// substitutionModel.initAndValidate();
// siteModel.initByName("substModel", substitutionModel);
// likelihood.initByName("data", data, "tree", tree, "siteModel", siteModel);
// likelihood.calculateLogP();
//
// // calculate distances
// Node [] nodes = tree.getNodesAsArray();
// double [] distance = new double[nodes.length];
// for (int i = 0; i < distance.length - 1; i++) {
// double len = nodes[i].getLength();
// double dist = likelihood.calcDistance(nodes[i]);
// distance[i] = len / dist;
// }
//
// // match category to distance
// double min = distance[0], max = min;
// for (int i = 1; i < distance.length - 1; i++) {
// if (!Double.isNaN(distance[i]) && !Double.isInfinite(distance[i])) {
// min = Math.min(min, distance[i]);
// max = Math.max(max, distance[i]);
// }
// }
// IntegerParameter categoriesParameter = categoryInput.get();
// Integer[] categories = new Integer[categoriesParameter.getDimension()];
// for (int i = 0; i < distance.length - 1; i++) {
// if (!Double.isNaN(distance[i]) && !Double.isInfinite(distance[i])) {
// categories[i] = (int)((distance.length - 2) * (distance[i]-min)/(max-min));
// } else {
// categories[i] = distance.length / 2;
// }
// }
// IntegerParameter other = new IntegerParameter(categories);
// other.setBounds(categoriesParameter.getLower(), categoriesParameter.getUpper());
// categoriesParameter.assignFromWithoutID(other);
// }
// }
// } catch (Exception e) {
// // ignore
// System.err.println("WARNING: UCRelaxedClock heuristic initialisation failed");
// }
// }
//
// @Description("Treelikelihood used to guesstimate rates on branches by using the JC69 model on the data")
// class TreeLikelihoodD extends TreeLikelihood {
//
// double calcDistance(Node node) {
// int nodeIndex = node.getNr();
// int patterncount = dataInput.get().getPatternCount();
// int statecount = dataInput.get().getDataType().getStateCount();
// double [] parentPartials = new double[patterncount * statecount];
// likelihoodCore.getNodePartials(node.getParent().getNr(), parentPartials);
// if (node.isLeaf()) {
// // distance of leaf to its parent, ignores ambiguities
// int [] states = new int[patterncount ];
// likelihoodCore.getNodeStates(nodeIndex, states);
// double distance = 0;
// for (int i = 0; i < patterncount; i++) {
// int k = states[i];
// double d = 0;
// for (int j = 0; j < statecount; j++) {
// if (j == k) {
// d += 1.0 - parentPartials[i * statecount + j];
// } else {
// d += parentPartials[i * statecount + j];
// }
// }
// distance += d * dataInput.get().getPatternWeight(i);
// }
// return distance;
// } else {
// // L1 distance of internal node partials to its parent partials
// double [] partials = new double[parentPartials.length];
// likelihoodCore.getNodePartials(nodeIndex, partials);
// double distance = 0;
// for (int i = 0; i < patterncount; i++) {
// double d = 0;
// for (int j = 0; j < statecount; j++) {
// d += Math.abs(partials[i * statecount + j] - parentPartials[i * statecount + j]);
// }
// distance += d * dataInput.get().getPatternWeight(i);
// }
// return distance;
// }
// }
//
// }
@Override
protected boolean requiresRecalculation() {
recompute = false;
renormalize = true;
// if (treeInput.get().somethingIsDirty()) {
// recompute = true;
// return true;
// }
// rateDistInput cannot be dirty?!?
if (rateDistInput.get().isDirtyCalculation()) {
recompute = true;
return true;
}
// NOT processed as trait on the tree, so DO mark as dirty
if (categoryInput.get() != null && categoryInput.get().somethingIsDirty()) {
//recompute = true;
return true;
}
if (quantileInput.get() != null && quantileInput.get().somethingIsDirty()) {
return true;
}
if (meanRate.somethingIsDirty()) {
return true;
}
return recompute;
}
@Override
public void store() {
if (!usingQuantiles) System.arraycopy(rates, 0, storedRates, 0, rates.length);
storedScaleFactor = scaleFactor;
super.store();
}
@Override
public void restore() {
if (!usingQuantiles) {
double[] tmp = rates;
rates = storedRates;
storedRates = tmp;
}
scaleFactor = storedScaleFactor;
super.restore();
}
ParametricDistribution distribution;
IntegerParameter categories;
RealParameter quantiles;
Tree tree;
private boolean normalize = false;
private boolean recompute = true;
private boolean renormalize = true;
private double[] rates;
private double[] storedRates;
private double scaleFactor = 1.0;
private double storedScaleFactor = 1.0;
}