package dr.evomodel.speciation;
import java.util.*;
import dr.evolution.tree.*;
import dr.evomodel.tree.TreeLogger;
import dr.evomodelxml.speciation.PopsIOSpeciesTreeModelParser;
import dr.inference.loggers.LogColumn;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import dr.inference.operators.OperatorFailedException;
import dr.inference.operators.Scalable;
import dr.util.AlloppMisc;
import dr.util.Author;
import dr.util.Citation;
import jebl.util.FixedBitSet;
import dr.evolution.util.Taxon;
import dr.math.MathUtils;
/**
* User: Graham
* Date: 10/05/12
*/
/*
nodes[] and root are the fundamental data here. They are used for
the SlidableTree implementation, not just supporting the MCMC move,
but also the calculations dealing with coalescences (compatibility
and likelihood).
oldtree stores the state so it can be restored after a rejected move. It
only stores taxa at tips, topology and node times. The other data (unions, lineages,
coalescences) are reconstructed from that.
stree is used to implement Tree, which is used for logging and for
calculation of the tree prior (eg Yule model). stree is a copy of
topology and node times from nodes[].
*/
public class PopsIOSpeciesTreeModel extends AbstractModel implements SlidableTree, Tree, Scalable, TreeLogger.LogUpon {
private PopsIOSpeciesBindings piosb;
private PriorComponent[] priorComponents;
private PopsIONode[] pionodes;
private int rootn;
private PopsIONode[] oldpionodes;
private int oldrootn;
private SimpleTree stree;
@Override
public boolean logNow(long state) {
// for debugging, set logEvery=0 in XML:
// <!-- species tree log file. -->
// <logTree id="pioTreeFileLog" logEvery="0" fileName="C:\Users\....
/*if (state == 40) {
System.out.println("DEBUGGING: PopsIOSpeciesTreeModel.logNow(), state == 40");
} */
if (state <= 100) {
return true;
}
if (state <= 10000) {
return (state % 100) == 0;
}
return (state % 10000) == 0;
}
public static class PriorComponent {
private double weight;
private double alpha;
private double beta;
// inv gamma pdf is parameterized as b^a/Gamma(a) x^(-a-1) exp(-b/x)
// mean is b/(a-1) if a>1, var is b^2/((a-1)^2 (a-2)) if a>2.
public PriorComponent(double weight, double alpha, double beta) {
this.weight = weight;
this.alpha = alpha;
this.beta = beta;
}
}
/*
* The parent, child, height, taxon fileds implement the basic binary tree.
* The other fields are 'working' fields.
*
* nodeNumber is required to implement NodeRef, needed for replaceSlidableRoot()
* (possibly one can manage this another way?)
*
* union, coalheights, nlineages are used for calculations: determining compatibility
* with gene trees and calculating likelihood.
*/
private class PopsIONode extends AlloppNode.Abstract implements AlloppNode, NodeRef {
private int anc;
private int lft;
private int rgt;
private double height;
private Taxon taxon;
private FixedBitSet union;
private ArrayList<Double> coalheights;
private int nlineages;
private int nodeNumber;
// dud constuctor
PopsIONode(int nn) {
anc = -1;
lft = -1;
rgt = -1;
height = -1.0;
coalheights = new ArrayList<Double>();
taxon = new Taxon("");
union = null;
nodeNumber = nn;
}
// copy constructor
public PopsIONode(PopsIONode node) {
anc = node.anc;
lft = node.lft;
rgt = node.rgt;
nodeNumber = node.nodeNumber;
copyNonTopologyFields(node);
}
private void copyNonTopologyFields(PopsIONode node) {
height = node.height;
taxon = new Taxon(node.taxon.getId());
nlineages = node.nlineages;
if (node.union == null) {
union = null;
} else {
union = new FixedBitSet(node.union);
}
coalheights = new ArrayList<Double>();
for (int i = 0; i < node.coalheights.size(); i++) {
coalheights.add(node.coalheights.get(i));
}
}
public String asText(int indentlen) {
StringBuilder s = new StringBuilder();
Formatter formatter = new Formatter(s, Locale.US);
if (lft < 0) {
formatter.format("%s ", taxon.getId());
} else {
formatter.format("%s ", "+");
}
while (s.length() < 20-indentlen) {
formatter.format("%s", " ");
}
formatter.format("%s ", AlloppMisc.nonnegIn8Chars(height));
formatter.format("%20s ", AlloppMisc.FixedBitSetasText(union));
formatter.format("%3d ", nlineages);
for (int c = 0; c < coalheights.size(); c++) {
formatter.format(AlloppMisc.nonnegIn8Chars(coalheights.get(c)) + ",");
}
return s.toString();
}
@Override
public int getNumber() {
return nodeNumber;
}
@Override
public void setNumber(int n) {
nodeNumber = n;
}
@Override
public int nofChildren() {
return (lft < 0) ? 0 : 2;
}
@Override
public AlloppNode getChild(int ch) {
return ch==0 ? pionodes[lft] : pionodes[rgt];
}
@Override
public AlloppNode getAnc() {
return pionodes[anc];
}
@Override
public Taxon getTaxon() {
return taxon;
}
@Override
public double getHeight() {
return height;
}
@Override
public FixedBitSet getUnion() {
return union;
}
@Override
public void setChild(int ch, AlloppNode newchild) {
int newch = ((PopsIONode)newchild).nodeNumber;
if (ch == 0) {
lft = newch;
} else {
rgt = newch;
}
}
@Override
public void setAnc(AlloppNode anc) {
this.anc = ((PopsIONode)anc).nodeNumber;
}
@Override
public void setTaxon(String name) {
this.taxon = new Taxon(name);
}
@Override
public void setHeight(double height) {
this.height = height;
}
@Override
public void setUnion(FixedBitSet union) {
this.union = union;
}
@Override
public void addChildren(AlloppNode c0, AlloppNode c1) {
lft = ((PopsIONode)c0).nodeNumber;
pionodes[lft].anc = nodeNumber;
rgt = ((PopsIONode)c1).nodeNumber;
pionodes[rgt].anc = nodeNumber;
}
}
public PopsIOSpeciesTreeModel(PopsIOSpeciesBindings piosb, PriorComponent[] priorComponents) {
super(PopsIOSpeciesTreeModelParser.PIO_SPECIES_TREE);
this.piosb = piosb;
this.priorComponents = priorComponents;
PopsIOSpeciesBindings.SpInfo[] species = piosb.getSpecies();
int nTaxa = species.length;
int nNodes = 2 * nTaxa - 1;
pionodes = new PopsIONode[nNodes];
for (int n = 0; n < nNodes; n++) {
pionodes[n] = new PopsIONode(n);
}
ArrayList<Integer> tojoin = new ArrayList<Integer>(nTaxa);
for (int n = 0; n < nTaxa; n++) {
pionodes[n].setTaxon(species[n].name);
pionodes[n].setHeight(0.0);
pionodes[n].setUnion(piosb.tipUnionFromTaxon(pionodes[n].getTaxon()));
tojoin.add(n);
}
double rate = 1.0;
double treeheight = 0.0;
for (int i = 0; i < nTaxa-1; i++) {
int numtojoin = tojoin.size();
int j = MathUtils.nextInt(numtojoin);
Integer child0 = tojoin.get(j);
tojoin.remove(j);
int k = MathUtils.nextInt(numtojoin-1);
Integer child1 = tojoin.get(k);
tojoin.remove(k);
pionodes[nTaxa+i].addChildren(pionodes[child0],pionodes[child1]);
pionodes[nTaxa+i].setHeight(treeheight + randomnodeheight(numtojoin*rate));
treeheight = pionodes[nTaxa+i].getHeight();
tojoin.add(nTaxa+i);
}
rootn = pionodes.length - 1;
double scale = 0.99 * piosb.initialMinGeneNodeHeight() / pionodes[rootn].height;
scaleAllHeights(scale);
pionodes[rootn].fillinUnionsInSubtree(piosb.getSpecies().length);
stree = makeSimpleTree();
}
public List<Citation> getCitations() {
List<Citation> citations = new ArrayList<Citation>();
citations.add(new Citation(
new Author[]{
new Author("Graham", "Jones")
},
"WORKING TITLE: A multi-species coalescent model with population parameters integrated out",
"??", // journal
Citation.Status.IN_PREPARATION
));
return citations;
}
public String toString() {
int ngt = piosb.numberOfGeneTrees();
String nl = System.getProperty("line.separator");
String s = nl + pioTreeAsText() + nl;
for (int g = 0; g < ngt; g++) {
s += "Gene tree " + g + nl;
s += piosb.genetreeAsText(g) + nl;
}
s += nl;
return s;
}
public LogColumn[] getColumns() {
LogColumn[] columns = new LogColumn[1];
columns[0] = new LogColumn.Default(" species-tree and gene trees", this);
return columns;
}
private int scaleAllHeights(double scale) {
for (int nn = 0; nn < pionodes.length; nn++) {
pionodes[nn].height *= scale;
}
return pionodes.length;
}
public String pioTreeAsText() {
String header = "topology height union nlin coalheights" + System.getProperty("line.separator");
String s = "";
Stack<Integer> x = new Stack<Integer>();
return header + subtreeAsText(pionodes[rootn], s, x, 0, "");
}
/*
* Called from PopsIOSpeciesBindings to check if a node in a gene tree
* is compatible with the network.
*/
public boolean coalescenceIsCompatible(double height, FixedBitSet union) {
PopsIONode node = (PopsIONode) pionodes[rootn].nodeOfUnionInSubtree(union);
return (node.height <= height);
}
/*
* Called from PopsIOSpeciesBindings to remove coalescent information
* from branches of mullabtree. Required before call to recordCoalescence
*/
public void clearCoalescences() {
clearSubtreeCoalescences(pionodes[rootn]);
}
/*
* Called from PopsIOSpeciesBindings to add a node from a gene tree
* to its branch in mullabtree.
*/
public void recordCoalescence(double height, FixedBitSet union) {
PopsIONode node = (PopsIONode) pionodes[rootn].nodeOfUnionInSubtree(union);
assert (node.height <= height);
while (node.anc >= 0 && pionodes[node.anc].height <= height) {
node = pionodes[node.anc];
}
node.coalheights.add(height);
}
public void sortCoalescences() {
for (PopsIONode node : pionodes) {
Collections.sort(node.coalheights);
}
}
/*
* Records the number of gene lineages at nodes of mullabtree.
*/
public void recordLineageCounts() {
recordSubtreeLineageCounts(pionodes[rootn]);
}
public void fixupAfterNodeSlide() {
stree = makeSimpleTree();
}
/*
* Calculates the log-likelihood for a single gene tree in the network
*
* Requires that clearCoalescences(), recordCoalescence(), recordLineageCounts()
* called to fill tree in nodes[] with information about gene tree coalescences first.
*/
/*
* The formula comes from my note at http://www.indriid.com/goteborg/2011-09-23-simple-pop-model.pdf
* See branchLL() for more.
*/
public double geneTreeInSpeciesTreeLogLikelihood() {
return geneTreeInPopsIOSubtreeLogLikelihood(pionodes[rootn]);
}
private String subtreeAsText(PopsIONode node, String s, Stack<Integer> x, int depth, String b) {
Integer[] y = x.toArray(new Integer[x.size()]);
StringBuffer indent = new StringBuffer();
for (int i = 0; i < depth; i++) {
indent.append(" ");
}
for (int i = 0; i < y.length; i++) {
indent.replace(2*y[i], 2*y[i]+1, "|");
}
if (b.length() > 0) {
indent.replace(indent.length()-b.length(), indent.length(), b);
}
s += indent;
s += node.asText(indent.length());
s += System.getProperty("line.separator");
String subs = "";
if (node.lft >= 0) {
x.push(depth);
subs += subtreeAsText(pionodes[node.lft], "", x, depth+1, "-");
x.pop();
subs += subtreeAsText(pionodes[node.rgt], "", x, depth+1, "`-");
}
return s + subs;
}
private double geneTreeInPopsIOSubtreeLogLikelihood(PopsIONode node) {
double loglike = 0.0;
if (node.lft >= 0) {
loglike += geneTreeInPopsIOSubtreeLogLikelihood(pionodes[node.lft]);
loglike += geneTreeInPopsIOSubtreeLogLikelihood(pionodes[node.rgt]);
}
loglike += branchLLInPopsIOtree(node);
return loglike;
}
private double branchLLInPopsIOtree(PopsIONode node) {
double loglike = 0.0;
double t[];
if (node.anc < 0) {
t = new double[node.coalheights.size() + 2];
t[0] = node.height;
t[t.length - 1] = piosb.maxGeneTreeHeight();
for (int i = 0; i < node.coalheights.size(); i++) {
t[i + 1] = node.coalheights.get(i);
}
loglike += branchLL(t, node.nlineages);
} else {
t = new double[node.coalheights.size() + 2];
t[0] = node.height;
t[t.length - 1] = pionodes[node.anc].height;
for (int i = 0; i < node.coalheights.size(); i++) {
t[i + 1] = node.coalheights.get(i);
}
loglike += branchLL(t, node.nlineages);
}
return loglike;
}
/*
* For one branch with tipward time t[0], rootward time t[k+1], k-1 coalescent times t[1]...t[k],
* and n lineages at tipward end, set
* x = sum from i=0 to k of
* ((n-i) choose 2) * (t[i+1]-t[i])
* Then sum over j (j is component index) of
* weight[j] * b[j]^a[i] * (b[j] + x)^-(a[j]+k+1) * GAMMA(a[j]+k+1) / GAMMA(a[j])
*
* G(z+1) = zG(z)
* GAMMA(a[j]+k+1) = (a[j]+k)GAMMA(a[j]+k)
* = (a[j]+k)(a[j]+k-1)GAMMA(a[j]+k-1)
* = ...
* = (a[j]+k)(a[j]+k-1)...a[j]GAMMA(a[j])
* so GAMMA(a[j]+k+1) / GAMMA(a[j] = (a[j]+k)(a[j]+k-1)...a[j]
*/
private double branchLL(double t[], int n) {
double lhood = 0.0;
double x = 0.0;
int k = t.length - 2;
for (int i = 0; i <= k; i++) {
x += (t[i+1] - t[i]) * 0.5*(n-i)*(n-i-1);
}
for (int j = 0; j < priorComponents.length; j++) {
double w = priorComponents[j].weight;
double a = priorComponents[j].alpha;
double b = priorComponents[j].beta;
double G = 1.0;
for (int i = 0; i <= k; i++) {
G *= (a+i);
}
lhood += w * Math.pow(a,b) * Math.pow(b+x, -(a+k+1)) * G;
}
return Math.log(lhood);
}
private SimpleTree makeSimpleTree() {
SimpleNode[] snodes = new SimpleNode[pionodes.length];
for (int n = 0; n < pionodes.length; n++) {
snodes[n] = new SimpleNode();
snodes[n].setTaxon(null);
}
makesimplesubtree(snodes, 0, pionodes[rootn]);
return new SimpleTree(snodes[pionodes.length-1]);
}
// for makeSimpleTree()
private int makesimplesubtree(SimpleNode[] snodes, int nextsn, PopsIONode pionode) {
if (pionode.lft < 0) {
Taxon tx = new Taxon(pionode.taxon.getId());
if (nextsn >= snodes.length) {
System.out.println("BUG: makesimplesubtree()");
}
snodes[nextsn].setTaxon(tx);
} else {
nextsn = makesimplesubtree(snodes, nextsn, pionodes[pionode.lft]);
int subtree0 = nextsn-1;
nextsn = makesimplesubtree(snodes, nextsn, pionodes[pionode.rgt]);
int subtree1 = nextsn-1;
snodes[nextsn].addChild(snodes[subtree0]);
snodes[nextsn].addChild(snodes[subtree1]);
}
snodes[nextsn].setHeight(pionode.height);
return nextsn+1;
}
private void clearSubtreeCoalescences(PopsIONode node) {
if (node.lft >= 0) {
clearSubtreeCoalescences(pionodes[node.lft]);
clearSubtreeCoalescences(pionodes[node.rgt]);
}
if (node == null) {
System.out.println("BUG");
}
if (node.coalheights == null) {
System.out.println("BUG");
}
node.coalheights.clear();
}
private void recordSubtreeLineageCounts(PopsIONode node) {
if (node.lft < 0) {
int spIndex = piosb.speciesId2index(node.getTaxon().getId());
node.nlineages = piosb.nLineages(spIndex);
} else {
node.nlineages = 0;
recordSubtreeLineageCounts(pionodes[node.lft]);
node.nlineages += pionodes[node.lft].nlineages - pionodes[node.lft].coalheights.size();
recordSubtreeLineageCounts(pionodes[node.rgt]);
node.nlineages += pionodes[node.rgt].nlineages - pionodes[node.rgt].coalheights.size();
}
}
private double randomnodeheight(double rate) {
return MathUtils.nextExponential(rate) + 1e-6/rate;
// 1e-6/rate to avoid very tiny heights
}
/*******************************************************************************/
// for Scalable.
/*******************************************************************************/
@Override
public int scale(double factor, int nDims) throws OperatorFailedException {
int n = scaleAllHeights(factor);
stree = makeSimpleTree();
return n;
}
@Override
public String getName() {
return PopsIOSpeciesTreeModelParser.PIO_SPECIES_TREE;
}
/*******************************************************************************/
// for AbstractModel.
/*******************************************************************************/
@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {
fireModelChanged();
}
@Override
protected void handleVariableChangedEvent(Variable variable, int index, Variable.ChangeType type) {
}
@Override
protected void storeState() {
oldpionodes = new PopsIONode[pionodes.length];
for (int n = 0; n < oldpionodes.length; n++) {
oldpionodes[n] = new PopsIONode(pionodes[n]);
}
oldrootn = rootn;
}
@Override
protected void restoreState() {
pionodes = new PopsIONode[oldpionodes.length];
for (int n = 0; n < pionodes.length; n++) {
pionodes[n] = new PopsIONode(oldpionodes[n]);
}
rootn = oldrootn;
stree = makeSimpleTree();
}
@Override
protected void acceptState() {
}
/*******************************************************************************/
// for SlidableTree.
/*******************************************************************************/
@Override
public NodeRef getSlidableRoot() {
assert pionodes[rootn].anc < 0;
return pionodes[rootn];
}
@Override
public void replaceSlidableRoot(NodeRef newroot) {
rootn = newroot.getNumber();
pionodes[rootn].anc = -1;
}
@Override
public int getSlidableNodeCount() {
return pionodes.length;
}
@Override
public Taxon getSlidableNodeTaxon(NodeRef node) {
assert node == pionodes[node.getNumber()];
return ((PopsIONode)node).getTaxon();
}
@Override
public double getSlidableNodeHeight(NodeRef node) {
assert node == pionodes[node.getNumber()];
return ((PopsIONode)node).getHeight();
}
@Override
public void setSlidableNodeHeight(NodeRef node, double height) {
assert node == pionodes[node.getNumber()];
((PopsIONode)node).height = height;
}
@Override
public boolean isExternalSlidable(NodeRef node) {
return (pionodes[node.getNumber()].lft < 0);
}
@Override
public NodeRef getSlidableChild(NodeRef node, int j) {
int n = node.getNumber();
return j == 0 ? pionodes[ pionodes[n].lft ] : pionodes[ pionodes[n].rgt ];
}
@Override
public void replaceSlidableChildren(NodeRef node, NodeRef lft, NodeRef rgt) {
int nn = node.getNumber();
int lftn = lft.getNumber();
int rgtn = rgt.getNumber();
assert pionodes[nn].lft >= 0;
pionodes[nn].lft = lftn;
pionodes[nn].rgt = rgtn;
pionodes[lftn].anc = pionodes[nn].nodeNumber;
pionodes[rgtn].anc = pionodes[nn].nodeNumber;
}
/*******************************************************************************/
// For Tree
/*******************************************************************************/
@Override
public NodeRef getRoot() {
return stree.getRoot();
}
@Override
public int getNodeCount() {
return stree.getNodeCount();
}
@Override
public NodeRef getNode(int i) {
return stree.getNode(i);
}
@Override
public NodeRef getInternalNode(int i) {
return stree.getInternalNode(i);
}
@Override
public NodeRef getExternalNode(int i) {
return stree.getExternalNode(i);
}
@Override
public int getExternalNodeCount() {
return stree.getExternalNodeCount();
}
@Override
public int getInternalNodeCount() {
return stree.getInternalNodeCount();
}
@Override
public Taxon getNodeTaxon(NodeRef node) {
return stree.getNodeTaxon(node);
}
@Override
public boolean hasNodeHeights() {
return stree.hasNodeHeights();
}
@Override
public double getNodeHeight(NodeRef node) {
return stree.getNodeHeight(node);
}
@Override
public boolean hasBranchLengths() {
return stree.hasBranchLengths();
}
@Override
public double getBranchLength(NodeRef node) {
return stree.getBranchLength(node);
}
@Override
public double getNodeRate(NodeRef node) {
return stree.getNodeRate(node);
}
@Override
public Object getNodeAttribute(NodeRef node, String name) {
return stree.getNodeAttribute(node, name);
}
@Override
public Iterator getNodeAttributeNames(NodeRef node) {
return stree.getNodeAttributeNames(node);
}
@Override
public boolean isExternal(NodeRef node) {
return stree.isExternal(node);
}
@Override
public boolean isRoot(NodeRef node) {
return stree.isRoot(node);
}
@Override
public int getChildCount(NodeRef node) {
return stree.getChildCount(node);
}
@Override
public NodeRef getChild(NodeRef node, int j) {
return stree.getChild(node, j);
}
@Override
public NodeRef getParent(NodeRef node) {
return stree.getParent(node);
}
@Override
public Tree getCopy() {
return stree.getCopy();
}
@Override
public void setAttribute(String name, Object value) {
stree.setAttribute(name, value);
}
@Override
public Object getAttribute(String name) {
return stree.getAttribute(name);
}
@Override
public Iterator<String> getAttributeNames() {
return stree.getAttributeNames();
}
@Override
public int getTaxonCount() {
return stree.getTaxonCount();
}
@Override
public Taxon getTaxon(int taxonIndex) {
return stree.getTaxon(taxonIndex);
}
@Override
public String getTaxonId(int taxonIndex) {
return stree.getTaxonId(taxonIndex);
}
@Override
public int getTaxonIndex(String id) {
return stree.getTaxonIndex(id);
}
@Override
public int getTaxonIndex(Taxon taxon) {
return stree.getTaxonIndex(taxon);
}
@Override
public List<Taxon> asList() {
return stree.asList();
}
@Override
public Object getTaxonAttribute(int taxonIndex, String name) {
return stree.getTaxonAttribute(taxonIndex, name);
}
@Override
public Iterator<Taxon> iterator() {
return stree.iterator();
}
@Override
public Type getUnits() {
return stree.getUnits();
}
@Override
public void setUnits(Type units) {
stree.setUnits(units);
}
}