package beast.evolution.speciation;
import static java.lang.Math.abs;
import static java.lang.Math.max;
import static java.lang.Math.min;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math.MathException;
import beast.core.*;
import beast.core.Input.Validate;
import beast.core.parameter.RealParameter;
import beast.evolution.alignment.Alignment;
import beast.evolution.alignment.Taxon;
import beast.evolution.alignment.TaxonSet;
import beast.evolution.alignment.distance.Distance;
import beast.evolution.alignment.distance.JukesCantorDistance;
import beast.evolution.tree.Node;
import beast.evolution.tree.RandomTree;
import beast.evolution.tree.Tree;
import beast.evolution.tree.coalescent.ConstantPopulation;
import beast.math.distributions.MRCAPrior;
import beast.util.ClusterTree;
/**
* @author Joseph Heled
*/
@Description("Set a starting point for a *BEAST analysis from gene alignment data.")
public class StarBeastStartState extends Tree implements StateNodeInitialiser {
static enum Method {
POINT("point-estimate"),
ALL_RANDOM("random");
Method(final String name) {
this.ename = name;
}
@Override
public String toString() {
return ename;
}
private final String ename;
}
final public Input<Method> initMethod = new Input<>("method", "Initialise either with a totally random " +
"state or a point estimate based on alignments data (default point-estimate)",
Method.POINT, Method.values());
final public Input<Tree> speciesTreeInput = new Input<>("speciesTree", "The species tree to initialize");
final public Input<List<Tree>> genes = new Input<>("gene", "Gene trees to initialize", new ArrayList<>());
//,
// Validate.REQUIRED);
final public Input<CalibratedYuleModel> calibratedYule = new Input<>("calibratedYule",
"The species tree (with calibrations) to initialize", Validate.XOR, speciesTreeInput);
final public Input<RealParameter> popMean = new Input<>("popMean",
"Population mean hyper prior to initialse");
final public Input<RealParameter> birthRate = new Input<>("birthRate",
"Tree prior birth rate to initialize");
final public Input<SpeciesTreePrior> speciesTreePriorInput =
new Input<>("speciesTreePrior", "Population size parameters to initialise");
final public Input<Function> muInput = new Input<>("baseRate",
"Main clock rate used to scale trees (default 1).");
private boolean hasCalibrations;
@Override
public void initAndValidate() {
// what does this do and is it dangerous to call it or not to call it at the start or at the end??????
super.initAndValidate();
hasCalibrations = calibratedYule.get() != null;
}
@Override
public void initStateNodes() {
final Set<BEASTInterface> treeOutputs = speciesTreeInput.get().getOutputs();
List<MRCAPrior> calibrations = new ArrayList<>();
for (final Object plugin : treeOutputs ) {
if( plugin instanceof MRCAPrior ) {
calibrations.add((MRCAPrior) plugin);
}
}
if( hasCalibrations ) {
if( calibrations.size() > 0 ) {
throw new IllegalArgumentException("Not implemented: mix of calibrated yule and MRCA priors: " +
"place all priors in the calibrated Yule");
}
try {
initWithCalibrations();
} catch (MathException e) {
throw new IllegalArgumentException(e);
}
} else {
if( calibrations.size() > 0 ) {
initWithMRCACalibrations(calibrations);
return;
}
final Method method = initMethod.get();
switch( method ) {
case POINT:
fullInit();
break;
case ALL_RANDOM:
randomInit();
break;
}
}
}
private double[] firstMeetings(final Tree gtree, final Map<String, Integer> tipName2Species, final int speciesCount) {
final Node[] nodes = gtree.listNodesPostOrder(null, null);
@SuppressWarnings("unchecked")
final Set<Integer>[] tipsSpecies = new Set[nodes.length];
for(int k = 0; k < tipsSpecies.length; ++k) {
tipsSpecies[k] = new HashSet<>();
}
// d[i,j] = minimum height of node which has tips belonging to species i and j
// d is is upper triangular
final double[] dmin = new double[(speciesCount*(speciesCount-1))/2];
Arrays.fill(dmin, Double.MAX_VALUE);
for (final Node n : nodes) {
if (n.isLeaf()) {
tipsSpecies[n.getNr()].add(tipName2Species.get(n.getID()));
} else {
assert n.getChildCount() == 2;
@SuppressWarnings("unchecked")
final Set<Integer>[] sps = new Set[2];
sps[0] = tipsSpecies[n.getChild(0).getNr()];
sps[1] = tipsSpecies[n.getChild(1).getNr()];
final Set<Integer> u = new HashSet<>(sps[0]);
u.retainAll(sps[1]);
sps[0].removeAll(u);
sps[1].removeAll(u);
for (final Integer s1 : sps[0]) {
for (final Integer s2 : sps[1]) {
final int i = getDMindex(speciesCount, s1, s2);
dmin[i] = min(dmin[i], n.getHeight());
}
}
u.addAll(sps[0]);
u.addAll(sps[1]);
tipsSpecies[n.getNr()] = u;
}
}
return dmin;
}
private int getDMindex(final int speciesCount, final int s1, final int s2) {
final int mij = min(s1,s2);
return (mij*(2*speciesCount-1 - mij))/2 + (abs(s1-s2)-1);
}
private void fullInit() {
// Build gene trees from alignments
final Function muInput = this.muInput.get();
final double mu = (muInput != null ) ? muInput.getArrayValue() : 1;
final Tree stree = speciesTreeInput.get();
final TaxonSet species = stree.m_taxonset.get();
final List<String> speciesNames = species.asStringList();
final int speciesCount = speciesNames.size();
final List<Tree> geneTrees = genes.get();
//final List<Alignment> alignments = genes.get();
//final List<Tree> geneTrees = new ArrayList<>(alignments.size());
double maxNsites = 0;
//for( final Alignment alignment : alignments) {
for (final Tree gtree : geneTrees) {
//final Tree gtree = new Tree();
final Alignment alignment = gtree.m_taxonset.get().alignmentInput.get();
final ClusterTree ctree = new ClusterTree();
ctree.initByName("initial", gtree, "clusterType", "upgma", "taxa", alignment);
gtree.scale(1 / mu);
maxNsites = max(maxNsites, alignment.getSiteCount());
}
final Map<String, Integer> geneTips2Species = new HashMap<>();
final List<Taxon> taxonSets = species.taxonsetInput.get();
for(int k = 0; k < speciesNames.size(); ++k) {
final Taxon nx = taxonSets.get(k);
final List<Taxon> taxa = ((TaxonSet) nx).taxonsetInput.get();
for( final Taxon n : taxa ) {
geneTips2Species.put(n.getID(), k);
}
}
final double[] dg = new double[(speciesCount*(speciesCount-1))/2];
final double[][] genesDmins = new double[geneTrees.size()][];
for( int ng = 0; ng < geneTrees.size(); ++ng ) {
final Tree g = geneTrees.get(ng);
final double[] dmin = firstMeetings(g, geneTips2Species, speciesCount);
genesDmins[ng] = dmin;
for(int i = 0; i < dmin.length; ++i) {
dg[i] += dmin[i];
if (dmin[i] == Double.MAX_VALUE) {
// this happens when a gene tree has no taxa for some species-tree taxon.
// TODO: ensure that if this happens, there will always be an "infinite"
// distance between species-taxon 0 and the species-taxon with missing lineages,
// so i < speciesCount - 1.
// What if lineages for species-taxon 0 are missing? Then all entries will be 'infinite'.
String id = (i < speciesCount - 1? stree.getExternalNodes().get(i+1).getID() : "unknown taxon");
if (i == 0) {
// test that all entries are 'infinite', which implies taxon 0 has lineages missing
boolean b = true;
for (int k = 1; b && k < speciesCount - 1; k++) {
b = (dmin[k] == Double.MAX_VALUE);
}
if (b) {
// if all entries have 'infinite' distances, it is probably the first taxon that is at fault
id = stree.getExternalNodes().get(0).getID();
}
}
throw new RuntimeException("Gene tree " + g.getID() + " has no lineages for species taxon " + id + " ");
}
}
}
for(int i = 0; i < dg.length; ++i) {
double d = dg[i] / geneTrees.size();
if( d == 0 ) {
d = (0.5/maxNsites) * (1/mu);
} else {
// heights to distances
d *= 2;
}
dg[i] = d;
}
final ClusterTree ctree = new ClusterTree();
final Distance distance = new Distance() {
@Override
public double pairwiseDistance(final int s1, final int s2) {
final int i = getDMindex(speciesCount, s1,s2);
return dg[i];
}
};
ctree.initByName("initial", stree, "taxonset", species,"clusterType", "upgma", "distance", distance);
final Map<String, Integer> sptips2SpeciesIndex = new HashMap<>();
for(int i = 0; i < speciesNames.size(); ++i) {
sptips2SpeciesIndex.put(speciesNames.get(i), i);
}
final double[] spmin = firstMeetings(stree, sptips2SpeciesIndex, speciesCount);
for( int ng = 0; ng < geneTrees.size(); ++ng ) {
final double[] dmin = genesDmins[ng];
boolean compatible = true;
for(int i = 0; i < spmin.length; ++i) {
if( dmin[i] <= spmin[i] ) {
compatible = false;
break;
}
}
if( ! compatible ) {
final Tree gtree = geneTrees.get(ng);
final TaxonSet gtreeTaxa = gtree.m_taxonset.get();
final Alignment alignment = gtreeTaxa.alignmentInput.get();
final List<String> taxaNames = alignment.getTaxaNames();
final int taxonCount = taxaNames.size();
// speedup
final Map<Integer,Integer> g2s = new HashMap<>();
for(int i = 0; i < taxonCount; ++i) {
g2s.put(i, geneTips2Species.get(taxaNames.get(i)));
}
final JukesCantorDistance jc = new JukesCantorDistance();
jc.setPatterns(alignment);
final Distance gdistance = new Distance() {
@Override
public double pairwiseDistance(final int t1, final int t2) {
final int s1 = g2s.get(t1);
final int s2 = g2s.get(t2);
double d = jc.pairwiseDistance(t1,t2)/mu;
if( s1 != s2 ) {
final int i = getDMindex(speciesCount, s1,s2);
final double minDist = 2 * spmin[i];
if( d <= minDist ) {
d = minDist * 1.001;
}
}
return d;
}
};
final ClusterTree gtreec = new ClusterTree();
gtreec.initByName("initial", gtree, "taxonset", gtreeTaxa,
"clusterType", "upgma", "distance", gdistance);
}
}
{
final RealParameter lambda = birthRate.get();
if( lambda != null ) {
final double rh = stree.getRoot().getHeight();
double l = 0;
for(int i = 2; i < speciesCount+1; ++i) {
l += 1./i;
}
setParameterValue(lambda, (1 / rh) * l);
}
double totBranches = 0;
final Node[] streeNodeas = stree.getNodesAsArray();
for( final Node n : streeNodeas ) {
if( ! n.isRoot() ) {
totBranches += n.getLength();
}
}
totBranches /= 2* (streeNodeas.length - 1);
final RealParameter popm = popMean.get();
if( popm != null ) {
setParameterValue(popm, totBranches);
}
final SpeciesTreePrior speciesTreePrior = speciesTreePriorInput.get();
if( speciesTreePrior != null ) {
final RealParameter popb = speciesTreePrior.popSizesBottomInput.get();
if( popb != null ) {
for(int i = 0; i < popb.getDimension(); ++i) {
setParameterValue(popb, i, 2*totBranches);
}
}
final RealParameter popt = speciesTreePrior.popSizesTopInput.get();
if( popt != null ) {
for(int i = 0; i < popt.getDimension(); ++i) {
setParameterValue(popt, i, totBranches);
}
}
}
}
}
/** set parameter value taking bounds in account: if out of bounds, use closest boundary value instead **/
private void setParameterValue(RealParameter p, double value) {
setParameterValue(p, 0, value);
}
private void setParameterValue(RealParameter p, int index, double value) {
if (value < p.getLower()) {
value = p.getLower();
}
if (value > p.getUpper()) {
value = p.getUpper();
}
p.setValue(index, value);
}
private void randomInitGeneTrees(double speciesTreeHeight) {
final List<Tree> geneTrees = genes.get();
for (final Tree gtree : geneTrees) {
gtree.makeCaterpillar(speciesTreeHeight, speciesTreeHeight/gtree.getInternalNodeCount(), true);
}
}
private void randomInit() {
double lam = 1;
final RealParameter lambda = birthRate.get();
if( lambda != null ) {
lam = lambda.getArrayValue();
}
final Tree stree = speciesTreeInput.get();
final TaxonSet species = stree.m_taxonset.get();
final int speciesCount = species.asStringList().size();
double s = 0;
for(int k = 2; k <= speciesCount; ++k) {
s += 1.0/k;
}
final double rootHeight = (1/lam) * s;
stree.scale(rootHeight/stree.getRoot().getHeight());
randomInitGeneTrees(rootHeight);
// final List<Tree> geneTrees = genes.get();
// for (final Tree gtree : geneTrees) {
// gtree.makeCaterpillar(rootHeight, rootHeight/gtree.getInternalNodeCount(), true);
// }
}
private void initWithCalibrations() throws MathException {
final CalibratedYuleModel cYule = calibratedYule.get();
final Tree spTree = (Tree) cYule.treeInput.get();
final List<CalibrationPoint> cals = cYule.calibrationsInput.get();
final CalibratedYuleModel cym = new CalibratedYuleModel();
cym.getOutputs().addAll(cYule.getOutputs());
for( final CalibrationPoint cal : cals ) {
cym.setInputValue("calibrations", cal);
}
cym.setInputValue("tree", spTree);
cym.setInputValue("type", CalibratedYuleModel.Type.NONE);
cym.initAndValidate();
final Tree t = cym.compatibleInitialTree();
assert spTree.getLeafNodeCount() == t.getLeafNodeCount();
spTree.assignFromWithoutID(t);
// final CalibratedYuleInitialTree ct = new CalibratedYuleInitialTree();
// ct.initByName("initial", spTree, "calibrations", cYule.calibrationsInput.get());
// ct.initStateNodes();
final double rootHeight = spTree.getRoot().getHeight();
randomInitGeneTrees(rootHeight);
cYule.initAndValidate();
}
private void initWithMRCACalibrations(List<MRCAPrior> calibrations) {
final Tree spTree = speciesTreeInput.get();
final RandomTree rnd = new RandomTree();
rnd.setInputValue("taxonset", spTree.getTaxonset());
for( final MRCAPrior cal : calibrations ) {
rnd.setInputValue("constraint", cal);
}
ConstantPopulation pf = new ConstantPopulation();
pf.setInputValue("popSize", new RealParameter("1.0"));
rnd.setInputValue("populationModel", pf);
rnd.initAndValidate();
spTree.assignFromWithoutID((Tree)rnd);
final double rootHeight = spTree.getRoot().getHeight();
randomInitGeneTrees(rootHeight);
}
@Override
public void getInitialisedStateNodes(final List<StateNode> stateNodes) {
if( hasCalibrations ) {
stateNodes.add((Tree) calibratedYule.get().treeInput.get());
} else {
stateNodes.add(speciesTreeInput.get());
}
for( final Tree g : genes.get() ) {
stateNodes.add(g);
}
final RealParameter popm = popMean.get();
if( popm != null ) {
stateNodes.add(popm);
}
final RealParameter brate = birthRate.get();
if( brate != null ) {
stateNodes.add(brate) ;
}
final SpeciesTreePrior speciesTreePrior = speciesTreePriorInput.get();
if( speciesTreePrior != null ) {
final RealParameter popb = speciesTreePrior.popSizesBottomInput.get();
if( popb != null ) {
stateNodes.add(popb) ;
}
final RealParameter popt = speciesTreePrior.popSizesTopInput.get();
if( popt != null ) {
stateNodes.add(popt);
}
}
}
}