package beast.math.distributions; import java.io.PrintStream; import java.util.ArrayList; import java.util.LinkedHashSet; import java.util.List; import java.util.Random; import java.util.Set; import beast.core.Description; import beast.core.Distribution; import beast.core.Input; import beast.core.Input.Validate; import beast.core.State; import beast.evolution.alignment.TaxonSet; import beast.evolution.tree.Node; import beast.evolution.tree.Tree; @Description("Prior over set of taxa, useful for defining monophyletic constraints and " + "distributions over MRCA times or (sets of) tips of trees") public class MRCAPrior extends Distribution { public final Input<Tree> treeInput = new Input<>("tree", "the tree containing the taxon set", Validate.REQUIRED); public final Input<TaxonSet> taxonsetInput = new Input<>("taxonset", "set of taxa for which prior information is available"); public final Input<Boolean> isMonophyleticInput = new Input<>("monophyletic", "whether the taxon set is monophyletic (forms a clade without other taxa) or nor. Default is false.", false); public final Input<ParametricDistribution> distInput = new Input<>("distr", "distribution used to calculate prior over MRCA time, " + "e.g. normal, beta, gamma. If not specified, monophyletic must be true"); public final Input<Boolean> onlyUseTipsInput = new Input<>("tipsonly", "flag to indicate tip dates are to be used instead of the MRCA node. " + "If set to true, the prior is applied to the height of all tips in the taxonset " + "and the monophyletic flag is ignored. Default is false.", false); public final Input<Boolean> useOriginateInput = new Input<>("useOriginate", "Use parent of clade instead of clade. Cannot be used with tipsonly, or on the root.", false); /** * shadow members * */ ParametricDistribution dist; Tree tree; // number of taxa in taxon set int nrOfTaxa = -1; // array of flags to indicate which taxa are in the set Set<String> isInTaxaSet = new LinkedHashSet<>(); // array of indices of taxa int[] taxonIndex; // stores time to be calculated double MRCATime = -1; double storedMRCATime = -1; // flag indicating taxon set is monophyletic boolean isMonophyletic = false; boolean onlyUseTips = false; boolean useRoot = false; boolean useOriginate = false; boolean initialised = false; @Override public void initAndValidate() { dist = distInput.get(); tree = treeInput.get(); final List<String> taxaNames = new ArrayList<>(); for (final String taxon : tree.getTaxaNames()) { taxaNames.add(taxon); } // determine nr of taxa in taxon set List<String> set = null; if (taxonsetInput.get() != null) { set = taxonsetInput.get().asStringList(); nrOfTaxa = set.size(); } else { // assume all taxa nrOfTaxa = taxaNames.size(); } onlyUseTips = onlyUseTipsInput.get(); useOriginate = useOriginateInput.get(); if (nrOfTaxa == 1) { // ignore test for Monophyletic when it only involves a tree tip if (!useOriginate && !onlyUseTips) { onlyUseTips = true; } } if (!onlyUseTips && !useOriginate && nrOfTaxa < 2) { throw new IllegalArgumentException("At least two taxa are required in a taxon set"); } if (!onlyUseTips && taxonsetInput.get() == null) { throw new IllegalArgumentException("Taxonset must be specified OR tipsonly be set to true"); } if (useOriginate && onlyUseTips) { throw new IllegalArgumentException("'useOriginate' and 'tipsOnly' cannot be both true"); } useRoot = nrOfTaxa == tree.getLeafNodeCount(); if (useOriginate && useRoot) { throw new IllegalArgumentException("Cannot use originate of root. You can set useOriginate to false to fix this"); } initialised = false; } boolean [] nodesTraversed; int nseen; protected Node getCommonAncestor(Node n1, Node n2) { // assert n1.getTree() == n2.getTree(); if( ! nodesTraversed[n1.getNr()] ) { nodesTraversed[n1.getNr()] = true; nseen += 1; } if( ! nodesTraversed[n2.getNr()] ) { nodesTraversed[n2.getNr()] = true; nseen += 1; } while (n1 != n2) { double h1 = n1.getHeight(); double h2 = n2.getHeight(); if ( h1 < h2 ) { n1 = n1.getParent(); if( ! nodesTraversed[n1.getNr()] ) { nodesTraversed[n1.getNr()] = true; nseen += 1; } } else if( h2 < h1 ) { n2 = n2.getParent(); if( ! nodesTraversed[n2.getNr()] ) { nodesTraversed[n2.getNr()] = true; nseen += 1; } } else { //zero length branches hell Node n; double b1 = n1.getLength(); double b2 = n2.getLength(); if( b1 > 0 ) { n = n2; } else { // b1 == 0 if( b2 > 0 ) { n = n1; } else { // both 0 n = n1; while( n != null && n != n2 ) { n = n.getParent(); } if( n == n2 ) { // n2 is an ancestor of n1 n = n1; } else { // always safe to advance n2 n = n2; } } } if( n == n1 ) { n = n1 = n.getParent(); } else { n = n2 = n.getParent(); } if( ! nodesTraversed[n.getNr()] ) { nodesTraversed[n.getNr()] = true; nseen += 1; } } } return n1; } // A lightweight version for finding the most recent common ancestor of a group of taxa. // return the node-ref of the MRCA. // would be nice to use nodeRef's, but they are not preserved :( public Node getCommonAncestor() { if (!initialised) { initialise(); } nodesTraversed = new boolean[tree.getNodeCount()]; Node n = getCommonAncestorInternal(); assert ! (useRoot && !n.isRoot() ) ; return n; } private Node getCommonAncestorInternal() { Node cur = tree.getNode(taxonIndex[0]); for (int k = 1; k < taxonIndex.length; ++k) { cur = getCommonAncestor(cur, tree.getNode(taxonIndex[k])); } return cur; } @Override public double calculateLogP() { if (!initialised) { initialise(); } logP = 0; if (onlyUseTips) { // tip date if (dist == null) { return logP; } for (final int i : taxonIndex) { MRCATime = tree.getNode(i).getDate(); logP += dist.logDensity(MRCATime); } return logP; } else if (useRoot) { if (dist != null) { MRCATime = tree.getRoot().getDate(); logP += dist.logDensity(MRCATime); } return logP; } else { // internal node if( false) { calcMRCAtime(tree.getRoot(), new int[1]); } else { Node m; if (taxonIndex.length == 1) { isMonophyletic = true; m = tree.getNode(taxonIndex[0]); } else { nseen = 0; m = getCommonAncestor(); isMonophyletic = (nseen == 2 * taxonIndex.length - 1); } if (useOriginate) { if (!m.isRoot()) { MRCATime = m.getParent().getDate(); } else { MRCATime = m.getDate(); } } else { MRCATime = m.getDate(); } } } if (isMonophyleticInput.get() && !isMonophyletic) { logP = Double.NEGATIVE_INFINITY; return Double.NEGATIVE_INFINITY; } if (dist != null) { logP = dist.logDensity(MRCATime); // - dist.offsetInput.get()); } return logP; } protected void initialise() { // determine which taxa are in the set List<String> set = null; if (taxonsetInput.get() != null) { set = taxonsetInput.get().asStringList(); } final List<String> taxaNames = new ArrayList<>(); for (final String taxon : tree.getTaxaNames()) { taxaNames.add(taxon); } taxonIndex = new int[nrOfTaxa]; if ( set != null ) { // m_taxonset.get() != null) { isInTaxaSet.clear(); int k = 0; for (final String taxon : set) { final int taxonIndex_ = taxaNames.indexOf(taxon); if (taxonIndex_ < 0) { throw new RuntimeException("Cannot find taxon " + taxon + " in data"); } if (isInTaxaSet.contains(taxon)) { throw new RuntimeException("Taxon " + taxon + " is defined multiple times, while they should be unique"); } isInTaxaSet.add(taxon); taxonIndex[k++] = taxonIndex_; } } else { for (int i = 0; i < nrOfTaxa; i++) { taxonIndex[i] = i; } } initialised = true; } /** * Recursively visit all leaf nodes, and collect number of taxa in the taxon * set. When all taxa in the set are visited, record the time. * * * @param node * @param taxonCount2 */ int calcMRCAtime(final Node node, final int[] taxonCount2) { if (node.isLeaf()) { taxonCount2[0]++; if (isInTaxaSet.contains(node.getID())) { return 1; } else { return 0; } } else { int taxonCount = calcMRCAtime(node.getLeft(), taxonCount2); final int leftTaxa = taxonCount2[0]; taxonCount2[0] = 0; if (node.getRight() != null) { taxonCount += calcMRCAtime(node.getRight(), taxonCount2); final int rightTaxa = taxonCount2[0]; taxonCount2[0] = leftTaxa + rightTaxa; if (taxonCount == nrOfTaxa) { if (nrOfTaxa == 1 && useOriginate) { MRCATime = node.getDate(); isMonophyletic = true; return taxonCount + 1; } // we are at the MRCA, so record the height if (useOriginate) { Node parent = node.getParent(); if (parent != null) { MRCATime = parent.getDate(); } else { MRCATime = node.getDate(); } } else { MRCATime = node.getDate(); } isMonophyletic = (taxonCount2[0] == nrOfTaxa); return taxonCount + 1; } } return taxonCount; } } @Override public void store() { storedMRCATime = MRCATime; // don't need to store m_bIsMonophyletic since it is never reported // explicitly, only logP and MRCA time are (re)stored super.store(); } @Override public void restore() { MRCATime = storedMRCATime; super.restore(); } @Override protected boolean requiresRecalculation() { return super.requiresRecalculation(); } /** * Loggable interface implementation follows * */ @Override public void init(final PrintStream out) { if (!initialised) { initialise(); } if (onlyUseTips) { if (dist != null) { out.print("logP(mrca(" + getID() + "))\t"); } for (final int i : taxonIndex) { out.print("height(" + tree.getTaxaNames()[i] + ")\t"); } } else { if (!isMonophyleticInput.get()) { out.print("monophyletic(" + taxonsetInput.get().getID() + ")\t"); } if (dist != null) { out.print("logP(mrca(" + taxonsetInput.get().getID() + "))\t"); } out.print("mrcatime(" + taxonsetInput.get().getID() + (useOriginate ? ".originate" : "") +")\t"); } } @Override public void log(final int sample, final PrintStream out) { if (onlyUseTips) { if (dist != null) { out.print(getCurrentLogP() + "\t"); } for (final int i : taxonIndex) { out.print(tree.getNode(i).getDate() + "\t"); } } else { if (!isMonophyleticInput.get()) { out.print((isMonophyletic ? 1 : 0) + "\t"); } if (dist != null) { out.print(getCurrentLogP() + "\t"); } else { calcMRCAtime(tree.getRoot(), new int[1]); } out.print(MRCATime + "\t"); } } @Override public void close(final PrintStream out) { // nothing to do } /** * Valuable interface implementation follows, first dimension is log likelihood, second the time * */ @Override public int getDimension() { return 2; } @Override public double getArrayValue() { if (Double.isNaN(logP)) { try { calculateLogP(); }catch (Exception e) { logP = Double.NaN; } } return logP; } @Override public double getArrayValue(final int dim) { if (Double.isNaN(logP)) { try { calculateLogP(); }catch (Exception e) { logP = Double.NaN; } } switch (dim) { case 0: return logP; case 1: return MRCATime; default: return 0; } } @Override public void sample(final State state, final Random random) { } @Override public List<String> getArguments() { return null; } @Override public List<String> getConditions() { return null; } }