package beast.evolution.branchratemodel; import beast.core.Description; import beast.core.Input; import beast.core.parameter.BooleanParameter; import beast.core.parameter.RealParameter; import beast.core.util.Log; import beast.evolution.tree.Node; import beast.evolution.tree.Tree; /** * @author Alexei Drummond */ @Description("Random Local Clock Model, whatever that is....") public class RandomLocalClockModel extends BranchRateModel.Base { final public Input<BooleanParameter> indicatorParamInput = new Input<>("indicators", "the indicators associated with nodes in the tree for sampling of individual rate changes among branches.", Input.Validate.REQUIRED); final public Input<RealParameter> rateParamInput = new Input<>("rates", "the rate parameters associated with nodes in the tree for sampling of individual rates among branches.", Input.Validate.REQUIRED); // public Input<RealParameter> meanRateInput = // new Input<>("meanRate", // "an optional parameter to set the mean rate across the whole tree"); final public Input<Tree> treeInput = new Input<>("tree", "the tree this relaxed clock is associated with.", Input.Validate.REQUIRED); final public Input<Boolean> ratesAreMultipliersInput = new Input<>("ratesAreMultipliers", "true if the rates should be treated as multipliers (default false).", false); Tree m_tree; RealParameter meanRate; @Override public void initAndValidate() { m_tree = treeInput.get(); BooleanParameter indicators = indicatorParamInput.get(); if (indicators.getDimension() != m_tree.getNodeCount() - 1) { Log.warning.println("RandomLocalClockModel::Setting dimension of indicators to " + (m_tree.getNodeCount() - 1)); indicators.setDimension(m_tree.getNodeCount() - 1); } unscaledBranchRates = new double[m_tree.getNodeCount()]; RealParameter rates = rateParamInput.get(); if (rates.lowerValueInput.get() == null || rates.lowerValueInput.get() < 0.0) { rates.setLower(0.0); } if (rates.upperValueInput.get() == null || rates.upperValueInput.get() < 0.0) { rates.setUpper(Double.MAX_VALUE); } if (rates.getDimension() != m_tree.getNodeCount() - 1) { Log.warning.println("RandomLocalClockModel::Setting dimension of rates to " + (m_tree.getNodeCount() - 1)); rates.setDimension(m_tree.getNodeCount() - 1); } ratesAreMultipliers = ratesAreMultipliersInput.get(); meanRate = meanRateInput.get(); if (meanRate == null) { meanRate = new RealParameter("1.0"); } } /** * This is a recursive function that does the work of * calculating the unscaled branch rates across the tree * taking into account the indicator variables. * * @param node the node * @param rate the rate of the parent node */ private void calculateUnscaledBranchRates(Node node, double rate, BooleanParameter indicators, RealParameter rates) { int nodeNumber = getNr(node); if (!node.isRoot()) { if (indicators.getValue(nodeNumber)) { if (ratesAreMultipliers) { rate *= rates.getValue(nodeNumber); } else { rate = rates.getValue(nodeNumber); } } } unscaledBranchRates[nodeNumber] = rate; if (!node.isLeaf()) { calculateUnscaledBranchRates(node.getLeft(), rate, indicators, rates); calculateUnscaledBranchRates(node.getRight(), rate, indicators, rates); } } private void recalculateScaleFactor() { BooleanParameter indicators = indicatorParamInput.get(); RealParameter rates = rateParamInput.get(); calculateUnscaledBranchRates(m_tree.getRoot(), 1.0, indicators, rates); double timeTotal = 0.0; double branchTotal = 0.0; for (int i = 0; i < m_tree.getNodeCount(); i++) { Node node = m_tree.getNode(i); if (!node.isRoot()) { double branchInTime = node.getParent().getHeight() - node.getHeight(); double branchLength = branchInTime * unscaledBranchRates[node.getNr()]; timeTotal += branchInTime; branchTotal += branchLength; } } scaleFactor = timeTotal / branchTotal; scaleFactor *= meanRate.getValue(); } @Override public double getRateForBranch(Node node) { // this must be synchronized to avoid being called simultaneously by // two different likelihood threads synchronized (this) { if (recompute) { recalculateScaleFactor(); recompute = false; } } return unscaledBranchRates[getNr(node)] * scaleFactor; } private int getNr(Node node) { int nodeNr = node.getNr(); if (nodeNr > m_tree.getRoot().getNr()) { nodeNr--; } return nodeNr; } @Override protected boolean requiresRecalculation() { // this is only called if any of its inputs is dirty, hence we need to recompute recompute = true; return true; } @Override protected void store() { recompute = true; super.store(); } @Override protected void restore() { recompute = true; super.restore(); } private boolean recompute = true; double[] unscaledBranchRates; double scaleFactor; boolean ratesAreMultipliers = false; }