package beast.evolution.speciation; import java.util.Arrays; import java.util.List; import java.util.Random; import beast.core.Description; import beast.core.Input; import beast.core.Input.Validate; import beast.core.State; import beast.core.parameter.RealParameter; import beast.evolution.alignment.TaxonSet; import beast.evolution.tree.Node; import beast.evolution.tree.TreeDistribution; import beast.math.distributions.Gamma; @Description("Species tree prior for *BEAST analysis") public class SpeciesTreePrior extends TreeDistribution { //public Input<Tree> m_speciesTree = new Input<>("speciesTree", "species tree containing the associated gene tree", Validate.REQUIRED); protected enum TreePopSizeFunction {constant, linear, linear_with_constant_root} public final Input<TreePopSizeFunction> popFunctionInput = new Input<>("popFunction", "Population function. " + "This can be " + Arrays.toString(TreePopSizeFunction.values()) + " (default 'constant')", TreePopSizeFunction.constant, TreePopSizeFunction.values()); public final Input<RealParameter> popSizesBottomInput = new Input<>("bottomPopSize", "population size parameter for populations at the bottom of a branch. " + "For linear population function, this is the same at the top of the branch.", Validate.REQUIRED); public final Input<RealParameter> popSizesTopInput = new Input<>("topPopSize", "population size parameter at the top of a branch. " + "Ignored for constant population function, but required for linear population function."); public final Input<RealParameter> gammaParameterInput = new Input<>("gammaParameter", "scale parameter of the gamma distribution over population sizes. " + "This makes this parameter half the expected population size on all branches for constant population function, " + "but a quarter of the expected population size for tip branches only for linear population functions.", Validate.REQUIRED); // public Input<RealParameter> m_rootHeightParameter = new Input<>("rootBranchHeight","height of the node above the root, representing the root branch", Validate.REQUIRED); /** * m_taxonSet is used by GeneTreeForSpeciesTreeDistribution * */ final public Input<TaxonSet> taxonSetInput = new Input<>("taxonset", "set of taxa mapping lineages to species", Validate.REQUIRED); private TreePopSizeFunction popFunction; private RealParameter popSizesBottom; private RealParameter popSizesTop; private Gamma gamma2Prior; private Gamma gamma4Prior; @Override public void initAndValidate() { popFunction = popFunctionInput.get(); popSizesBottom = popSizesBottomInput.get(); popSizesTop = popSizesTopInput.get(); // set up sizes of population functions final int speciesCount = treeInput.get().getLeafNodeCount(); final int nodeCount = treeInput.get().getNodeCount(); switch (popFunction) { case constant: popSizesBottom.setDimension(nodeCount); break; case linear: if (popSizesTop == null) { throw new IllegalArgumentException("topPopSize must be specified"); } popSizesBottom.setDimension(speciesCount); popSizesTop.setDimension(nodeCount); break; case linear_with_constant_root: if (popSizesTop == null) { throw new IllegalArgumentException("topPopSize must be specified"); } popSizesBottom.setDimension(speciesCount); popSizesTop.setDimension(nodeCount - 1); break; } // bottom prior = Gamma(2,Psi) gamma2Prior = new Gamma(); gamma2Prior.betaInput.setValue(gammaParameterInput.get(), gamma2Prior); // top prior = Gamma(4,Psi) gamma4Prior = new Gamma(); final RealParameter parameter = new RealParameter(new Double[]{4.0}); gamma4Prior.alphaInput.setValue(parameter, gamma4Prior); gamma4Prior.betaInput.setValue(gammaParameterInput.get(), gamma4Prior); if (popFunction != TreePopSizeFunction.constant && gamma4Prior == null) { throw new IllegalArgumentException("Top prior must be specified when population function is not constant"); } // make sure the m_taxonSet is a set of taxonsets // HACK to make Beauti initialise: skip the check here // for (Taxon taxon : m_taxonSet.get().m_taxonset.get()) { // if (!(taxon instanceof TaxonSet)) { // throw new IllegalArgumentException("taxonset should be sets of taxa only, not individual taxons"); // } // } } @Override public double calculateLogP() { logP = 0; // make sure the root branch length is positive // if (m_rootHeightParameter.get().getValue() < m_speciesTree.get().getRoot().getHeight()) { // logP = Double.NEGATIVE_INFINITY; // return logP; // } final Node[] speciesNodes = treeInput.get().getNodesAsArray(); try { switch (popFunction) { case constant: // constant pop size function logP += gamma2Prior.calcLogP(popSizesBottom); // for (int i = 0; i < speciesNodes.length; i++) { // double popSize = m_fPopSizesBottom.getValue(i); // logP += m_bottomPrior.logDensity(popSize); // } break; case linear: // linear pop size function // int speciesCount = m_tree.get().getLeafNodeCount(); // m_fPopSizesBottom.setDimension(speciesCount); // logP += m_gamma4Prior.calcLogP(m_fPopSizesBottom); // int nodeCount = m_tree.get().getNodeCount(); // m_fPopSizesTop.setDimension(nodeCount-1); // logP += m_gamma2Prior.calcLogP(m_fPopSizesTop); for (int i = 0; i < speciesNodes.length; i++) { final Node node = speciesNodes[i]; final double popSizeBottom; if (node.isLeaf()) { // Gamma(4, psi) prior popSizeBottom = popSizesBottom.getValue(i); logP += gamma4Prior.logDensity(popSizeBottom); } final double popSizeTop = popSizesTop.getValue(i); logP += gamma2Prior.logDensity(popSizeTop); } break; case linear_with_constant_root: // logP += m_gamma4Prior.calcLogP(m_fPopSizesBottom); // logP += m_gamma2Prior.calcLogP(m_fPopSizesTop); // int rootNr = m_tree.get().getRoot().getNr(); // double popSize = m_fPopSizesTop.getValue(rootNr); // logP -= m_gamma2Prior.logDensity(popSize); for (int i = 0; i < speciesNodes.length; i++) { final Node node = speciesNodes[i]; if (node.isLeaf()) { final double popSizeBottom = popSizesBottom.getValue(i); logP += gamma4Prior.logDensity(popSizeBottom); } if (!node.isRoot()) { if (i < speciesNodes.length - 1) { final double popSizeTop = popSizesTop.getArrayValue(i); logP += gamma2Prior.logDensity(popSizeTop); } else { final int nodeIndex = treeInput.get().getRoot().getNr(); final double popSizeTop = popSizesTop.getArrayValue(nodeIndex); logP += gamma2Prior.logDensity(popSizeTop); } } } break; } } catch (Exception e) { // exceptions can be thrown by the gamma priors e.printStackTrace(); return Double.NEGATIVE_INFINITY; } return logP; } @Override protected boolean requiresRecalculation() { return true; } @Override public List<String> getArguments() { return null; } @Override public List<String> getConditions() { return null; } @Override public void sample(final State state, final Random random) { } }