package beast.evolution.tree.coalescent;
import java.util.ArrayList;
import java.util.List;
import beast.core.CalculationNode;
import beast.core.Description;
import beast.core.Input;
import beast.core.Input.Validate;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.util.HeapSort;
/*
* TreeIntervals.java
*
* Copyright (C) 2002-2006 Alexei Drummond and Andrew Rambaut
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
/**
* Extracts the intervals from a beast.tree.
*
* @author Andrew Rambaut
* @author Alexei Drummond
* @version $Id: TreeIntervals.java,v 1.9 2005/05/24 20:25:56 rambaut Exp $
*/
@Description("Extracts the intervals from a tree. Points in the intervals " +
"are defined by the heights of nodes in the tree.")
public class TreeIntervals extends CalculationNode implements IntervalList {
final public Input<Tree> treeInput = new Input<>("tree", "tree for which to calculate the intervals", Validate.REQUIRED);
public TreeIntervals() {
super();
}
public TreeIntervals(Tree tree) {
init(tree);
}
@Override
public void initAndValidate() {
// this initialises data structures that store/restore might need
calculateIntervals();
intervalsKnown = false;
}
/**
* CalculationNode methods *
*/
@Override
protected boolean requiresRecalculation() {
// we only get here if the tree is dirty, which is a StateNode
// since the StateNode can only become dirty through an operation,
// we need to recalculate tree intervals
intervalsKnown = false;
return true;
}
@Override
protected void restore() {
//intervalsKnown = false;
double[] tmp = storedIntervals;
storedIntervals = intervals;
intervals = tmp;
int[] tmp2 = storedLineageCounts;
storedLineageCounts = lineageCounts;
lineageCounts = tmp2;
int tmp3 = storedIntervalCount;
storedIntervalCount = intervalCount;
intervalCount = tmp3;
super.restore();
}
@Override
protected void store() {
System.arraycopy(lineageCounts, 0, storedLineageCounts, 0, lineageCounts.length);
System.arraycopy(intervals, 0, storedIntervals, 0, intervals.length);
storedIntervalCount = intervalCount;
super.store();
}
/**
* Specifies that the intervals are unknown (i.e., the beast.tree has changed).
*/
public void setIntervalsUnknown() {
intervalsKnown = false;
}
/**
* Sets the limit for which adjacent events are merged.
*
* @param multifurcationLimit A value of 0 means merge addition of leafs (terminal nodes) when possible but
* return each coalescense as a separate event.
*/
public void setMultifurcationLimit(double multifurcationLimit) {
// invalidate only if changing anything
if (this.multifurcationLimit != multifurcationLimit) {
this.multifurcationLimit = multifurcationLimit;
intervalsKnown = false;
}
}
@Override
public int getSampleCount() {
// Assumes a binary tree!
return treeInput.get().getInternalNodeCount();
}
/**
* get number of intervals
*/
@Override
public int getIntervalCount() {
if (!intervalsKnown) {
calculateIntervals();
}
return intervalCount;
}
/**
* Gets an interval.
*/
@Override
public double getInterval(int i) {
if (!intervalsKnown) {
calculateIntervals();
}
if (i < 0 || i >= intervalCount) throw new IllegalArgumentException();
return intervals[i];
}
/**
* Defensive implementation creates copy
*
* @return
*/
public double[] getIntervals(double[] inters) {
if (!intervalsKnown) {
calculateIntervals();
}
if (inters == null) inters = new double[intervals.length];
System.arraycopy(intervals, 0, inters, 0, intervals.length);
return inters;
}
public double[] getCoalescentTimes(double[] coalescentTimes) {
if (!intervalsKnown) {
calculateIntervals();
}
if (coalescentTimes == null) coalescentTimes = new double[getSampleCount()];
double time = 0;
int coalescentIndex = 0;
for (int i = 0; i < intervals.length; i++) {
time += intervals[i];
for (int j = 0; j < getCoalescentEvents(i); j++) {
coalescentTimes[coalescentIndex] = time;
coalescentIndex += 1;
}
}
return coalescentTimes;
}
/**
* Returns the number of uncoalesced lineages within this interval.
* Required for s-coalescents, where new lineages are added as
* earlier samples are come across.
*/
@Override
public int getLineageCount(int i) {
if (!intervalsKnown) {
calculateIntervals();
}
if (i >= intervalCount) throw new IllegalArgumentException();
return lineageCounts[i];
}
/**
* @param interval the index of the interval
* @return a list of the nodes representing the lineages in the ith interval.
*/
// public final List<Node> getLineages(int interval) {
//
// if (lineages[interval] == null) {
//
// List<Node> lines = new ArrayList<>();
// for (int i = 0; i <= interval; i++) {
// if (lineagesAdded[i] != null) lines.addAll(lineagesAdded[i]);
// if (lineagesRemoved[i] != null) lines.removeAll(lineagesRemoved[i]);
// }
// lineages[interval] = Collections.unmodifiableList(lines);
//
// }
// return lineages[interval];
// }
/**
* Returns the number of coalescent events in an interval
*/
@Override
public int getCoalescentEvents(int i) {
if (!intervalsKnown) {
calculateIntervals();
}
if (i >= intervalCount) throw new IllegalArgumentException();
if (i < intervalCount - 1) {
return lineageCounts[i] - lineageCounts[i + 1];
} else {
return lineageCounts[i] - 1;
}
}
/**
* Returns the type of interval observed.
*/
@Override
public IntervalType getIntervalType(int i) {
if (!intervalsKnown) {
calculateIntervals();
}
if (i >= intervalCount) throw new IllegalArgumentException();
int numEvents = getCoalescentEvents(i);
if (numEvents > 0) return IntervalType.COALESCENT;
else if (numEvents < 0) return IntervalType.SAMPLE;
else return IntervalType.NOTHING;
}
// public Node getCoalescentNode(int interval) {
// if (getIntervalType(interval) == IntervalType.COALESCENT) {
// if (lineagesRemoved[interval] != null) {
// if (lineagesRemoved[interval].size() == 1) {
// return lineagesRemoved[interval].get(0);
// } else throw new IllegalArgumentException("multiple lineages lost over this interval!");
// } else throw new IllegalArgumentException("Inconsistent: no intervals lost over this interval!");
// } else throw new IllegalArgumentException("Interval " + interval + " is not a coalescent interval.");
// }
/**
* get the total height of the genealogy represented by these
* intervals.
*/
@Override
public double getTotalDuration() {
if (!intervalsKnown) {
calculateIntervals();
}
double height = 0.0;
for (int j = 0; j < intervalCount; j++) {
height += intervals[j];
}
return height;
}
/**
* Checks whether this set of coalescent intervals is fully resolved
* (i.e. whether is has exactly one coalescent event in each
* subsequent interval)
*/
@Override
public boolean isBinaryCoalescent() {
if (!intervalsKnown) {
calculateIntervals();
}
for (int i = 0; i < intervalCount; i++) {
if (getCoalescentEvents(i) > 0) {
if (getCoalescentEvents(i) != 1) return false;
}
}
return true;
}
/**
* Checks whether this set of coalescent intervals coalescent only
* (i.e. whether is has exactly one or more coalescent event in each
* subsequent interval)
*/
@Override
public boolean isCoalescentOnly() {
if (!intervalsKnown) {
calculateIntervals();
}
for (int i = 0; i < intervalCount; i++) {
if (getCoalescentEvents(i) < 1) return false;
}
return true;
}
/**
* Recalculates all the intervals for the given beast.tree.
*/
@SuppressWarnings("unchecked")
protected void calculateIntervals() {
Tree tree = treeInput.get();
final int nodeCount = tree.getNodeCount();
times = new double[nodeCount];
int[] childCounts = new int[nodeCount];
collectTimes(tree, times, childCounts);
indices = new int[nodeCount];
HeapSort.sort(times, indices);
if (intervals == null || intervals.length != nodeCount) {
intervals = new double[nodeCount];
lineageCounts = new int[nodeCount];
lineagesAdded = new List[nodeCount];
lineagesRemoved = new List[nodeCount];
// lineages = new List[nodeCount];
storedIntervals = new double[nodeCount];
storedLineageCounts = new int[nodeCount];
} else {
for (List<Node> l : lineagesAdded) {
if (l != null) {
l.clear();
}
}
for (List<Node> l : lineagesRemoved) {
if (l != null) {
l.clear();
}
}
}
// start is the time of the first tip
double start = times[indices[0]];
int numLines = 0;
int nodeNo = 0;
intervalCount = 0;
while (nodeNo < nodeCount) {
int lineagesRemoved = 0;
int lineagesAdded = 0;
double finish = times[indices[nodeNo]];
double next;
do {
final int childIndex = indices[nodeNo];
final int childCount = childCounts[childIndex];
// don't use nodeNo from here on in do loop
nodeNo += 1;
if (childCount == 0) {
addLineage(intervalCount, tree.getNode(childIndex));
lineagesAdded += 1;
} else {
lineagesRemoved += (childCount - 1);
// record removed lineages
final Node parent = tree.getNode(childIndex);
//assert childCounts[indices[nodeNo]] == beast.tree.getChildCount(parent);
//for (int j = 0; j < lineagesRemoved + 1; j++) {
for (int j = 0; j < childCount; j++) {
Node child = j == 0 ? parent.getLeft() : parent.getRight();
removeLineage(intervalCount, child);
}
// record added lineages
addLineage(intervalCount, parent);
// no mix of removed lineages when 0 th
if (multifurcationLimit == 0.0) {
break;
}
}
if (nodeNo < nodeCount) {
next = times[indices[nodeNo]];
} else break;
} while (Math.abs(next - finish) <= multifurcationLimit);
if (lineagesAdded > 0) {
if (intervalCount > 0 || ((finish - start) > multifurcationLimit)) {
intervals[intervalCount] = finish - start;
lineageCounts[intervalCount] = numLines;
intervalCount += 1;
}
start = finish;
}
// add sample event
numLines += lineagesAdded;
if (lineagesRemoved > 0) {
intervals[intervalCount] = finish - start;
lineageCounts[intervalCount] = numLines;
intervalCount += 1;
start = finish;
}
// coalescent event
numLines -= lineagesRemoved;
}
intervalsKnown = true;
}
/**
* Returns the time of the start of an interval
*
* @param i which interval
* @return start time
*/
public double getIntervalTime(int i) {
if (!intervalsKnown) {
calculateIntervals();
}
return times[indices[i]];
}
protected void addLineage(int interval, Node node) {
if (lineagesAdded[interval] == null) lineagesAdded[interval] = new ArrayList<>();
lineagesAdded[interval].add(node);
}
protected void removeLineage(int interval, Node node) {
if (lineagesRemoved[interval] == null) lineagesRemoved[interval] = new ArrayList<>();
lineagesRemoved[interval].add(node);
}
/**
* @return the delta parameter of Pybus et al (Node spread statistic)
*/
public double getDelta() {
return IntervalList.Utils.getDelta(this);
}
/**
* extract coalescent times and tip information into array times from beast.tree.
*
* @param tree the beast.tree
* @param times the times of the nodes in the beast.tree
* @param childCounts the number of children of each node
*/
protected static void collectTimes(Tree tree, double[] times, int[] childCounts) {
Node[] nodes = tree.getNodesAsArray();
for (int i = 0; i < nodes.length; i++) {
Node node = nodes[i];
times[i] = node.getHeight();
childCounts[i] = node.isLeaf() ? 0 : 2;
}
}
/**
* The beast.tree. RRB: not a good idea to keep a copy around, since it changes all the time.
*/
// private Tree tree = null;
/**
* The widths of the intervals.
*/
protected double[] intervals;
protected double[] storedIntervals;
/** interval times **/
double[] times;
int[] indices;
/**
* The number of uncoalesced lineages within a particular interval.
*/
protected int[] lineageCounts;
protected int[] storedLineageCounts;
/**
* The lineages in each interval (stored by node ref).
*/
protected List<Node>[] lineagesAdded;
protected List<Node>[] lineagesRemoved;
// private List<Node>[] lineages;
protected int intervalCount = 0;
protected int storedIntervalCount = 0;
/**
* are the intervals known?
*/
protected boolean intervalsKnown = false;
protected double multifurcationLimit = -1.0;
}