/* Copyright (C) 2009 Diego Darriba This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */ package es.uvigo.darwin.jmodeltest.tree; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.PriorityQueue; import java.util.Set; import pal.misc.IdGroup; import pal.tree.Node; import pal.tree.NodeFactory; import pal.tree.SimpleTree; import pal.tree.Tree; import es.uvigo.darwin.jmodeltest.exception.InternalException; import es.uvigo.darwin.jmodeltest.model.Model; import es.uvigo.darwin.jmodeltest.selection.InformationCriterion; import es.uvigo.darwin.jmodeltest.utilities.FixedBitSet; import es.uvigo.darwin.jmodeltest.utilities.MyFormattedOutput; import es.uvigo.darwin.jmodeltest.utilities.Utilities; /** * Phylogenetic consensus tree builder. * * @author Diego Darriba * @since 3.0 */ public class Consensus { /** Display branch suport as percent. */ final static public boolean SUPPORT_AS_PERCENT = false; /** Calculate branch lengths as weighted average. */ final static public int BRANCH_LENGTHS_AVERAGE = 1; /** Calculate branch lengths as weighted median. */ final static public int BRANCH_LENGTHS_MEDIAN = 2; /** Default branch lengths algorithm */ private static final BranchDistances DEFAULT_BRANCH_DISTANCES = BranchDistances.WeightedMedian; /** The Constant FIRST (just for source code visibility). */ private static final int FIRST = 0; /** The weighted trees in consensus. */ private List<WeightedTree> trees; /** The cummulative weight. */ private double cumWeight = 0.0; /** The number of taxa. */ private int numTaxa; /** The common id group of the tree set. */ private IdGroup idGroup; /** The set of clade supports. */ private Map<FixedBitSet, Support> support = new HashMap<FixedBitSet, Support>(); /** The set of clade supports to get from outside this class. */ private Map<FixedBitSet, Double> cladeSupport; /** The inner consensus tree. */ private Tree consensusTree; /** The splits included in consensus tree */ private List<FixedBitSet> splitsInConsensus = new ArrayList<FixedBitSet>(); /** The splits not included in consensus tree */ private List<FixedBitSet> splitsOutFromConsensus = new ArrayList<FixedBitSet>(); /** * Gets the clade support, with Support instances * * @return the map of the support for each bitSet */ private Map<FixedBitSet, Support> getSupport() { return support; } /** * Gets the double precision clade support * * @return the map of the support for each bitSet */ public Map<FixedBitSet, Double> getCladeSupport() { if (cladeSupport == null) { cladeSupport = new HashMap<FixedBitSet, Double>(support.size()); FixedBitSet[] keys = support.keySet().toArray(new FixedBitSet[0]); Arrays.sort(keys); for (FixedBitSet fbs : keys) { cladeSupport.put(fbs, support.get(fbs).treesWeightWithClade / cumWeight); } } return cladeSupport; } /** * Gets the Id Group of the set of trees * * @return the id group */ public IdGroup getIdGroup() { return idGroup; } /** * Gets the consensus tree * * @return the consensus tree */ public Tree getConsensusTree() { return consensusTree; } /** * Gets the set of trees included in the consensus. * * @return the trees */ public Collection<WeightedTree> getTrees() { return trees; } /** * Adds a weighted tree to the set. * * @param wTree the weighted tree * * @return true, if successful */ private boolean addTree(WeightedTree wTree) { //check integrity if (wTree.getTree() == null || wTree.getWeight() < 0.0) { throw new InternalException(); } //check compatibility if (trees.isEmpty()) { trees.add(wTree); numTaxa = wTree.getTree().getIdCount(); idGroup = pal.tree.TreeUtils.getLeafIdGroup(wTree.getTree()); } else { if (wTree.getTree().getIdCount() != numTaxa) { return false; } Tree pTree = trees.get(FIRST).getTree(); for (int i = 0; i < numTaxa; i++) { boolean found = false; for (int j = 0; j < numTaxa; j++) { if (wTree.getTree().getIdentifier(i).equals(pTree.getIdentifier(j))) { found = true; break; } } if (!found) { System.out.println("NOT COMPATIBLE TREES"); return false; } } trees.add(wTree); } cumWeight += wTree.getWeight(); return true; } /** * Instantiates a new consensus tree builder. * * @param ic the information criterion to build the weighted trees * @param supportThreshold the minimum support for a clade */ public Consensus(InformationCriterion ic, double supportThreshold) { this(ic, supportThreshold, 0); } /** * Instantiates a new consensus tree builder. * * @param ic the information criterion to build the weighted trees * @param supportThreshold the minimum support for a clade * @param branchDistances the method to get the consensus branch lengths */ public Consensus(InformationCriterion ic, double supportThreshold, int branchDistances) { this.trees = new ArrayList<WeightedTree>(); for (Model model : ic.getConfidenceModels()) { WeightedTree wTree = new WeightedTree( model.getTree(), ic.getWeight(model)); this.addTree(wTree); } consensusTree = buildTree(supportThreshold, getBranchDistances(branchDistances)); } /** * Instantiates a new unweighted consensus builder. * * @param trees the trees * @param supportThreshold the minimum support for a clade * @param branchDistances the method to get the consensus branch lengths */ public Consensus(List<WeightedTree> trees, double supportThreshold, int branchDistances) { this.trees = new ArrayList<WeightedTree>(); for (WeightedTree tree : trees) { this.addTree(tree); } consensusTree = buildTree(supportThreshold, getBranchDistances(branchDistances)); } /** * Calculates rooted support. * * @param wTree the weighted tree instance * @param node the node * @param support the support * * @return the fixed bit set */ private FixedBitSet rootedSupport(WeightedTree wTree, Node node, Map<FixedBitSet, Support> support) { FixedBitSet clade = new FixedBitSet(numTaxa); if (node.isLeaf()) { clade.set(idGroup.whichIdNumber(node.getIdentifier().getName())); } else { for (int i = 0; i < node.getChildCount(); i++) { Node n = node.getChild(i); FixedBitSet childClade = rootedSupport(wTree, n, support); clade.union(childClade); } } Support s = support.get(clade); if (s == null) { s = new Support(); support.put(clade, s); } s.add(wTree.getWeight(), TreeUtilities.safeNodeHeight(wTree.getTree(), node), node.getBranchLength()); return clade; } /** * Detach the children of a tree. * * @param tree the tree * @param node the node to detach * @param split the split * * @return the node */ public Node detachChildren(Tree tree, Node node, List<Integer> split) { assert (split.size() > 1); List<Node> detached = new ArrayList<Node>(); for (int n : split) { detached.add(node.getChild(n)); } Node saveRoot = tree.getRoot(); List<Integer> toRemove = new ArrayList<Integer>(); for (int i = 0; i < node.getChildCount(); i++) { Node n = node.getChild(i); if (detached.contains(n)) { toRemove.add(0, i); } } for (int i : toRemove) { node.removeChild(i); } Node dnode = NodeFactory.createNode(detached.toArray(new Node[0])); node.addChild(dnode); tree.setRoot(saveRoot); return dnode; } /** * Builds the consensus tree over a set of weighted trees. * * @param supportThreshold the minimum support to consider a split into the consensus tree * * @return the consensus tree */ private Tree buildTree(double supportThreshold, BranchDistances branchDistances) { if (trees.isEmpty()) { throw new InternalException("There are no trees to consense"); } if (supportThreshold < 0.5 || supportThreshold > 1.0) { throw new InternalException("Invalid threshold value: " + supportThreshold); } double effectiveThreshold = supportThreshold; if (supportThreshold == 0.5) { effectiveThreshold += 1.0/(numTaxa+1); } else if (supportThreshold == 1.0) { effectiveThreshold -= 1.0/(numTaxa+1); } // establish support support = new HashMap<FixedBitSet, Support>(); for (WeightedTree wTree : trees) { rootedSupport(wTree, wTree.getTree().getRoot(), support); } Tree cons = new SimpleTree(); // Contains all internal nodes in the tree so far, ordered so descendants // appear later than ancestors List<Node> internalNodes = new ArrayList<Node>(numTaxa); // For each internal node, a bit-set with the complete set of tips for it's clade List<FixedBitSet> internalNodesTips = new ArrayList<FixedBitSet>(numTaxa); assert idGroup.getIdCount() == numTaxa; // establish a tree with one root having all tips as descendants internalNodesTips.add(new FixedBitSet(numTaxa)); FixedBitSet rooNode = internalNodesTips.get(0); Node[] nodes = new Node[numTaxa]; for (int nt = 0; nt < numTaxa; ++nt) { nodes[nt] = NodeFactory.createNode(idGroup.getIdentifier(nt)); rooNode.set(nt); } Node rootNode = NodeFactory.createNode(nodes); internalNodes.add(rootNode); cons.setRoot(rootNode); // sorts support from largest to smallest final Comparator<Map.Entry<FixedBitSet, Support>> comparator = new Comparator<Map.Entry<FixedBitSet, Support>>() { @Override public int compare(Map.Entry<FixedBitSet, Support> o1, Map.Entry<FixedBitSet, Support> o2) { double diff = o2.getValue().treesWeightWithClade - o1.getValue().treesWeightWithClade; if (diff > 0.0) { return 1; } else if (diff < 0.0) { return -1; } else { return 0; } } }; // add everything to queue PriorityQueue<Map.Entry<FixedBitSet, Support>> queue = new PriorityQueue<Map.Entry<FixedBitSet, Support>>(support.size(), comparator); for (Map.Entry<FixedBitSet, Support> se : support.entrySet()) { Support s = se.getValue(); FixedBitSet clade = se.getKey(); final int cladeSize = clade.cardinality(); if (cladeSize == numTaxa) { // root cons.getRoot().setNodeHeight(s.sumBranches / trees.size()); cons.getRoot().setBranchLength(branchDistances.build(s.branchLengths)); continue; } if (Math.abs(s.treesWeightWithClade - this.cumWeight) < 1e-5 && cladeSize == 1) { // leaf/external node final int nt = clade.nextOnBit(FIRST); final Node leaf = cons.getExternalNode(nt); leaf.setNodeHeight(s.sumBranches / trees.size()); leaf.setBranchLength(branchDistances.build(s.branchLengths)); } else { queue.add(se); } } while (queue.peek() != null) { Map.Entry<FixedBitSet, Support> e = queue.poll(); final Support s = e.getValue(); final double psupport = (1.0 * s.treesWeightWithClade) / cumWeight; if (psupport < effectiveThreshold) { break; } final FixedBitSet cladeTips = e.getKey(); boolean found = false; /* locate the node containing the clade. going in reverse order ensures the lowest one is hit first */ for (int nsub = internalNodesTips.size() - 1; nsub >= 0; --nsub) { FixedBitSet allNodeTips = internalNodesTips.get(nsub); // size of intersection between tips & split final int nSplit = allNodeTips.intersectCardinality(cladeTips); if (nSplit == cladeTips.cardinality()) { // node contains all of clade // Locate node descendants containing the split found = true; List<Integer> split = new ArrayList<Integer>(); Node n = internalNodes.get(nsub); int l = 0; for (int j = 0; j < n.getChildCount(); j++) { Node ch = n.getChild(j); if (ch.isLeaf()) { if (cladeTips.contains(idGroup.whichIdNumber(ch.getIdentifier().getName()))) { split.add(l); } } else { // internal final int o = internalNodes.indexOf(ch); final int i = internalNodesTips.get(o).intersectCardinality(cladeTips); if (i == internalNodesTips.get(o).cardinality()) { split.add(l); } else if (i > 0) { // Non compatible found = false; break; } } ++l; } if (!(found && split.size() < n.getChildCount())) { found = false; break; } if (split.isEmpty()) { System.err.println("Bug??"); assert (false); } final Node detached = detachChildren(cons, n, split); final double height = s.sumBranches / s.nTreesWithClade; detached.setNodeHeight(height); detached.setBranchLength(branchDistances.build(s.branchLengths)); cons.setAttribute(detached, TreeUtilities.TREE_CLADE_SUPPORT_ATTRIBUTE, SUPPORT_AS_PERCENT ? 100 * psupport : psupport); // insert just after parent, so before any descendants internalNodes.add(nsub + 1, detached); internalNodesTips.add(nsub + 1, new FixedBitSet(cladeTips)); break; } } } TreeUtilities.insureConsistency(cons, cons.getRoot()); String thresholdAsPercent = String.valueOf(supportThreshold * 100); cons.setAttribute(cons.getRoot(), TreeUtilities.TREE_NAME_ATTRIBUTE, "cons_" + thresholdAsPercent + "_majRule"); Set<FixedBitSet> keySet = getSupport().keySet(); FixedBitSet[] keys = keySet.toArray(new FixedBitSet[0]); Arrays.sort(keys); for (FixedBitSet fbs : keys) { if (fbs.cardinality() > 1) { double psupport = (1.0 * getSupport().get(fbs).getTreesWeightWithClade()) / cumWeight; if (psupport < effectiveThreshold) { splitsOutFromConsensus.add(fbs); } else { splitsInConsensus.add(fbs); } } } return cons; } /** * Enum to calculate the branch lengths */ private enum BranchDistances { WeightedAverage { /** * Calculates the weighted average. * * @param values the weighted values * @param cumWeight the sum of weights * * @return the weighted average of the set */ @Override public double build(List<WeightLengthPair> values) { double avg = 0.0; double cumWeight = 0.0; for (WeightLengthPair pair : values) { avg += pair.branchLength * pair.weight; cumWeight += pair.weight; } avg /= cumWeight; return avg; } }, WeightedMedian { /** * Calculates the weighted median. * * @param values the weighted values * @param cumWeight the sum of weights * * @return the weighted median of the set */ @Override public double build(List<WeightLengthPair> values) { Collections.sort(values); double median = -1; double cumWeight = 0.0; for (WeightLengthPair pair : values) { cumWeight += pair.weight; } double halfWeight = cumWeight / 2.0; double cumValue = 0.0; for (WeightLengthPair pair : values) { cumValue += pair.weight; if (cumValue >= halfWeight) { median = pair.branchLength; break; } } return median; } }; public abstract double build(List<WeightLengthPair> values); } /** * One clade support. */ static final class Support { /** number of trees containing the clade. */ private int nTreesWithClade; /** The trees weight with clade. */ private double treesWeightWithClade; /** The branch lengths. */ private ArrayList<WeightLengthPair> branchLengths; /** Sum of node heights of trees containing the clade. */ private double sumBranches; public double getTreesWeightWithClade() { return treesWeightWithClade; } /** * Instantiates a new support. */ Support() { sumBranches = 0.0; treesWeightWithClade = 0.0; nTreesWithClade = 0; branchLengths = new ArrayList<WeightLengthPair>(); } /** * Adds the branch to the map of branch lengths. * * @param weight the weight * @param height the height * @param branchLength the branch length */ public final void add(double weight, double height, double branchLength) { sumBranches += height; branchLengths.add(new WeightLengthPair(weight, branchLength)); treesWeightWithClade += weight; ++nTreesWithClade; // double testW = 0.0; // for (WeightLengthPair wlp : branchLengths) { // testW += wlp.weight; // } } } static class WeightLengthPair implements Comparable<WeightLengthPair> { private double weight; private double branchLength; WeightLengthPair(double weight, double branchLength) { this.weight = weight; this.branchLength = branchLength; } @Override public int compareTo(WeightLengthPair o) { if (branchLength < o.branchLength) { return -1; } else if (branchLength > o.branchLength) { return 1; } return 0; } } /** * A extension of Weighted tree but every tree * has the same weight. */ static class UnweightedTree extends WeightedTree { /** * Instantiates a new unweighted tree. * * @param tree the tree */ UnweightedTree(Tree tree) { super(tree, 1.0); } } public String getTaxaHeader() { StringBuilder taxaHeader = new StringBuilder(); for (int i = 0; i < numTaxa; i++) { taxaHeader.append(String.valueOf(i + 1).charAt(0)); } if (numTaxa >= 10) { taxaHeader.append('\n'); taxaHeader.append(MyFormattedOutput.space(4 + 9, ' ')); for (int i = 9; i < numTaxa; i++) { taxaHeader.append(String.valueOf(i + 1).charAt(1)); } } if (numTaxa >= 100) { taxaHeader.append('\n'); taxaHeader.append(MyFormattedOutput.space(4 + 99, ' ')); for (int i = 99; i < numTaxa; i++) { taxaHeader.append(String.valueOf(i + 1).charAt(2)); } } if (numTaxa >= 1000) { taxaHeader.append('\n'); taxaHeader.append(MyFormattedOutput.space(4 + 999, ' ')); for (int i = 999; i < numTaxa; i++) { taxaHeader.append(String.valueOf(i + 1).charAt(3)); } } return taxaHeader.toString(); } public String getSetsIncluded() { StringBuilder setsIncluded = new StringBuilder(); setsIncluded.append(" "); setsIncluded.append(getTaxaHeader()); setsIncluded.append('\n'); for (FixedBitSet fbs : splitsInConsensus) { setsIncluded.append(" ") .append(fbs.splitRepresentation()) .append(" ( ") .append(Utilities.roundDoubleTo(getCladeSupport().get(fbs), 5)) .append(" )") .append('\n'); } return setsIncluded.toString(); } public String getSetsNotIncluded() { StringBuilder setsIncluded = new StringBuilder(); setsIncluded.append(" "); setsIncluded.append(getTaxaHeader()); setsIncluded.append('\n'); for (FixedBitSet fbs : splitsOutFromConsensus) { setsIncluded.append(" ") .append(fbs.splitRepresentation()) .append(" ( ") .append(Utilities.roundDoubleTo(getCladeSupport().get(fbs), 5)) .append(" )") .append('\n'); } return setsIncluded.toString(); } private BranchDistances getBranchDistances(int value) { BranchDistances bd; switch (value) { case BRANCH_LENGTHS_AVERAGE: bd = BranchDistances.WeightedAverage; break; case BRANCH_LENGTHS_MEDIAN: bd = BranchDistances.WeightedMedian; break; default: // Weighted average bd = DEFAULT_BRANCH_DISTANCES; } return bd; } }