/*
* TreeSummary.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.app.tools;
import dr.app.beast.BeastVersion;
import dr.app.util.Arguments;
import dr.evolution.io.Importer;
import dr.evolution.io.NexusImporter;
import dr.evolution.io.TreeImporter;
import dr.evolution.tree.*;
import dr.evolution.util.Taxon;
import dr.evolution.util.TaxonList;
import dr.util.Version;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.util.*;
/**
* @author Andrew Rambaut
*/
public class TreeSummary {
private final static Version version = new BeastVersion();
// Messages to stderr, output to stdout
private static PrintStream progressStream = System.err;
/**
* Burnin can be specified as the number of trees or the number of states
* (one or other should be zero).
* @param burninTrees
* @param burninStates
* @param posteriorLimit
* @param inputFileName
* @param outputFileName
* @throws java.io.IOException
*/
public TreeSummary(final int burninTrees,
final int burninStates,
double posteriorLimit,
boolean createSummaryTree,
boolean createCladeMap,
String inputFileName,
String outputFileName
) throws IOException {
this.posteriorLimit = posteriorLimit;
CladeSystem cladeSystem = new CladeSystem();
int burnin = -1;
totalTrees = 10000;
totalTreesUsed = 0;
progressStream.println("Reading trees (bar assumes 10,000 trees)...");
progressStream.println("0 25 50 75 100");
progressStream.println("|--------------|--------------|--------------|--------------|");
int stepSize = totalTrees / 60;
if (stepSize < 1) stepSize = 1;
cladeSystem = new CladeSystem();
FileReader fileReader = new FileReader(inputFileName);
TreeImporter importer = new NexusImporter(fileReader);
try {
totalTrees = 0;
while (importer.hasTree()) {
Tree tree = importer.importNextTree();
int state = Integer.MAX_VALUE;
if (burninStates > 0) {
// if burnin has been specified in states, try to parse it out...
String name = tree.getId().trim();
if (name != null && name.length() > 0 && name.startsWith("STATE_")) {
state = Integer.parseInt(name.split("_")[1]);
}
}
if (totalTrees >= burninTrees && state >= burninStates) {
// if either of the two burnin thresholds have been reached...
if (burnin < 0) {
// if this is the first time this point has been reached,
// record the number of trees this represents for future use...
burnin = totalTrees;
}
cladeSystem.add(tree, true);
totalTreesUsed += 1;
}
if (totalTrees > 0 && totalTrees % stepSize == 0) {
progressStream.print("*");
progressStream.flush();
}
totalTrees++;
}
} catch (Importer.ImportException e) {
System.err.println("Error Parsing Input Tree: " + e.getMessage());
return;
}
fileReader.close();
progressStream.println();
progressStream.println();
if (totalTrees < 1) {
System.err.println("No trees");
return;
}
if (totalTreesUsed <= 1) {
if (burnin > 0) {
System.err.println("No trees to use: burnin too high");
return;
}
}
cladeSystem.calculateCladeCredibilities(totalTreesUsed);
progressStream.println("Total trees read: " + totalTrees);
if (burninTrees > 0) {
progressStream.println("Ignoring first " + burninTrees + " trees" +
(burninStates > 0 ? " (" + burninStates + " states)." : "." ));
} else if (burninStates > 0) {
progressStream.println("Ignoring first " + burninStates + " states (" + burnin + " trees).");
}
progressStream.println("Total unique clades: " + cladeSystem.getCladeMap().keySet().size());
progressStream.println();
if (createCladeMap) {
Map<BitSet, Integer> cladeCountMap = cladeSystem.getCladeCounts();
System.out.println("No.\tSize\tCred\tMembers");
int n = 1;
for (BitSet bits : cladeCountMap.keySet()) {
System.out.print(n);
System.out.print("\t");
System.out.print(bits.cardinality());
System.out.print("\t");
System.out.print(cladeSystem.getCladeCredibility(bits));
System.out.print("\t");
System.out.println(cladeSystem.getCladeString(bits));
n++;
}
System.out.println();
progressStream.println("Reading trees...");
progressStream.println("0 25 50 75 100");
progressStream.println("|--------------|--------------|--------------|--------------|");
stepSize = totalTrees / 60;
fileReader = new FileReader(inputFileName);
importer = new NexusImporter(fileReader);
final PrintStream stream = outputFileName != null ?
new PrintStream(new FileOutputStream(outputFileName)) :
System.out;
stream.print("Clade");
n = 1;
for (BitSet bits : cladeCountMap.keySet()) {
stream.print("\t");
stream.print(n);
n++;
}
stream.println();
stream.print("State");
for (BitSet bits : cladeCountMap.keySet()) {
stream.print("\t");
stream.print(cladeSystem.getCladeCredibility(bits));
}
stream.println();
try {
totalTrees = 0;
while (importer.hasTree()) {
Tree tree = importer.importNextTree();
int state = totalTrees;
if (burninStates > 0) {
// if burnin has been specified in states, try to parse it out...
String name = tree.getId().trim();
if (name != null && name.length() > 0 && name.startsWith("STATE_")) {
state = Integer.parseInt(name.split("_")[1]);
}
}
if (totalTrees >= burninTrees && state >= burninStates) {
// if either of the two burnin thresholds have been reached...
if (burnin < 0) {
// if this is the first time this point has been reached,
// record the number of trees this represents for future use...
burnin = totalTrees;
}
stream.print(state);
Set<BitSet> cladeSet = cladeSystem.getCladeSet(tree);
for (BitSet bits : cladeCountMap.keySet()) {
stream.print("\t");
stream.print(cladeSet.contains(bits) ? "1" : "0");
}
stream.println();
totalTreesUsed += 1;
}
if (totalTrees > 0 && totalTrees % stepSize == 0) {
progressStream.print("*");
progressStream.flush();
}
totalTrees++;
}
} catch (Importer.ImportException e) {
System.err.println("Error Parsing Input Tree: " + e.getMessage());
return;
}
fileReader.close();
stream.close();
progressStream.println();
progressStream.println();
}
if (createSummaryTree) {
progressStream.println("Finding summary tree...");
// MutableTree targetTree = new FlexibleTree(summarizeTrees(burnin, cladeSystem, inputFileName /*, false*/));
Tree consensusTree = buildConsensusTree(cladeSystem);
progressStream.println("Writing consensus tree....");
try {
final PrintStream stream = outputFileName != null ?
new PrintStream(new FileOutputStream(outputFileName)) :
System.out;
new NexusExporter(stream).exportTree(consensusTree);
} catch (Exception e) {
System.err.println("Error writing consensus tree file: " + e.getMessage());
return;
}
}
}
private Tree buildConsensusTree(CladeSystem cladeSystem) {
List<BitSet> bitSets = new ArrayList<BitSet>(cladeSystem.getCladeMap().keySet());
Collections.sort(bitSets, new Comparator<BitSet>() {
@Override
public int compare(BitSet b1, BitSet b2) {
return b1.cardinality() - b2.cardinality();
}
});
SimpleNode root = null;
for (int i = 0; i < bitSets.size(); i++) {
BitSet key1 = bitSets.get(i);
CladeSystem.Clade clade1 = cladeSystem.getCladeMap().get(key1);
if (key1.cardinality() == 1) {
Taxon taxon = cladeSystem.getTaxon(key1.nextSetBit(0));
clade1.node = new SimpleNode();
clade1.node.setTaxon(taxon);
clade1.node.setAttribute("clade", clade1);
}
if (clade1.credibility >= posteriorLimit) {
for (int j = i + 1; j < bitSets.size(); j++) {
BitSet key2 = bitSets.get(j);
if (isSubSet(key1, key2) && cladeSystem.getCladeCredibility(key2) >= posteriorLimit) {
// the clades are ordered by size so this is the smallest clade for which
// clade1 is a subset and has high credibility.
CladeSystem.Clade clade2 = cladeSystem.getCladeMap().get(key2);
if (clade2.node == null) {
clade2.node = new SimpleNode();
}
if (clade1.node == null) {
throw new RuntimeException("null node");
}
clade2.node.addChild(clade1.node);
clade2.node.setAttribute("credibility", clade2.credibility);
clade2.node.setAttribute("clade", clade2);
if (key2.cardinality() == cladeSystem.taxonList.getTaxonCount()) {
root = clade2.node;
}
break;
}
}
}
}
SimpleTree tree = new SimpleTree(root);
return tree;
}
private Tree summarizeTrees(int burnin, Tree consensusTree, CladeSystem cladeSystem, String inputFileName)
throws IOException {
progressStream.println("Analyzing " + totalTreesUsed + " trees...");
progressStream.println("0 25 50 75 100");
progressStream.println("|--------------|--------------|--------------|--------------|");
int stepSize = totalTrees / 60;
if (stepSize < 1) stepSize = 1;
int counter = 0;
int bestTreeNumber = 0;
TreeImporter importer = new NexusImporter(new FileReader(inputFileName));
try {
while (importer.hasTree()) {
Tree tree = importer.importNextTree();
if (counter >= burnin) {
cladeSystem.addSubTrees(tree);
}
if (counter > 0 && counter % stepSize == 0) {
progressStream.print("*");
progressStream.flush();
}
counter++;
}
} catch (Importer.ImportException e) {
System.err.println("Error Parsing Input Tree: " + e.getMessage());
return null;
}
Tree bestTree = cladeSystem.getBestTree(consensusTree);
return bestTree;
}
private class CladeSystem {
public CladeSystem() {
}
public CladeSystem(TaxonList taxonList) {
this.taxonList = taxonList;
}
/**
* adds all the clades in the tree
*/
public void add(Tree tree, boolean includeTips) {
if (taxonList == null) {
taxonList = tree;
}
// Recurse over the tree and add all the clades (or increment their
// frequency if already present). The root clade is added too (for
// annotation purposes).
addClades(tree, tree.getRoot(), includeTips);
}
private BitSet addClades(Tree tree, NodeRef node, boolean includeTips) {
BitSet bits = new BitSet();
if (tree.isExternal(node)) {
int index = taxonList.getTaxonIndex(tree.getNodeTaxon(node).getId());
bits.set(index);
if (includeTips) {
addClade(bits);
}
} else {
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef node1 = tree.getChild(node, i);
bits.or(addClades(tree, node1, includeTips));
}
addClade(bits);
}
return bits;
}
private void addClade(BitSet bits) {
Clade clade = cladeMap.get(bits);
if (clade == null) {
clade = new Clade(bits);
cladeMap.put(bits, clade);
}
clade.setCount(clade.getCount() + 1);
}
public void addSubTrees(Tree tree) {
addSubTrees(tree, tree.getRoot());
}
private BitSet addSubTrees(Tree tree, NodeRef node) {
BitSet bits = new BitSet();
if (tree.isExternal(node)) {
int index = taxonList.getTaxonIndex(tree.getNodeTaxon(node).getId());
bits.set(index);
} else {
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef node1 = tree.getChild(node, i);
BitSet bits2 = addSubTrees(tree, node1);
bits.or(bits2);
}
Clade clade = cladeMap.get(bits);
if (clade.credibility >= posteriorLimit) {
if (clade.conditionalCladeSystem == null) {
clade.conditionalCladeSystem = new CladeSystem(taxonList);
}
clade.conditionalCladeSystem.addClades(tree, node, false);
}
}
return bits;
}
public Map<BitSet, Clade> getCladeMap() {
return cladeMap;
}
public void calculateCladeCredibilities(int totalTreesUsed) {
for (Clade clade : cladeMap.values()) {
if (clade.getCount() > totalTreesUsed) {
throw new AssertionError("clade.getCount=(" + clade.getCount() +
") should be <= totalTreesUsed = (" + totalTreesUsed + ")");
}
clade.setCredibility(((double) clade.getCount()) / (double) totalTreesUsed);
if (clade.conditionalCladeSystem != null) {
clade.conditionalCladeSystem.calculateCladeCredibilities(totalTreesUsed);
}
}
}
public double getLogCladeCredibility(Tree tree, NodeRef node, BitSet bits) {
double logCladeCredibility = 0.0;
if (tree.isExternal(node)) {
int index = taxonList.getTaxonIndex(tree.getNodeTaxon(node).getId());
bits.set(index);
} else {
BitSet bits2 = new BitSet();
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef node1 = tree.getChild(node, i);
logCladeCredibility += getLogCladeCredibility(tree, node1, bits2);
}
logCladeCredibility += Math.log(getCladeCredibility(bits2));
if (bits != null) {
bits.or(bits2);
}
}
return logCladeCredibility;
}
private double getCladeCredibility(BitSet bits) {
Clade clade = cladeMap.get(bits);
if (clade == null) {
return 0.0;
}
return clade.getCredibility();
}
private int getCladeCount(BitSet bits) {
Clade clade = cladeMap.get(bits);
if (clade == null) {
return 0;
}
return clade.getCount();
}
public Tree getBestTree(Tree consensusTree) {
return new SimpleTree(getBestSubTree(consensusTree, consensusTree.getRoot()));
}
public SimpleNode getBestSubTree(Tree tree, NodeRef node) {
SimpleNode subTree = null;
if (tree.isExternal(node)) {
} else {
Clade clade = (Clade) tree.getNodeAttribute(node, "clade");
if (clade.conditionalCladeSystem != null) {
subTree = clade.conditionalCladeSystem.getBestSubTree(tree, node);
} else {
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef child = tree.getChild(node, i);
SimpleNode newChild = getBestSubTree(tree, child);
}
}
}
return subTree;
}
public Taxon getTaxon(int index) {
return taxonList.getTaxon(index);
}
private BitSet getClades(Tree tree, NodeRef node, boolean includeTips, Set<BitSet> cladeSet) {
BitSet bits = new BitSet();
if (tree.isExternal(node)) {
int index = taxonList.getTaxonIndex(tree.getNodeTaxon(node).getId());
bits.set(index);
if (includeTips) {
cladeSet.add(bits);
}
} else {
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef node1 = tree.getChild(node, i);
bits.or(getClades(tree, node1, includeTips, cladeSet));
}
cladeSet.add(bits);
}
return bits;
}
public Set<BitSet> getCladeSet(Tree tree) {
Set<BitSet> cladeSet = new HashSet<BitSet>();
getClades(tree, tree.getRoot(), false, cladeSet);
return cladeSet;
}
public Map<BitSet, Integer> getCladeCounts() {
Map<BitSet, Integer> countMap = new HashMap<BitSet, Integer>();
for (BitSet bits : cladeMap.keySet()) {
int count = getCladeCount(bits);
if (count > 1) {
countMap.put(bits, count);
}
}
return countMap;
}
public String getCladeString(BitSet bits) {
StringBuilder sb = new StringBuilder("{");
int index = bits.nextSetBit(0);
if (index != -1) {
sb.append(taxonList.getTaxon(index).getId());
index = bits.nextSetBit(index + 1);
}
while (index != -1) {
sb.append(",");
sb.append(taxonList.getTaxon(index).getId());
index = bits.nextSetBit(index + 1);
}
sb.append("}");
return sb.toString();
}
class Clade {
public Clade(BitSet bits) {
this.bits = bits;
count = 0;
credibility = 0.0;
}
public int getCount() {
return count;
}
public void setCount(int count) {
this.count = count;
}
public double getCredibility() {
return credibility;
}
public void setCredibility(double credibility) {
this.credibility = credibility;
}
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final Clade clade = (Clade) o;
return !(bits != null ? !bits.equals(clade.bits) : clade.bits != null);
}
public int hashCode() {
return (bits != null ? bits.hashCode() : 0);
}
public String toString() {
return "clade " + bits.toString();
}
int count;
double credibility;
BitSet bits;
SimpleNode node;
CladeSystem conditionalCladeSystem;
}
//
// Private stuff
//
TaxonList taxonList = null;
Map<BitSet, Clade> cladeMap = new HashMap<BitSet, Clade>();
Tree targetTree;
}
int totalTrees = 0;
int totalTreesUsed = 0;
double posteriorLimit = 0.0;
Set<String> attributeNames = new HashSet<String>();
public static void printTitle() {
progressStream.println();
centreLine("TreeSummary " + version.getVersionString() + ", " + version.getDateString(), 60);
centreLine("MCMC tree set summarizer", 60);
centreLine("by", 60);
centreLine("Andrew Rambaut", 60);
progressStream.println();
centreLine("Institute of Evolutionary Biology", 60);
centreLine("University of Edinburgh", 60);
centreLine("a.rambaut@ed.ac.uk", 60);
progressStream.println();
progressStream.println();
}
public static void centreLine(String line, int pageWidth) {
int n = pageWidth - line.length();
int n1 = n / 2;
for (int i = 0; i < n1; i++) {
progressStream.print(" ");
}
progressStream.println(line);
}
public static void printUsage(Arguments arguments) {
arguments.printUsage("treesummary", "<input-file-name> [<output-file-name>]");
progressStream.println();
progressStream.println(" Example: treesummary test.trees out.txt");
progressStream.println(" Example: treesummary -burnin 100 -heights mean test.trees out.txt");
progressStream.println();
}
//Main method
public static void main(String[] args) throws IOException {
// There is a major issue with languages that use the comma as a decimal separator.
// To ensure compatibility between programs in the package, enforce the US locale.
Locale.setDefault(Locale.US);
String inputFileName = null;
String outputFileName = null;
printTitle();
Arguments arguments = new Arguments(
new Arguments.Option[]{
new Arguments.IntegerOption("burnin", "the number of states to be considered as 'burn-in'"),
new Arguments.IntegerOption("burninTrees", "the number of trees to be considered as 'burn-in'"),
new Arguments.Option("clademap", "show states of all clades over chain length"),
new Arguments.RealOption("limit", "the minimum posterior probability for a subtree to be included"),
new Arguments.Option("help", "option to print this message")
});
try {
arguments.parseArguments(args);
} catch (Arguments.ArgumentException ae) {
progressStream.println(ae);
printUsage(arguments);
System.exit(1);
}
if (arguments.hasOption("help")) {
printUsage(arguments);
System.exit(0);
}
int burninStates = -1;
int burninTrees = -1;
if (arguments.hasOption("burnin")) {
burninStates = arguments.getIntegerOption("burnin");
}
if (arguments.hasOption("burninTrees")) {
burninTrees = arguments.getIntegerOption("burninTrees");
}
double posteriorLimit = 0.5;
if (arguments.hasOption("limit")) {
posteriorLimit = arguments.getRealOption("limit");
}
final String[] args2 = arguments.getLeftoverArguments();
boolean createCladeMap = false;
boolean createSummaryTree = true;
if (arguments.hasOption("clademap")) {
createCladeMap = true;
createSummaryTree = false;
}
switch (args2.length) {
case 2:
outputFileName = args2[1];
// fall to
case 1:
inputFileName = args2[0];
break;
default: {
System.err.println("Unknown option: " + args2[2]);
System.err.println();
printUsage(arguments);
System.exit(1);
}
}
new TreeSummary(burninTrees, burninStates, posteriorLimit, createSummaryTree, createCladeMap, inputFileName, outputFileName);
System.exit(0);
}
/**
* Is x a subset of y?
* @param x
* @param y
* @return
*/
static boolean isSubSet(BitSet x, BitSet y) {
y = (BitSet) y.clone();
y.and(x);
return y.equals(x);
}
}