package beast.evolution.tree.coalescent;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import beast.core.BEASTObject;
import beast.core.Description;
import beast.core.Function;
import beast.core.Input;
import beast.core.Input.Validate;
import beast.core.State;
import beast.core.parameter.IntegerParameter;
import beast.evolution.tree.Tree;
import beast.evolution.tree.TreeDistribution;
import beast.math.Binomial;
/**
* @author Alexei Drummond
*/
@Description("A likelihood function for the generalized skyline plot coalescent.")
public class BayesianSkyline extends TreeDistribution {
//public class BayesianSkyline extends PopulationFunction.Abstract {
final public Input<Function> popSizeParamInput = new Input<>("popSizes", "present-day population size. "
+ "If time units are set to Units.EXPECTED_SUBSTITUTIONS then"
+ "the N0 parameter will be interpreted as N0 * mu. "
+ "Also note that if you are dealing with a diploid population " + "N0 will be out by a factor of 2.",
Validate.REQUIRED);
final public Input<IntegerParameter> groupSizeParamInput = new Input<>("groupSizes",
"the group sizes parameter", Validate.REQUIRED);
// public Input<Tree> treeInput = new Input<>("tree",
// "The tree containing coalescent node times for use in defining BSP.");
// public Input<TreeIntervals> m_treeIntervals = new Input<>("treeIntervals",
// "The intervals of the tree containing coalescent node times for use in defining BSP.", Validate.REQUIRED);
Function popSizes;
IntegerParameter groupSizes;
Tree tree;
TreeIntervals intervals;
double[] coalescentTimes;
int[] cumulativeGroupSizes;
boolean m_bIsPrepared = false;
public BayesianSkyline() {
}
// /**
// * This pseudo-constructor is only used for junit tests
// *
// * @param populationSize
// * @param groupSizes
// * @param tree
// */
// public void init(RealParameter populationSize, IntegerParameter
// groupSizes, Tree tree) {
// super.init(populationSize, groupSizes, tree);
// }
@Override
public void initAndValidate() {
if (treeInput.get() != null) {
throw new IllegalArgumentException("only tree intervals (not tree) should not be specified");
}
intervals = treeIntervalsInput.get();
groupSizes = groupSizeParamInput.get();
popSizes = popSizeParamInput.get();
// make sure that the sum of groupsizes == number of coalescent events
int events = intervals.treeInput.get().getInternalNodeCount();
if (groupSizes.getDimension() > events) {
throw new IllegalArgumentException("There are more groups than coalescent nodes in the tree.");
}
int paramDim2 = groupSizes.getDimension();
int eventsCovered = 0;
for (int i = 0; i < groupSizes.getDimension(); i++) {
eventsCovered += groupSizes.getValue(i);
}
if (eventsCovered != events) {
if (eventsCovered == 0 || eventsCovered == paramDim2) {
// double[] uppers = new double[paramDim2];
// double[] lowers = new double[paramDim2];
// For these special cases we assume that the XML has not
// specified initial group sizes
// or has set all to 1 and we set them here automatically...
int eventsEach = events / paramDim2;
int eventsExtras = events % paramDim2;
Integer[] values = new Integer[paramDim2];
for (int i = 0; i < paramDim2; i++) {
if (i < eventsExtras) {
values[i] = eventsEach + 1;
} else {
values[i] = eventsEach;
}
// uppers[i] = Double.MAX_VALUE;
// lowers[i] = 1.0;
}
// if (type == EXPONENTIAL_TYPE || type == LINEAR_TYPE) {
// lowers[0] = 2.0;
// }
IntegerParameter parameter = new IntegerParameter(values);
parameter.setBounds(1, Integer.MAX_VALUE);
groupSizes.assignFromWithoutID(parameter);
} else {
// ... otherwise assume the user has made a mistake setting
// initial group sizes.
throw new IllegalArgumentException(
"The sum of the initial group sizes does not match the number of coalescent events in the tree.");
}
}
prepare();
}
public void prepare() {
cumulativeGroupSizes = new int[groupSizes.getDimension()];
int intervalCount = 0;
for (int i = 0; i < cumulativeGroupSizes.length; i++) {
intervalCount += groupSizes.getValue(i);
cumulativeGroupSizes[i] = intervalCount;
}
coalescentTimes = intervals.getCoalescentTimes(coalescentTimes);
assert (intervals.getSampleCount() == intervalCount);
m_bIsPrepared = true;
}
/**
* CalculationNode methods *
*/
@Override
protected boolean requiresRecalculation() {
m_bIsPrepared = false;
return true;
}
@Override
public void store() {
m_bIsPrepared = false;
super.store();
}
@Override
public void restore() {
m_bIsPrepared = false;
super.restore();
}
public List<String> getParameterIds() {
List<String> paramIDs = new ArrayList<>();
paramIDs.add(((BEASTObject) popSizes).getID());
paramIDs.add(groupSizes.getID());
return paramIDs;
}
/**
* Calculates the log likelihood of this set of coalescent intervals, given
* a demographic model.
*/
@Override
public double calculateLogP() {
if (!m_bIsPrepared) {
prepare();
}
logP = 0.0;
double currentTime = 0.0;
int groupIndex = 0;
// int[] groupSizes = getGroupSizes();
// double[] groupEnds = getGroupHeights();
int subIndex = 0;
//ConstantPopulation cp = new ConstantPopulation();// Units.Type.YEARS);
for (int j = 0; j < intervals.getIntervalCount(); j++) {
// set the population size to the size of the middle of the current
// interval
final double ps = getPopSize(currentTime + (intervals.getInterval(j) / 2.0));
//cp.setN0(ps);
if (intervals.getIntervalType(j) == IntervalType.COALESCENT) {
subIndex += 1;
if (subIndex >= groupSizes.getValue(groupIndex)) {
groupIndex += 1;
subIndex = 0;
}
}
logP += calculateIntervalLikelihood(ps, intervals.getInterval(j), currentTime,
intervals.getLineageCount(j), intervals.getIntervalType(j));
// insert zero-length coalescent intervals
int diff = intervals.getCoalescentEvents(j) - 1;
for (int k = 0; k < diff; k++) {
//cp.setN0(getPopSize(currentTime));
double popSize = getPopSize(currentTime);
logP += calculateIntervalLikelihood(popSize, 0.0, currentTime, intervals.getLineageCount(j) - k - 1,
IntervalType.COALESCENT);
subIndex += 1;
if (subIndex >= groupSizes.getValue(groupIndex)) {
groupIndex += 1;
subIndex = 0;
}
}
currentTime += intervals.getInterval(j);
}
return logP;
}
public static double calculateIntervalLikelihood(double popSize, double width,
double timeOfPrevCoal, int lineageCount, IntervalType type) {
final double timeOfThisCoal = width + timeOfPrevCoal;
final double intervalArea = (timeOfThisCoal - timeOfPrevCoal) / popSize;
//demogFunction.getIntegral(timeOfPrevCoal, timeOfThisCoal);
final double kchoose2 = Binomial.choose2(lineageCount);
double like = -kchoose2 * intervalArea;
switch (type) {
case COALESCENT:
final double demographic = Math.log(popSize);//demogFunction.getLogDemographic(timeOfThisCoal);
like += -demographic;
break;
default:
break;
}
return like;
}
/**
* @param t time
* @return
*/
public double getPopSize(double t) {
if (!m_bIsPrepared) {
prepare();
}
if (t > coalescentTimes[coalescentTimes.length - 1])
return popSizes.getArrayValue(popSizes.getDimension() - 1);
int epoch = Arrays.binarySearch(coalescentTimes, t);
if (epoch < 0) {
epoch = -epoch - 1;
}
int groupIndex = Arrays.binarySearch(cumulativeGroupSizes, epoch);
if (groupIndex < 0) {
groupIndex = -groupIndex - 1;
} else {
groupIndex++;
}
if (groupIndex >= popSizes.getDimension()) {
groupIndex = popSizes.getDimension() - 1;
}
return popSizes.getArrayValue(groupIndex);
}
@Override
public List<String> getArguments() {
// TODO Auto-generated method stub
return null;
}
@Override
public List<String> getConditions() {
// TODO Auto-generated method stub
return null;
}
@Override
public void sample(State state, Random random) {
// TODO Auto-generated method stub
}
// This is the implementation of BayesianSkyline as PopulationFunction.Abstract, which is somewhat slower
// than the implementation as a Distribution (43s/Msamples agains 41s/Msamples on Dengue data)
// /**
// * @param t
// * time
// * @return
// */
// @Override
// public double getIntensity(double t) {
// if (!m_bIsPrepared) {
// prepare();
// }
//
// int index = 0;
// int groupIndex = 0;
//
// t -= 1e-100;
// if (t > coalescentTimes[coalescentTimes.length - 1]) {
// t = coalescentTimes[coalescentTimes.length - 1];
// }
//
// if (t < coalescentTimes[0]) {
// return t / popSizes.getArrayValue(0);
// } else {
//
// double intensity = coalescentTimes[0] / popSizes.getArrayValue(0);
// index += 1;
// if (index >= cumulativeGroupSizes[groupIndex]) {
// groupIndex += 1;
// }
//
// while (t > coalescentTimes[index]) {
//
// intensity += (coalescentTimes[index] - coalescentTimes[index - 1]) / popSizes.getArrayValue(groupIndex);
//
// index += 1;
// if (index >= cumulativeGroupSizes[groupIndex]) {
// groupIndex += 1;
// }
// }
// intensity += (t - coalescentTimes[index - 1]) / popSizes.getArrayValue(groupIndex);
//
// return intensity;
// }
// }
//
// public double getInverseIntensity(double x) {
// throw new UnsupportedOperationException();
// }
}