package beast.evolution.tree.coalescent; import java.util.Collections; import java.util.List; import java.util.Random; import beast.core.CalculationNode; import beast.core.Description; import beast.core.Input; import beast.core.Input.Validate; import beast.core.State; import beast.evolution.tree.TreeDistribution; import beast.math.Binomial; /** * @author Alexei Drummond */ @Description("Calculates the probability of a beast.tree conditional on a population size function. " + "Note that this does not take the number of possible tree interval/tree topology combinations " + "in account, in other words, the constant required for making this a proper distribution that integrates " + "to unity is not calculated (partly, because we don't know how for sequentially sampled data).") public class Coalescent extends TreeDistribution { final public Input<PopulationFunction> popSizeInput = new Input<>("populationModel", "A population size model", Validate.REQUIRED); TreeIntervals intervals; @Override public void initAndValidate() { intervals = treeIntervalsInput.get(); if (intervals == null) { throw new IllegalArgumentException("Expected treeIntervals to be specified"); } calculateLogP(); } /** * do the actual calculation * */ @Override public double calculateLogP() { logP = calculateLogLikelihood(intervals, popSizeInput.get()); if (Double.isInfinite(logP)) { logP = Double.NEGATIVE_INFINITY; } return logP; } @Override public void sample(State state, Random random) { // TODO this should eventually sample a coalescent tree conditional on population size function throw new UnsupportedOperationException("This should eventually sample a coalescent tree conditional on population size function."); } /** * @return a list of unique ids for the state nodes that form the argument */ @Override public List<String> getArguments() { return Collections.singletonList(treeIntervalsInput.get().getID()); } /** * @return a list of unique ids for the state nodes that make up the conditions */ @Override public List<String> getConditions() { return popSizeInput.get().getParameterIds(); } /** * Calculates the log likelihood of this set of coalescent intervals, * given a demographic model. * * @param intervals the intervals whose likelihood is computed * @param popSizeFunction the population size function * @return the log likelihood of the intervals given the population size function */ public double calculateLogLikelihood(IntervalList intervals, PopulationFunction popSizeFunction) { return calculateLogLikelihood(intervals, popSizeFunction, 0.0); } /** * Calculates the log likelihood of this set of coalescent intervals, * given a population size function. * * @param intervals the intervals whose likelihood is computed * @param popSizeFunction the population size function * @param threshold the minimum allowable coalescent interval size; negative infinity will be returned if * any non-zero intervals are smaller than this * @return the log likelihood of the intervals given the population size function */ public double calculateLogLikelihood(IntervalList intervals, PopulationFunction popSizeFunction, double threshold) { double logL = 0.0; double startTime = 0.0; final int n = intervals.getIntervalCount(); for (int i = 0; i < n; i++) { final double duration = intervals.getInterval(i); final double finishTime = startTime + duration; final double intervalArea = popSizeFunction.getIntegral(startTime, finishTime); if (intervalArea == 0 && duration != 0) { return Double.NEGATIVE_INFINITY; } final int lineageCount = intervals.getLineageCount(i); final double kChoose2 = Binomial.choose2(lineageCount); // common part logL += -kChoose2 * intervalArea; if (intervals.getIntervalType(i) == IntervalType.COALESCENT) { final double demographicAtCoalPoint = popSizeFunction.getPopSize(finishTime); // if value at end is many orders of magnitude different than mean over interval reject the interval // This is protection against cases where ridiculous infinitesimal // population size at the end of a linear interval drive coalescent values to infinity. if (duration == 0.0 || demographicAtCoalPoint * (intervalArea / duration) >= threshold) { // if( duration == 0.0 || demographicAtCoalPoint >= threshold * (duration/intervalArea) ) { logL -= Math.log(demographicAtCoalPoint); } else { // remove this at some stage // System.err.println("Warning: " + i + " " + demographicAtCoalPoint + " " + (intervalArea/duration) ); return Double.NEGATIVE_INFINITY; } } startTime = finishTime; } return logL; } @Override protected boolean requiresRecalculation() { return ((CalculationNode) popSizeInput.get()).isDirtyCalculation() || super.requiresRecalculation(); } }