/*
* WeightedMultiplicativeBinary.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
/**
*
*/
package dr.evomodel.tree;
import dr.evolution.io.Importer;
import dr.evolution.io.TreeTrace;
import dr.evolution.tree.Clade;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.SimpleTree;
import dr.evolution.tree.Tree;
import dr.inference.model.Likelihood;
import dr.math.MathUtils;
import java.io.IOException;
import java.io.Reader;
import java.util.*;
/**
* @author Sebastian Hoehna
*/
public class WeightedMultiplicativeBinary extends AbstractCladeImportanceDistribution {
private final int TAXA_COUNT;
private double EPSILON;
private long samples = 0;
private HashMap<BitSet, Clade> cladeProbabilities;
private TreeTrace[] traces;
private int burnin;
/**
* @param epsilon - the default number of occurences for each clade which wasn't
* observed to guarantee non-zero probabilities
*/
public WeightedMultiplicativeBinary(Tree tree, double epsilon) {
// initializing global variables
cladeProbabilities = new HashMap<BitSet, Clade>();
// setting global variables
EPSILON = epsilon;
TAXA_COUNT = tree.getExternalNodeCount();
}
/**
* @param traces - samples of trees in a tree traces array.
* @param epsilon - the default number of occurences for each clade which wasn't
* observed to guarantee non-zero probabilities
* @param burnIn - number of trees discarded from the trace
* @param verbose - hide the runtime status and outputs
*/
public WeightedMultiplicativeBinary(TreeTrace[] traces, double epsilon,
int burnIn, boolean verbose) {
// initializing global variables
cladeProbabilities = new HashMap<BitSet, Clade>();
// setting global variables
EPSILON = epsilon;
this.traces = traces;
// calculates the burn-in to 10% if it was set out of the boundaries
int minMaxState = Integer.MAX_VALUE;
for (TreeTrace trace : traces) {
if (trace.getMaximumState() < minMaxState) {
minMaxState = trace.getMaximumState();
}
}
Tree tree = traces[0].getTree(0, burnIn);
TAXA_COUNT = tree.getExternalNodeCount();
if (burnIn < 0 || burnIn >= minMaxState) {
this.burnin = minMaxState / (10 * traces[0].getStepSize());
if (verbose)
System.out
.println("WARNING: Burn-in larger than total number of states - using 10% of smallest trace");
} else {
this.burnin = burnIn;
}
// analyzing the whole trace -> reading the trees
analyzeTrace(verbose);
}
/**
* Actually analyzes the trace given the burn-in. Each tree from the trace
* is read and the conditional clade frequencies incremented.
*
* @param verbose if true then progress is logged to stdout
*/
public void analyzeTrace(boolean verbose) {
if (verbose) {
if (traces.length > 1)
System.out.println("Combining " + traces.length + " traces.");
}
// get first tree to extract the taxon
Tree tree = getTree(0);
// read every tree from the trace
for (TreeTrace trace : traces) {
// do some output stuff
int treeCount = trace.getTreeCount(burnin * trace.getStepSize());
double stepSize = treeCount / 60.0;
int counter = 1;
if (verbose) {
System.out.println("Analyzing " + treeCount + " trees...");
System.out
.println("0 25 50 75 100");
System.out
.println("|--------------|--------------|--------------|--------------|");
System.out.print("*");
}
for (int i = 1; i < treeCount; i++) {
// get the next tree
tree = trace.getTree(i, burnin * trace.getStepSize());
// add the tree and its clades to the frequencies
addTree(tree);
// some more output stuff
if (i >= (int) Math.round(counter * stepSize) && counter <= 60) {
if (verbose) {
System.out.print("*");
System.out.flush();
}
counter += 1;
}
}
if (verbose) {
System.out.println("*");
}
}
}
/**
* Creates the report. The estimated posterior of the given tree is printed.
*
* @throws IOException if general I/O error occurs
*/
public void report(Tree tree) throws IOException {
System.err.println("making report");
SimpleTree sTree = new SimpleTree(tree);
System.out
.println("Estimated marginal posterior by condiational clade frequencies:");
System.out.println(getTreeProbability(sTree));
System.out.flush();
}
/**
* Calculates the probability of a given tree.
*
* @param tree - the tree to be analyzed
* @return estimated posterior probability in log
*/
public double getTreeProbability(SimpleTree tree) {
return calculateTreeProbabilityLog(tree);
// return calculateTreeProbabilityLogRecursive(tree, tree.getRoot());
}
/**
* Calculates the probability of a given tree.
*
* @param tree - the tree to be analyzed
* @return estimated posterior probability in log
*/
public double getTreeProbability(SimpleTree tree, HashMap<String, Integer> taxonMap) {
return calculateTreeProbabilityLog(tree, taxonMap);
// return calculateTreeProbabilityLogRecursive(tree, tree.getRoot());
}
/**
* Calculates the probability of a given tree.
*
* @param tree - the tree to be analyzed
* @return estimated posterior probability in log
*/
private double calculateTreeProbabilityLog(Tree tree) {
double prob = 0.0;
// calculate the number of possible splits
final double splits = Math.pow(2, tree.getExternalNodeCount() - 1) - 1;
List<Clade> clades = new ArrayList<Clade>();
List<Clade> parentClades = new ArrayList<Clade>();
// get clades contained in the tree
getClades(tree, tree.getRoot(), parentClades, clades);
// for every clade multiply its probability to the
// tree probability
for (Clade c : clades) {
// set the occurrences to epsilon
double occurrences = EPSILON;
if (cladeProbabilities.containsKey(c.getBits())) {
// if we observed this clade in the trace, add the occurrences
// to epsilon
occurrences += cladeProbabilities.get(c.getBits())
.getSampleCount();
}
// multiply the conditional clade probability to the tree
// probability
prob += Math.log(occurrences / (samples + (splits * EPSILON)));
}
return prob;
}
/**
* Calculates the probability of a given tree.
*
* @param tree - the tree to be analyzed
* @return estimated posterior probability in log
*/
private double calculateTreeProbabilityLog(Tree tree, HashMap<String, Integer> taxonMap) {
double prob = 0.0;
// calculate the number of possible splits
final double splits = Math.pow(2, tree.getExternalNodeCount() - 1) - 1;
List<Clade> clades = new ArrayList<Clade>();
List<Clade> parentClades = new ArrayList<Clade>();
// get clades contained in the tree
getClades(tree, tree.getRoot(), parentClades, clades, taxonMap);
// for every clade multiply its probability to the
// tree probability
for (Clade c : clades) {
// set the occurrences to epsilon
double occurrences = EPSILON;
if (cladeProbabilities.containsKey(c.getBits())) {
// if we observed this clade in the trace, add the occurrences
// to epsilon
double cladesInTreeSpace = getTrees(c.getBits().cardinality()) * getTrees(TAXA_COUNT - c.getBits().cardinality() + 1);
occurrences += (cladeProbabilities.get(c.getBits())
.getSampleCount() / cladesInTreeSpace);
}
// multiply the conditional clade probability to the tree
// probability
prob += Math.log(occurrences / (samples + (splits * EPSILON)));
}
return prob;
}
private double getTrees(int n) {
double trees = 1;
for (int i = 3; i <= n; i++) {
trees *= (2 * i - 3);
}
return trees;
}
/**
* Calculates the probability of a given tree recursively.
*
* @param tree - the tree to be analyzed
* @param node - the node at which the subtree is rooted for which the
* probability has to be calculated
* @return estimated posterior probability in log
*/
private double calculateTreeProbabilityLogRecursive(Tree tree, NodeRef node) {
double prob = 0.0;
NodeRef leftChild = tree.getChild(node, 0);
NodeRef rightChild = tree.getChild(node, 1);
if (tree.isExternal(leftChild) && tree.isExternal(rightChild)) {
// both children are external nodes
return 0.0;
} else if (!tree.isExternal(leftChild) && !tree.isExternal(rightChild)) {
// both children are internal nodes
Clade leftSubclade = getClade(tree, leftChild);
Clade rightSubclade = getClade(tree, rightChild);
double sum = 0.0;
if (cladeProbabilities.containsKey(leftSubclade.getBits())) {
sum += (cladeProbabilities.get(leftSubclade.getBits())
.getSampleCount() + EPSILON)
/ samples;
} else {
sum += EPSILON / samples;
}
if (cladeProbabilities.containsKey(rightSubclade.getBits())) {
sum += (cladeProbabilities.get(rightSubclade.getBits())
.getSampleCount() + EPSILON)
/ samples;
} else {
sum += EPSILON / samples;
}
prob += Math.log(sum / 2.0);
prob += calculateTreeProbabilityLogRecursive(tree, leftChild);
prob += calculateTreeProbabilityLogRecursive(tree, rightChild);
return prob;
} else {
Clade leftSubclade = getClade(tree, leftChild);
Clade rightSubclade = getClade(tree, rightChild);
double sum = 0.0;
if (leftSubclade.getSize() > 1) {
if (cladeProbabilities.containsKey(leftSubclade.getBits())) {
sum += (cladeProbabilities.get(leftSubclade.getBits())
.getSampleCount() + EPSILON)
/ samples;
} else {
sum += EPSILON / samples;
}
}
if (rightSubclade.getSize() > 1) {
if (cladeProbabilities.containsKey(rightSubclade.getBits())) {
sum += (cladeProbabilities.get(rightSubclade.getBits())
.getSampleCount() + EPSILON)
/ samples;
} else {
sum += EPSILON / samples;
}
}
prob += Math.log(sum);
if (!tree.isExternal(leftChild)) {
prob += calculateTreeProbabilityLogRecursive(tree, leftChild);
}
if (!tree.isExternal(rightChild)) {
prob += calculateTreeProbabilityLogRecursive(tree, rightChild);
}
return prob;
}
}
/* (non-Javadoc)
* @see dr.evomodel.tree.AbstractCladeImportanceDistribution#getChanceForNodeHeights(dr.evomodel.tree.TreeModel, dr.inference.model.Likelihood, dr.inference.prior.Prior)
*/
@Override
public double getChanceForNodeHeights(TreeModel tree, Likelihood likelihood) {
// TODO Auto-generated method stub
return 0;
}
/* (non-Javadoc)
* @see dr.evomodel.tree.AbstractCladeImportanceDistribution#setNodeHeights(dr.evomodel.tree.TreeModel, dr.inference.model.Likelihood, dr.inference.prior.Prior)
*/
@Override
public double setNodeHeights(TreeModel tree, Likelihood likelihood) {
// TODO Auto-generated method stub
return 0;
}
/**
* get the i'th tree of the trace
*
* @param index
* @return the i'th tree of the trace
*/
public final Tree getTree(int index) {
int oldTreeCount = 0;
int newTreeCount = 0;
for (TreeTrace trace : traces) {
newTreeCount += trace.getTreeCount(burnin * trace.getStepSize());
if (index < newTreeCount) {
return trace.getTree(index - oldTreeCount, burnin
* trace.getStepSize());
}
oldTreeCount = newTreeCount;
}
throw new RuntimeException("Couldn't find tree " + index);
}
/**
* increments the number of occurrences for all conditional clades
*
* @param tree - the tree to be added
*/
public void addTree(Tree tree) {
samples++;
List<Clade> clades = new ArrayList<Clade>();
List<Clade> parentClades = new ArrayList<Clade>();
// get clades contained in the tree
getClades(tree, tree.getRoot(), parentClades, clades);
// increment the occurrences of the clade and the conditional clade
for (Clade c : clades) {
// increment the clade occurrences
if (cladeProbabilities.containsKey(c.getBits())) {
Clade tmp = cladeProbabilities.get(c.getBits());
tmp.addHeight(c.getHeight());
// frequency += cladeProbabilities.get(c);
} else {
// just to set the first value of the height value list
c.addHeight(c.getHeight());
cladeProbabilities.put(c.getBits(), c);
}
}
}
/**
* increments the number of occurrences for all conditional clades
*
* @param tree - the tree to be added
*/
public void addTree(Tree tree, HashMap<String, Integer> taxonMap) {
samples++;
List<Clade> clades = new ArrayList<Clade>();
List<Clade> parentClades = new ArrayList<Clade>();
// get clades contained in the tree
getClades(tree, tree.getRoot(), parentClades, clades, taxonMap);
// increment the occurrences of the clade and the conditional clade
for (Clade c : clades) {
// increment the clade occurrences
if (cladeProbabilities.containsKey(c.getBits())) {
Clade tmp = cladeProbabilities.get(c.getBits());
tmp.addHeight(c.getHeight());
// frequency += cladeProbabilities.get(c);
} else {
// just to set the first value of the height value list
c.addHeight(c.getHeight());
cladeProbabilities.put(c.getBits(), c);
}
}
}
/**
* @param reader the readers to be analyzed
* @param burnin the burnin in states
* @param verbose true if progress should be logged to stdout
* @return an analyses of the trees in a log file.
* @throws java.io.IOException if general I/O error occurs
*/
public static ConditionalCladeFrequency analyzeLogFile(Reader[] reader,
double e, int burnin, boolean verbose) throws IOException {
TreeTrace[] trace = new TreeTrace[reader.length];
for (int i = 0; i < reader.length; i++) {
try {
trace[i] = TreeTrace.loadTreeTrace(reader[i]);
} catch (Importer.ImportException ie) {
throw new RuntimeException(ie.toString());
}
reader[i].close();
}
return new ConditionalCladeFrequency(trace, e, burnin, verbose);
}
/*
* (non-Javadoc)
*
* @see
* dr.evolution.tree.ImportanceDistribution#getTreeProbability(dr.evolution
* .tree.Tree)
*/
public double getTreeProbability(Tree tree) {
return calculateTreeProbabilityLogRecursive(tree, tree.getRoot());
}
/*
* (non-Javadoc)
*
* @see
* dr.evolution.tree.ImportanceDistribution#splitClade(dr.evolution.tree
* .Clade, dr.evolution.tree.Clade[])
*/
public double splitClade(Clade parent, Clade[] children) {
// the number of all possible clades is 2^n with n the number of tips
// reduced by 2 because we wont consider the clades with all or no tips
// contained
// note: this time we consider each clade of a split separately with its
// own probability because every clade has a different chance for
// itself.
// #splits = 2^(n) - 1
final double splits = Math.pow(2, parent.getSize()) - 1;
double prob = 0;
double sum = 0.0;
List<Clade> childClades = getPossibleChildren(parent);
for (Clade child : childClades) {
sum += child.getSampleCount();
}
sum += EPSILON * splits;
double randomNumber = Math.random() * sum;
for (Clade child : childClades) {
randomNumber -= (child.getSampleCount() + EPSILON);
if (randomNumber < 0) {
children[0] = child;
double chance = (child.getSampleCount() + EPSILON) / samples;
// the other clade which would have resulted into the same split
BitSet secondChild = (BitSet) children[0].getBits().clone();
secondChild.xor(parent.getBits());
if (secondChild.cardinality() > 1) {
Clade counterClade = cladeProbabilities.get(secondChild);
if (counterClade != null) {
chance += (counterClade.getSampleCount() + EPSILON)
/ samples;
} else {
chance += EPSILON / samples;
}
prob = chance / 2.0;
} else {
prob = chance;
}
break;
}
}
// we take a clade which we haven't seen so far
if (randomNumber >= 0) {
// System.out.println("Random Clade");
BitSet newChild;
do {
do {
newChild = (BitSet) parent.getBits().clone();
int index = -1;
do {
index = newChild.nextSetBit(index + 1);
if (index > -1 && MathUtils.nextBoolean()) {
newChild.clear(index);
}
} while (index > -1);
} while (newChild.cardinality() == 0
|| newChild.cardinality() == parent.getSize());
} while (cladeProbabilities.containsKey(newChild));
Clade randomClade = new Clade(newChild, 0.5);
children[0] = randomClade;
BitSet secondChild = (BitSet) children[0].getBits().clone();
secondChild.xor(parent.getBits());
if (cladeProbabilities.containsKey(secondChild)) {
children[1] = cladeProbabilities.get(secondChild);
} else {
children[1] = new Clade(secondChild, 0.5);
}
if (children[0].getSize() > 1 && children[1].getSize() > 1) {
prob = (children[0].getSampleCount()
+ children[1].getSampleCount() + (2.0 * EPSILON))
/ (samples * 2.0);
} else {
if (children[0].getSize() > 1) {
prob = (children[0].getSampleCount() + EPSILON) / samples;
} else {
prob = (children[1].getSampleCount() + EPSILON) / samples;
}
}
} else {
BitSet secondChild = (BitSet) children[0].getBits().clone();
secondChild.xor(parent.getBits());
children[1] = cladeProbabilities.get(secondChild);
// children[1] = childClades.get(secondChild);
if (children[1] == null) {
children[1] = new Clade(secondChild, 0.5);
children[1].addHeight(0.5);
}
}
return Math.log(prob);
}
/**
* Finds all possible children clades which we have observed already. A
* clade is a possible child clade if is a subset of taxa of the parent
*
* @param parent - the parent clade of which we want to find the possible
* children
* @return a List<Clade> of the possible child clades
*/
private List<Clade> getPossibleChildren(Clade parent) {
List<Clade> children = new ArrayList<Clade>();
Set<BitSet> keys = cladeProbabilities.keySet();
for (BitSet key : keys) {
if (key.cardinality() < parent.getSize()) {
if (containsClade(parent.getBits(), key)) {
children.add(cladeProbabilities.get(key));
}
}
}
return children;
}
}