package beast.evolution.speciation;
import java.util.Arrays;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Random;
import beast.core.Description;
import beast.core.Input;
import beast.core.Input.Validate;
import beast.core.State;
import beast.core.parameter.RealParameter;
import beast.core.util.Log;
import beast.evolution.alignment.Taxon;
import beast.evolution.alignment.TaxonSet;
import beast.evolution.speciation.SpeciesTreePrior.TreePopSizeFunction;
import beast.evolution.tree.Node;
import beast.evolution.tree.TreeDistribution;
import beast.evolution.tree.TreeInterface;
@Description("Calculates probability of gene tree conditioned on a species tree (multi-species coalescent)")
public class GeneTreeForSpeciesTreeDistribution extends TreeDistribution {
final public Input<TreeInterface> speciesTreeInput =
new Input<>("speciesTree", "species tree containing the associated gene tree", Validate.REQUIRED);
// public enum PLOIDY {autosomal_nuclear, X, Y, mitrochondrial};
final public Input<Double> ploidyInput =
new Input<>("ploidy", "ploidy (copy number) for this gene, typically a whole number or half (default 2 for autosomal_nuclear)", 2.0);
// public Input<PLOIDY> m_ploidy =
// new Input<>("ploidy", "ploidy for this gene (default X, Possible values: " + PLOIDY.values(), PLOIDY.X, PLOIDY.values());
final public Input<SpeciesTreePrior> speciesTreePriorInput =
new Input<>("speciesTreePrior", "defines population function and its parameters", Validate.REQUIRED);
final public Input<TreeTopFinder> treeTopFinderInput =
new Input<>("treetop", "calculates height of species tree, required only for linear *beast analysis");
// intervals for each of the species tree branches
private PriorityQueue<Double>[] intervalsInput;
// count nr of lineages at the bottom of species tree branches
private int[] nrOfLineages;
// maps gene tree leaf nodes to species tree leaf nodes. Indexed by node number.
protected int[] nrOfLineageToSpeciesMap;
beast.evolution.speciation.SpeciesTreePrior.TreePopSizeFunction isConstantPopFunction;
RealParameter popSizesBottom;
RealParameter popSizesTop;
// Ploidy is a constant - cache value of input here
private double ploidy;
//???
public GeneTreeForSpeciesTreeDistribution() {
treeInput.setRule(Validate.REQUIRED);
}
@SuppressWarnings("unchecked")
@Override
public void initAndValidate() {
ploidy = ploidyInput.get();
// switch (m_ploidy.get()) {
// case autosomal_nuclear: m_fPloidy = 2.0; break;
// case X: m_fPloidy = 1.5; break;
// case Y: m_fPloidy = 0.5; break;
// case mitrochondrial: m_fPloidy = 0.5; break;
// default: throw new Exception("Unknown value for ploidy");
// }
final Node[] gtNodes = treeInput.get().getNodesAsArray();
final int gtLineages = treeInput.get().getLeafNodeCount();
final Node[] sptNodes = speciesTreeInput.get().getNodesAsArray();
final int speciesCount = speciesTreeInput.get().getNodeCount();
if (speciesCount <= 1 && sptNodes[0].getID().equals("Beauti2DummyTaxonSet")) {
// we are in Beauti, don't initialise
return;
}
// reserve memory for priority queues
intervalsInput = new PriorityQueue[speciesCount];
for (int i = 0; i < speciesCount; i++) {
intervalsInput[i] = new PriorityQueue<>();
}
// sanity check lineage nodes are all at height=0
for (int i = 0; i < gtLineages; i++) {
if (gtNodes[i].getHeight() != 0) {
throw new IllegalArgumentException("Cannot deal with taxon " + gtNodes[i].getID() +
", which has non-zero height + " + gtNodes[i].getHeight());
}
}
// set up m_nLineageToSpeciesMap
nrOfLineageToSpeciesMap = new int[gtLineages];
Arrays.fill(nrOfLineageToSpeciesMap, -1);
for (int i = 0; i < gtLineages; i++) {
final String speciesID = getSetID(gtNodes[i].getID());
// ??? can this be a startup check? can this happen during run due to tree change?
if (speciesID == null) {
throw new IllegalArgumentException("Cannot find species for lineage " + gtNodes[i].getID());
}
for (int species = 0; species < speciesCount; species++) {
if (speciesID.equals(sptNodes[species].getID())) {
nrOfLineageToSpeciesMap[i] = species;
break;
}
}
if (nrOfLineageToSpeciesMap[i] < 0) {
throw new IllegalArgumentException("Cannot find species with name " + speciesID + " in species tree");
}
}
// calculate nr of lineages per species
nrOfLineages = new int[speciesCount];
// for (final Node node : gtNodes) {
// if (node.isLeaf()) {
// final int species = m_nLineageToSpeciesMap[node.getNr()];
// m_nLineages[species]++;
// }
// }
final SpeciesTreePrior popInfo = speciesTreePriorInput.get();
isConstantPopFunction = popInfo.popFunctionInput.get();
popSizesBottom = popInfo.popSizesBottomInput.get();
popSizesTop = popInfo.popSizesTopInput.get();
assert( ! (isConstantPopFunction == TreePopSizeFunction.linear && treeTopFinderInput.get() == null ) );
}
/**
* @param lineageID
* @return species ID to which the lineage ID belongs according to the TaxonSets
*/
String getSetID(final String lineageID) {
final TaxonSet taxonSuperset = speciesTreePriorInput.get().taxonSetInput.get();
final List<Taxon> taxonSets = taxonSuperset.taxonsetInput.get();
for (final Taxon taxonSet : taxonSets) {
final List<Taxon> taxa = ((TaxonSet) taxonSet).taxonsetInput.get();
for (final Taxon aTaxa : taxa) {
if (aTaxa.getID().equals(lineageID)) {
return taxonSet.getID();
}
}
}
return null;
}
@Override
public double calculateLogP() {
logP = 0;
for (final PriorityQueue<Double> m_interval : intervalsInput) {
m_interval.clear();
}
Arrays.fill(nrOfLineages, 0);
final TreeInterface stree = speciesTreeInput.get();
final Node[] speciesNodes = stree.getNodesAsArray();
traverseLineageTree(speciesNodes, treeInput.get().getRoot());
// System.err.println(getID());
// for (int i = 0; i < m_intervals.length; i++) {
// System.err.println(m_intervals[i]);
// }
// if the gene tree does not fit the species tree, logP = -infinity by now
if (logP == 0) {
traverseSpeciesTree(stree.getRoot());
}
// System.err.println("logp=" + logP);
return logP;
}
/**
* calculate contribution to logP for each of the branches of the species tree
*
* @param node*
*/
private void traverseSpeciesTree(final Node node) {
if (!node.isLeaf()) {
traverseSpeciesTree(node.getLeft());
traverseSpeciesTree(node.getRight());
}
// calculate contribution of a branch in the species tree to the log probability
final int nodeIndex = node.getNr();
// k, as defined in the paper
//System.err.println(Arrays.toString(m_nLineages));
final int k = intervalsInput[nodeIndex].size();
final double[] times = new double[k + 2];
times[0] = node.getHeight();
for (int i = 1; i <= k; i++) {
times[i] = intervalsInput[nodeIndex].poll();
}
if (!node.isRoot()) {
times[k + 1] = node.getParent().getHeight();
} else {
if (isConstantPopFunction == TreePopSizeFunction.linear) {
times[k + 1] = treeTopFinderInput.get().getHighestTreeHeight();
} else {
times[k + 1] = Math.max(node.getHeight(), treeInput.get().getRoot().getHeight());
}
}
// sanity check
for (int i = 0; i <= k; i++) {
if (times[i] > times[i + 1]) {
Log.warning.println("invalid times");
calculateLogP();
}
}
final int lineagesBottom = nrOfLineages[nodeIndex];
switch (isConstantPopFunction) {
case constant:
calcConstantPopSizeContribution(lineagesBottom, popSizesBottom.getValue(nodeIndex), times, k);
break;
case linear:
logP += calcLinearPopSizeContributionJH(lineagesBottom, nodeIndex, times, k, node);
break;
case linear_with_constant_root:
if (node.isRoot()) {
final double popSize = getTopPopSize(node.getLeft().getNr()) + getTopPopSize(node.getRight().getNr());
calcConstantPopSizeContribution(lineagesBottom, popSize, times, k);
} else {
logP += calcLinearPopSizeContribution(lineagesBottom, nodeIndex, times, k, node);
}
break;
}
}
/* the contribution of a branch in the species tree to
* the log probability, for constant population function.
*/
private void calcConstantPopSizeContribution(final int lineagesBottom, final double popSize2,
final double[] times, final int k) {
final double popSize = popSize2 * ploidy;
logP += -k * Math.log(popSize);
// System.err.print(logP);
for (int i = 0; i <= k; i++) {
logP += -((lineagesBottom - i) * (lineagesBottom - i - 1.0) / 2.0) * (times[i + 1] - times[i]) / popSize;
}
// System.err.println(" " + logP + " " + Arrays.toString(times) + " " + nodeIndex + " " + k);
}
/* the contribution of a branch in the species tree to
* the log probability, for linear population function.
*/
private double calcLinearPopSizeContribution(final int lineagesBottom, final int nodeIndex, final double[] times,
final int k, final Node node) {
double lp = 0.0;
final double popSizeBottom;
if (node.isLeaf()) {
popSizeBottom = popSizesBottom.getValue(nodeIndex) * ploidy;
} else {
// use sum of left and right child branches for internal nodes
popSizeBottom = (getTopPopSize(node.getLeft().getNr()) + getTopPopSize(node.getRight().getNr())) * ploidy;
}
final double popSizeTop = getTopPopSize(nodeIndex) * ploidy;
final double a = (popSizeTop - popSizeBottom) / (times[k + 1] - times[0]);
final double b = popSizeBottom;
for (int i = 0; i < k; i++) {
//double popSize = popSizeBottom + (popSizeTop-popSizeBottom) * times[i+1]/(times[k]-times[0]);
final double popSize = a * (times[i + 1] - times[0]) + b;
lp += -Math.log(popSize);
}
for (int i = 0; i <= k; i++) {
if (Math.abs(popSizeTop - popSizeBottom) < 1e-10) {
// slope = 0, so population function is constant
final double popSize = a * (times[i + 1] - times[0]) + b;
lp += -((lineagesBottom - i) * (lineagesBottom - i - 1.0) / 2.0) * (times[i + 1] - times[i]) / popSize;
} else {
final double f = (a * (times[i + 1] - times[0]) + b) / (a * (times[i] - times[0]) + b);
lp += -((lineagesBottom - i) * (lineagesBottom - i - 1.0) / 2.0) * Math.log(f) / a;
}
}
return lp;
}
private double calcLinearPopSizeContributionJH(final int lineagesBottom, final int nodeIndex, final double[] times,
final int k, final Node node) {
double lp = 0.0;
double popSizeBottom;
if (node.isLeaf()) {
popSizeBottom = popSizesBottom.getValue(nodeIndex);
} else {
// use sum of left and right child branches for internal nodes
popSizeBottom = (getTopPopSize(node.getLeft().getNr()) + getTopPopSize(node.getRight().getNr()));
}
popSizeBottom *= ploidy;
final double popSizeTop = getTopPopSize(nodeIndex) * ploidy;
final double d5 = popSizeTop - popSizeBottom;
final double time0 = times[0];
final double a = d5 / (times[k + 1] - time0);
final double b = popSizeBottom;
if (Math.abs(d5) < 1e-10) {
// use approximation for small values to bypass numerical instability
for (int i = 0; i <= k; i++) {
final double timeip1 = times[i + 1];
final double popSize = a * (timeip1 - time0) + b;
if( i < k ) {
lp += -Math.log(popSize);
}
// slope = 0, so population function is constant
final int i1 = lineagesBottom - i;
lp -= (i1 * (i1 - 1.0) / 2.0) * (timeip1 - times[i]) / popSize;
}
} else {
final double vv = b - a * time0;
for (int i = 0; i <= k; i++) {
final double popSize = a * times[i + 1] + vv;
if( i < k ) {
lp += -Math.log(popSize);
}
final double f = popSize / (a * times[i] + vv);
final int i1 = lineagesBottom - i;
lp += -(i1 * (i1 - 1.0) / 2.0) * Math.log(f) / a;
}
}
return lp;
}
/**
* collect intervals for each of the branches of the species tree
* as defined by the lineage tree.
*
* @param speciesNodes
* @param node
* @return
*/
private int traverseLineageTree(final Node[] speciesNodes, final Node node) {
if (node.isLeaf()) {
final int species = nrOfLineageToSpeciesMap[node.getNr()];
nrOfLineages[species]++;
return species;
} else {
int speciesLeft = traverseLineageTree(speciesNodes, node.getLeft());
int speciesRight = traverseLineageTree(speciesNodes, node.getRight());
final double height = node.getHeight();
while (!speciesNodes[speciesLeft].isRoot() && height > speciesNodes[speciesLeft].getParent().getHeight()) {
speciesLeft = speciesNodes[speciesLeft].getParent().getNr();
nrOfLineages[speciesLeft]++;
}
while (!speciesNodes[speciesRight].isRoot() && height > speciesNodes[speciesRight].getParent().getHeight()) {
speciesRight = speciesNodes[speciesRight].getParent().getNr();
nrOfLineages[speciesRight]++;
}
// validity check
if (speciesLeft != speciesRight) {
// if we got here, it means the gene tree does
// not fit in the species tree
logP = Double.NEGATIVE_INFINITY;
}
intervalsInput[speciesRight].add(height);
return speciesRight;
}
}
/* return population size at top. For linear with constant root, there is no
* entry for the root. An internal node can have the number equal to dimension
* of m_fPopSizesTop, then the root node can be numbered with a lower number
* and we can use that entry in m_fPopSizesTop for the rogue internal node.
*/
private double getTopPopSize(final int nodeIndex) {
if (nodeIndex < popSizesTop.getDimension()) {
return popSizesTop.getArrayValue(nodeIndex);
}
return popSizesTop.getArrayValue(speciesTreeInput.get().getRoot().getNr());
}
@Override
public boolean requiresRecalculation() {
// TODO: check whether this is worth optimising?
return true;
}
@Override
public List<String> getArguments() {
return null;
}
@Override
public List<String> getConditions() {
return null;
}
@Override
public void sample(final State state, final Random random) {
}
}