/*
* CalibrationPoints.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.speciation;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeUtils;
import dr.evolution.util.Taxa;
import dr.inference.model.Statistic;
import dr.math.distributions.Distribution;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* @author Joseph Heled
* Date: 8/06/2011
*/
public class CalibrationPoints {
public static enum CorrectionType {
EXACT("exact"),
APPROXIMATED("approximated"),
PEXACT("pexact"),
NONE("none");
CorrectionType(String name) {
this.name = name;
}
public String toString() {
return name;
}
private final String name;
}
public CalibrationPoints(Tree tree, boolean isYule, List<Distribution> dists, List<Taxa> clades,
List<Boolean> forParent, Statistic userPDF, CorrectionType correctionType) {
this.densities = new Distribution[dists.size()];
this.clades = new int[clades.size()][];
this.forParent = new boolean[clades.size()];
for (int k = 0; k < clades.size(); ++k) {
final Taxa tk = clades.get(k);
for (int i = k + 1; i < clades.size(); ++i) {
final Taxa ti = clades.get(i);
if (ti.containsAny(tk)) {
if (!(ti.containsAll(tk) || tk.containsAll(ti))) {
throw new IllegalArgumentException("Overlapping clades??");
}
}
}
}
Taxa[] taxaInOrder = new Taxa[clades.size()];
{
int loc = clades.size() - 1;
while (loc >= 0) {
// place maximal clades at end one at a time
int k = 0;
for (/**/; k < clades.size(); ++k) {
if (isMaximal(clades, k)) {
break;
}
}
this.densities[loc] = dists.remove(k);
this.forParent[loc] = forParent.remove(k);
final Taxa tk = clades.get(k);
final int tkcount = tk.getTaxonCount();
this.clades[loc] = new int[tkcount];
for (int nt = 0; nt < tkcount; ++nt) {
final int taxonIndex = tree.getTaxonIndex(tk.getTaxon(nt));
this.clades[loc][nt] = taxonIndex;
if (taxonIndex < 0) {
throw new IllegalArgumentException("Taxon not found in tree: " + tk.getTaxon(nt));
}
}
taxaInOrder[loc] = tk;
clades.remove(k);
--loc;
}
}
List<Integer>[] tio = new List[taxaInOrder.length];
for (int k = 0; k < taxaInOrder.length; ++k) {
tio[k] = new ArrayList<Integer>();
}
for (int k = 0; k < taxaInOrder.length; ++k) {
for (int i = k + 1; i < taxaInOrder.length; ++i) {
if (taxaInOrder[i].containsAll(taxaInOrder[k])) {
tio[i].add(k);
break;
}
}
}
this.taxaPartialOrder = new int[taxaInOrder.length][];
for (int k = 0; k < taxaInOrder.length; ++k) {
List<Integer> tiok = tio[k];
this.taxaPartialOrder[k] = new int[tiok.size()];
for (int j = 0; j < tiok.size(); ++j) {
this.taxaPartialOrder[k][j] = tiok.get(j);
}
}
this.freeHeights = new int[this.clades.length];
for (int k = 0; k < this.clades.length; ++k) {
int taken = 0;
for (int i : this.taxaPartialOrder[k]) {
taken += this.clades[i].length - (this.forParent[i] ? 0 : 1);
}
this.freeHeights[k] = this.clades[k].length - (this.forParent[k] ? 1 : 2) - taken;
assert this.freeHeights[k] >= 0;
}
// true if clade is not contained in any other clade
boolean[] maximal = new boolean[this.clades.length];
for (int k = 0; k < this.clades.length; ++k) {
maximal[k] = true;
}
for (int k = 0; k < this.clades.length; ++k) {
for (int i : this.taxaPartialOrder[k]) {
maximal[i] = false;
}
}
rootCorrection = this.clades[this.clades.length - 1].length < tree.getExternalNodeCount();
this.calibrationLogPDF = userPDF;
this.correctionType = correctionType;
if (userPDF == null) {
if (!isYule) {
throw new IllegalArgumentException("Sorry, not implemented: conditional calibration prior for this non Yule models.");
}
if (correctionType == CorrectionType.EXACT) {
if (densities.length == 1) {
// closed form formula
} else {
boolean anyParent = false;
for (boolean in : this.forParent) {
if (in) {
anyParent = true;
}
}
if (anyParent) {
throw new IllegalArgumentException("Sorry, not implemented: calibration on parent for more than one clade.");
}
if (densities.length == 2 && taxaInOrder[1].containsAll(taxaInOrder[0])) {
// closed form formulas
} else {
setUpTables(tree);
linsIter = new CalibrationLineagesIterator(this.clades, this.taxaPartialOrder, maximal,
tree.getExternalNodeCount());
lastHeights = new double[this.clades.length];
}
}
} else if (correctionType == CorrectionType.PEXACT) {
setUpTables(tree);
}
}
}
private void setUpTables(Tree tree) {
final int MAX_N = tree.getExternalNodeCount() + 1;
double[] lints = new double[MAX_N];
lc2 = new double[MAX_N];
lfactorials = new double[MAX_N];
lNR = new double[MAX_N];
lints[0] = Double.NEGATIVE_INFINITY; //-infinity, should never be used
lints[1] = 0.0;
for (int i = 2; i < MAX_N; ++i) {
lints[i] = Math.log(i);
}
lc2[0] = lc2[1] = Double.NEGATIVE_INFINITY;
for (int i = 2; i < MAX_N; ++i) {
lc2[i] = lints[i] + lints[i - 1] - lg2;
}
lfactorials[0] = 0.0;
for (int i = 1; i < MAX_N; ++i) {
lfactorials[i] = lfactorials[i - 1] + lints[i];
}
lNR[0] = Double.NEGATIVE_INFINITY; //-infinity, should never be used
lNR[1] = 0.0;
for (int i = 2; i < MAX_N; ++i) {
lNR[i] = lNR[i - 1] + lc2[i];
}
}
private boolean isMaximal(List<Taxa> taxa, int k) {
final Taxa tk = taxa.get(k);
for (int i = 0; i < taxa.size(); ++i) {
if (i != k) {
final Taxa ti = taxa.get(i);
if (ti.containsAll(tk)) {
return false;
}
}
}
return true;
}
public double getCorrection(Tree tree, final double lam) {
double logL = 0.0;
final int nDists = densities.length;
double hs[] = new double[nDists];
for (int k = 0; k < nDists; ++k) {
NodeRef c;
final int[] taxk = clades[k];
if (taxk.length > 1) {
// check if monophyly and find node
c = TreeUtils.getCommonAncestor(tree, taxk);
if (TreeUtils.getLeafCount(tree, c) != taxk.length) {
return Double.NEGATIVE_INFINITY;
}
} else {
c = tree.getNode(taxk[0]);
assert forParent[k];
}
if (forParent[k]) {
c = tree.getParent(c);
}
final double h = tree.getNodeHeight(c);
logL += densities[k].logPdf(h);
hs[k] = h;
}
if (Double.isInfinite(logL)) {
return logL;
}
if (correctionType == CorrectionType.NONE) {
return logL;
}
if (calibrationLogPDF == null) {
switch (correctionType) {
case EXACT: {
if (nDists == 1) {
logL -= logMarginalDensity(lam, tree.getExternalNodeCount(), hs[0], clades[0].length, forParent[0]);
} else if (nDists == 2 && taxaPartialOrder[1].length == 1) {
assert !forParent[0] && !forParent[1];
logL -= logMarginalDensity(lam, tree.getExternalNodeCount(), hs[0], clades[0].length,
hs[1], clades[1].length);
} else {
if (lastLam == lam) {
int k = 0;
for (; k < hs.length; ++k) {
if (hs[k] != lastHeights[k]) {
break;
}
}
if (k == hs.length) {
return lastValue;
}
}
// the slow and painful way
double[] hss = new double[hs.length];
int[] ranks = new int[hs.length];
for (int k = 0; k < hs.length; ++k) {
int r = 0;
for (double h : hs) {
r += (h < hs[k]) ? 1 : 0;
}
ranks[k] = r + 1;
hss[r] = hs[k];
}
logL -= logMarginalDensity(lam, hss, ranks, linsIter);
lastLam = lam;
System.arraycopy(hs, 0, lastHeights, 0, lastHeights.length);
lastValue = logL;
}
break;
}
case APPROXIMATED: {
final double loglam = Math.log(lam);
int maxh = 0;
for (int k = 0; k < nDists; ++k) {
final double v = -lam * hs[k];
if (freeHeights[k] > 0) {
logL -= Math.log1p(-Math.exp(v)) * freeHeights[k];
}
logL -= v + loglam;
if (hs[k] > hs[maxh]) {
maxh = k;
}
}
if (rootCorrection || true) {
logL -= -(forParent[maxh] ? 0 : 1) * lam * hs[maxh];
}
if (Double.isNaN(logL)) {
logL = Double.NEGATIVE_INFINITY;
}
break;
}
case PEXACT: {
Arrays.sort(hs);
int cs[] = new int[nDists + 1];
final int internalNodeCount = tree.getInternalNodeCount();
for (int k = 0; k < internalNodeCount; ++k) {
final double nhk = tree.getNodeHeight(tree.getInternalNode(k));
int i = 0;
for (/**/; i < hs.length; ++i) {
if (hs[i] >= nhk) {
break;
}
}
if (i == hs.length) {
cs[i]++;
} else {
if (nhk < hs[i]) {
cs[i]++;
}
}
}
if (false) {
int l = nDists;
for (int i = 0; i < cs.length; ++i) {
l += cs[i];
}
assert l == internalNodeCount;
}
double ll = 0;
ll += cs[0] * Math.log1p(-Math.exp(-lam * hs[0])) - lam * hs[0] - lfactorials[cs[0]];
for (int i = 1; i < cs.length - 1; ++i) {
int c = cs[i];
ll += c * (Math.log1p(-Math.exp(-lam * (hs[i] - hs[i - 1]))) - lam * hs[i - 1]);
ll += -lam * hs[i] - lfactorials[c];
}
ll += -lam * (cs[nDists] + 1) * hs[nDists - 1] - lfactorials[cs[nDists] + 1];
ll += Math.log(lam) * nDists;
logL -= ll;
break;
}
}
} else {
final double value = calibrationLogPDF.getStatisticValue(0);
if (Double.isNaN(value) || Double.isInfinite(value)) {
logL = Double.NEGATIVE_INFINITY;
} else {
logL -= value;
}
}
return logL;
}
private double logMarginalDensity(final double lam, int nTaxa, final double h, int nClade, boolean forParent) {
double lgp;
final double lh = lam * h;
if (forParent) {
// n(n+1) factor left out
lgp = -2 * lh + Math.log(lam);
if (nClade > 1) {
lgp += (nClade - 1) * Math.log(1 - Math.exp(-lh));
}
} else {
assert nClade > 1;
lgp = -3 * lh + (nClade - 2) * Math.log(1 - Math.exp(-lh)) + Math.log(lam);
// root is a special case
if (nTaxa == nClade) {
// n(n-1) factor left out
lgp += lh;
} else {
// (n^3-n)/2 factor left out
}
}
return lgp;
}
private double logMarginalDensity(final double lam, final int nTaxa, double h2, final int n, double h1, int nm) {
assert h2 <= h1 && n < nm;
final int m = nm - n;
final double elh2 = Math.exp(-lam * h2);
final double elh1 = Math.exp(-lam * h1);
double lgl = 2 * Math.log(lam);
lgl += (n - 2) * Math.log(1 - elh2);
lgl += (m - 3) * Math.log(1 - elh1);
lgl += Math.log(1 - 2 * m * elh1 + 2 * (m - 1) * elh2
- m * (m - 1) * elh1 * elh2 + (m * (m + 1) / 2.) * elh1 * elh1
+ ((m - 1) * (m - 2) / 2.) * elh2 * elh2);
if (nm < nTaxa) {
/* lgl += Math.log(0.5*(n*(n*n-1))*(n+1+m)) */
lgl -= lam * (h2 + 3 * h1);
} else {
/* lgl += Math.log(lam) /* + Math.log(n*(n*n-1)) */
lgl -= lam * (h2 + 2 * h1);
}
return lgl;
}
private double logMarginalDensity(final double lam, double[] hs, int[] ranks, CalibrationLineagesIterator cli) {
final int ni = cli.setup(ranks);
final int nHeights = hs.length;
double[] lehs = new double[nHeights + 1];
lehs[0] = 0.0;
for (int i = 1; i < lehs.length; ++i) {
lehs[i] = -lam * hs[i - 1];
}
// assert maxRank == len(sit)
boolean noRoot = ni == lehs.length;
int nLevels = nHeights + (noRoot ? 1 : 0);
double[] lebase = new double[nLevels];
for (int i = 0; i < nHeights; ++i) {
lebase[i] = lehs[i] + Math.log1p(-Math.exp(lehs[i + 1] - lehs[i]));
}
if (noRoot) {
lebase[nHeights] = lehs[nHeights];
}
int[] linsAtLevel = new int[nLevels];
int[][] joiners = cli.allJoiners();
double val = 0;
boolean first = true;
int[][] linsInLevels;
int ccc = 0;
while ((linsInLevels = cli.next()) != null) {
ccc++;
double v = countRankedTrees(nLevels, linsInLevels, joiners, linsAtLevel);
// 1 for root formula, 1 for kludge in iterator which sets root as 2 lineages
if (noRoot) {
final int ll = linsAtLevel[nLevels - 1] + 2;
linsAtLevel[nLevels - 1] = ll;
v -= lc2[ll] + lg2;
}
for (int i = 0; i < nLevels; ++i) {
v += linsAtLevel[i] * lebase[i];
}
if (first) {
val = v;
first = false;
} else {
if (val > v) {
val += Math.log1p(Math.exp(v - val));
} else {
val = v + Math.log1p(Math.exp(val - v));
}
}
}
double logc0 = 0.0;
int totLin = 0;
for (int i = 0; i < ni; ++i) {
final int l = cli.nStart(i);
if (l > 0) {
logc0 += lNR[l];
totLin += l;
}
}
final double logc1 = lfactorials[totLin];
double logc2 = nHeights * Math.log(lam);
for (int i = 1; i < nHeights + 1; ++i) {
logc2 += lehs[i];
}
if (!noRoot) {
// we dont have an iterator for 0 free lineages
logc2 += 1 * lehs[nHeights];
}
// Missing scale by total of all possible trees over all ranking orders.
// Add it outside if needed for comparison.
val += logc0 + logc1 + logc2;
return val;
}
private double
countRankedTrees(final int nLevels, final int[][] linsAtCrossings, final int[][] joiners, int[] linsAtLevel) {
double logCount = 0;
for (int i = 0; i < nLevels; ++i) {
int sumLins = 0;
for (int k = i; k < nLevels; ++k) {
int[] lack = linsAtCrossings[k];
int cki = lack[i];
if (joiners[k][i] > 0) {
++cki;
if (cki > 1) {
// can be 1 if iterator without lins - for joiners only - need to check this is correct
logCount += lc2[cki];
} //assert(cki >= 2);
}
final int l = cki - lack[i + 1]; //assert(l >= 0);
logCount -= lfactorials[l];
sumLins += l;
}
linsAtLevel[i] = sumLins;
}
return logCount;
}
// Flavour of marginal computation.
private final CorrectionType correctionType;
// Calibrated clades, each as a list of node ids.
// Clades are partially ordered by inclusion - if X <= Y then X appears before Y.
private final int[][] clades;
// One calibration density associated with each clade
private final Distribution[] densities;
// true if density is for clade parent
private final boolean[] forParent;
// For each clade, lists the clades contained in it (using their index in clades)
private final int[][] taxaPartialOrder;
private final int[] freeHeights;
private final boolean rootCorrection;
// User provided function to calculate the marginal density
private final Statistic calibrationLogPDF;
// speedup constants
private final double lg2 = Math.log(2.0);
private double[] lc2;
private double[] lNR;
private double[] lfactorials;
private CalibrationLineagesIterator linsIter = null;
// simple cache of last result can go a long way in a big tree with a few calibration nodes, for non-global tree operators which do
// not change the calibration nodes heights.
double lastLam = Double.NEGATIVE_INFINITY;
double[] lastHeights;
double lastValue = Double.NEGATIVE_INFINITY;
}