/*
* AbstractImportanceDistributionOperator.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
/**
*
*/
package dr.evomodel.operators;
import dr.evolution.tree.Clade;
import dr.evolution.tree.MutableTree.InvalidTreeException;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.tree.AbstractCladeImportanceDistribution;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Likelihood;
import dr.inference.operators.*;
import dr.math.MathUtils;
import java.util.*;
/**
* @author Sebastian Hoehna
*/
// Cleaning out untouched stuff. Can be resurrected if needed
@Deprecated
public abstract class AbstractImportanceDistributionOperator extends
SimpleMCMCOperator implements GeneralOperator {
private long transitions = 0;
private OperatorSchedule schedule;
protected TreeModel tree;
protected AbstractCladeImportanceDistribution probabilityEstimater;
private int sampleEvery;
private int samples;
private int sampleCount;
private Queue<NodeRef> internalNodes;
private Map<Integer, NodeRef> externalNodes;
private boolean burnin = false;
/**
*
*/
public AbstractImportanceDistributionOperator(TreeModel tree, double weight) {
super();
this.tree = tree;
setWeight(weight);
this.samples = 10000;
this.sampleEvery = 10;
init();
}
/**
*
*/
public AbstractImportanceDistributionOperator(TreeModel tree,
double weight, int samples, int sampleEvery) {
super();
this.tree = tree;
setWeight(weight);
this.samples = samples;
this.sampleEvery = sampleEvery;
init();
}
private void init() {
schedule = getOperatorSchedule(tree);
sampleCount = 0;
internalNodes = new LinkedList<NodeRef>();
externalNodes = new HashMap<Integer, NodeRef>();
fillExternalNodes(tree.getRoot());
}
/*
* (non-Javadoc)
*
* @see dr.inference.operators.AbstractImportanceSampler#doOperation()
*/
public double doOperation() {
// dummy method
return 0.0;
}
/*
* (non-Javadoc)
*
* @see dr.inference.operators.AbstractImportanceSampler#doOperation()
*/
public double doOperation(Likelihood likelihood) {
if (!burnin) {
if (sampleCount < samples * sampleEvery) {
sampleCount++;
if (sampleCount % sampleEvery == 0) {
probabilityEstimater.addTree(tree);
}
setAcceptCount(0);
setRejectCount(0);
setTransitions(0);
return doUnguidedOperation();
} else {
return doImportanceDistributionOperation(likelihood);
}
} else {
return doUnguidedOperation();
}
}
protected double doImportanceDistributionOperation(Likelihood likelihood) {
final NodeRef root = tree.getRoot();
BitSet all = new BitSet();
all.set(0, (tree.getNodeCount() + 1) / 2);
Clade rootClade = new Clade(all, tree.getNodeHeight(root));
internalNodes.clear();
fillInternalNodes(root);
// remove the root
internalNodes.poll();
externalNodes.clear();
fillExternalNodes(root);
double prob;
double back = probabilityEstimater.getTreeProbability(tree);
try {
tree.beginTreeEdit();
List<Clade> originalClades = new ArrayList<Clade>();
extractClades(tree, tree.getRoot(), originalClades, null);
double[] originalNodeHeights = getAbsoluteNodeHeights(originalClades);
Arrays.sort(originalNodeHeights);
back += getChanceForNodeHeights(originalNodeHeights);
prob = createTree(root, rootClade);
assignDummyHeights(root);
// assignCladeHeights(tree.getRoot(), originalClades, null);
// double[] originalNodeHeights = getAbsoluteNodeHeights(originalClades);
// Arrays.sort(originalNodeHeights);
// prob += setMissingNodeHeights(tree.getChild(tree.getRoot(),0));
// prob += setMissingNodeHeights(tree.getChild(tree.getRoot(),1));
prob += setNodeHeights(originalNodeHeights);
// List<Clade> newClades = new ArrayList<Clade>();
// extractClades(tree, tree.getRoot(), newClades, null);
tree.endTreeEdit();
tree.checkTreeIsValid();
} catch (InvalidTreeException e) {
throw new RuntimeException(e.getMessage());
}
tree.pushTreeChangedEvent(root);
return back - prob;
}
private void assignDummyHeights(NodeRef node) {
double rootHeight = tree.getNodeHeight(node) * tree.getInternalNodeCount();
tree.setNodeHeight(node, rootHeight);
int childcount = tree.getChildCount(node);
for (int i = 0; i < childcount; i++) {
NodeRef child = tree.getChild(node, i);
if (!tree.isExternal(child)) {
assignDummyHeights(child, rootHeight / 2.0);
}
}
}
private void assignDummyHeights(NodeRef node, double height) {
assert (!tree.isExternal(node));
tree.setNodeHeight(node, height);
int childcount = tree.getChildCount(node);
for (int i = 0; i < childcount; i++) {
NodeRef child = tree.getChild(node, i);
if (!tree.isExternal(child)) {
assignDummyHeights(child, height / 2.0);
}
}
}
private double createTree(NodeRef node, Clade c)
throws InvalidTreeException {
double prob = 0.0;
if (c.getSize() == 2) {
// this clade only contains two tips
// the split between them is trivial
int leftTipIndex = c.getBits().nextSetBit(0);
int rightTipIndex = c.getBits().nextSetBit(leftTipIndex + 1);
NodeRef leftTip = externalNodes.get(leftTipIndex);
NodeRef rightTip = externalNodes.get(rightTipIndex);
removeChildren(node);
NodeRef leftParent = tree.getParent(leftTip);
if (leftParent != null)
tree.removeChild(leftParent, leftTip);
NodeRef rightParent = tree.getParent(rightTip);
if (rightParent != null)
tree.removeChild(rightParent, rightTip);
tree.addChild(node, leftTip);
tree.addChild(node, rightTip);
} else {
Clade[] clades = new Clade[2];
prob = splitClade(c, clades);
NodeRef leftChild, rightChild;
if (clades[0].getSize() == 1) {
int tipIndex = clades[0].getBits().nextSetBit(0);
leftChild = externalNodes.get(tipIndex);
} else {
leftChild = internalNodes.poll();
// TODO set the node height for the new node
tree.setNodeHeight(leftChild, tree.getNodeHeight(node) * 0.5);
prob += createTree(leftChild, clades[0]);
}
if (clades[1].getSize() == 1) {
int tipIndex = clades[1].getBits().nextSetBit(0);
rightChild = externalNodes.get(tipIndex);
} else {
rightChild = internalNodes.poll();
// TODO set the node height for the new node
tree.setNodeHeight(rightChild, tree.getNodeHeight(node) * 0.5);
prob += createTree(rightChild, clades[1]);
}
removeChildren(node);
NodeRef leftParent = tree.getParent(leftChild);
if (leftParent != null)
tree.removeChild(leftParent, leftChild);
NodeRef rightParent = tree.getParent(rightChild);
if (rightParent != null)
tree.removeChild(rightParent, rightChild);
tree.addChild(node, leftChild);
tree.addChild(node, rightChild);
}
return prob;
}
/**
* @param parent
* @warning assumes strictly bifurcating trees
*/
private void removeChildren(NodeRef parent) {
// assumes strictly bifurcating trees
NodeRef child = tree.getChild(parent, 0);
if (child != null) {
tree.removeChild(parent, child);
}
child = tree.getChild(parent, 1);
if (child != null) {
tree.removeChild(parent, child);
}
}
private double splitClade(Clade c, Clade[] children) {
return probabilityEstimater.splitClade(c, children);
}
/**
* Creates a list with all clades of the tree
*
* @param tree - the tree from which the clades are extracted
* @param node - the starting node. All clades below starting at this branch
* are added
* @param clades - the list in which the clades are stored
* @param bits - a bit set to which the current bits of the clades are added
*/
private void extractClades(Tree tree, NodeRef node,
List<Clade> clades, BitSet bits) {
// create a new bit set for this clade
BitSet bits2 = new BitSet();
// check if the node is external
if (tree.isExternal(node)) {
// if so, the only taxon in the clade is I
int index = node.getNumber();
bits2.set(index);
} else {
// otherwise, call all children and add its taxon together to one
// clade
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef child = tree.getChild(node, i);
extractClades(tree, child, clades, bits2);
}
// add my bit set to the list
clades.add(new Clade(bits2, tree.getNodeHeight(node)));
}
// add my bit set to the bit set I was given
// this is needed for adding all children clades together
if (bits != null) {
bits.or(bits2);
}
}
/**
* Creates a list with all clades of the tree
*
* @param node - the starting node. All clades below starting at this branch
* are added
* @param clades - the list in which the clades are stored
*/
private void assignCladeHeights(NodeRef node,
HashMap<Clade, Double> clades, BitSet bits) {
// create a new bit set for this clade
BitSet bits2 = new BitSet();
// check if the node is external
if (tree.isExternal(node)) {
// if so, the only taxon in the clade is I
int index = node.getNumber();
bits2.set(index);
} else {
// otherwise, call all children and add its taxon together to one
// clade
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef child = tree.getChild(node, i);
assignCladeHeights(child, clades, bits2);
}
Clade c = new Clade(bits2, tree.getNodeHeight(node));
if (clades.containsKey(c)) {
tree.setNodeHeight(node, clades.get(c));
clades.remove(c);
}
}
// add my bit set to the bit set I was given
// this is needed for adding all children clades together
if (bits != null) {
bits.or(bits2);
}
}
private double[] getRelativeNodeHeights(Tree tree) {
int count = tree.getInternalNodeCount();
double[] nodeHeights = new double[count];
for (int i = 0; i < count; i++) {
NodeRef node = tree.getInternalNode(i);
NodeRef parent = tree.getParent(node);
nodeHeights[i] = tree.getNodeHeight(node)
/ tree.getNodeHeight(parent);
}
return nodeHeights;
}
private double[] getAbsoluteNodeHeights(Tree tree) {
int count = tree.getInternalNodeCount();
double[] nodeHeights = new double[count];
for (int i = 0; i < count; i++) {
NodeRef node = tree.getInternalNode(i);
nodeHeights[i] = tree.getNodeHeight(node);
}
return nodeHeights;
}
private double[] getAbsoluteNodeHeights(List<Clade> clades) {
double[] nodeHeights = new double[clades.size()];
int count = 0;
for (Clade c : clades) {
nodeHeights[count] = c.getHeight();
count++;
}
return nodeHeights;
}
private double getChanceForNodeHeights(double[] nodeHeights) {
return getChanceOfPermuation(nodeHeights);
// return getChanceOfUniformNodeHeights(tree.getRoot());
// return probabilityEstimater.getChanceForNodeHeights(tree, likelihood, prior);
}
private double getChanceOfUniformNodeHeights(NodeRef parent) {
double prob = 0.0;
NodeRef leftChild = tree.getChild(parent, 0);
NodeRef rightChild = tree.getChild(parent, 1);
if (!tree.isExternal(leftChild)) {
prob += Math.log(1.0 / tree.getNodeHeight(parent));
prob += getChanceOfUniformNodeHeights(leftChild);
}
if (!tree.isExternal(rightChild)) {
prob += Math.log(1.0 / tree.getNodeHeight(parent));
prob += getChanceOfUniformNodeHeights(rightChild);
}
return prob;
}
private double getChanceOfPermuation(double[] nodeHeights) {
List<NodeRef> nodes = new LinkedList<NodeRef>();
NodeRef root = tree.getRoot();
NodeRef leftChild = tree.getChild(root, 0);
NodeRef rightChild = tree.getChild(root, 1);
if (!tree.isExternal(leftChild)) {
nodes.add(leftChild);
}
if (!tree.isExternal(rightChild)) {
nodes.add(rightChild);
}
int pointer = nodeHeights.length - 2;
double prob = 0.0;
while (!nodes.isEmpty()) {
int index = getHighestNode(nodes);
prob += Math.log(1.0 / nodes.size());
NodeRef n = nodes.remove(index);
tree.setNodeHeight(n, nodeHeights[pointer]);
pointer--;
leftChild = tree.getChild(n, 0);
rightChild = tree.getChild(n, 1);
if (!tree.isExternal(leftChild)) {
nodes.add(leftChild);
}
if (!tree.isExternal(rightChild)) {
nodes.add(rightChild);
}
}
return prob;
}
private int getHighestNode(List<NodeRef> nodes) {
double maxHeight = 0;
int index = 0;
for (int i = 0; i < nodes.size(); i++) {
NodeRef n = nodes.get(i);
if (tree.getNodeHeight(n) > maxHeight) {
maxHeight = tree.getNodeHeight(n);
index = i;
}
}
return index;
}
private double setNodeHeights(double[] nodeHeights) {
// return setUniformNodeHeights(tree.getRoot());
return assignPermutedNodeHeights(nodeHeights);
// return probabilityEstimater.setNodeHeights(tree, likelihood, prior);
}
private double setUniformNodeHeights(NodeRef parent) {
double prob = 0.0;
NodeRef leftChild = tree.getChild(parent, 0);
NodeRef rightChild = tree.getChild(parent, 1);
if (!tree.isExternal(leftChild)) {
double max = tree.getNodeHeight(parent);
double height = max * MathUtils.nextDouble();
tree.setNodeHeight(leftChild, height);
prob += Math.log(1.0 / max);
prob += setUniformNodeHeights(leftChild);
}
if (!tree.isExternal(rightChild)) {
double max = tree.getNodeHeight(parent);
double height = max * MathUtils.nextDouble();
tree.setNodeHeight(rightChild, height);
prob += Math.log(1.0 / max);
prob += setUniformNodeHeights(rightChild);
}
return prob;
}
private double assignPermutedNodeHeights(double[] nodeHeights) {
List<NodeRef> nodes = new LinkedList<NodeRef>();
NodeRef root = tree.getRoot();
NodeRef leftChild = tree.getChild(root, 0);
NodeRef rightChild = tree.getChild(root, 1);
if (!tree.isExternal(leftChild)) {
nodes.add(leftChild);
}
if (!tree.isExternal(rightChild)) {
nodes.add(rightChild);
}
int pointer = nodeHeights.length - 2;
double prob = 0.0;
while (!nodes.isEmpty()) {
int index = MathUtils.nextInt(nodes.size());
prob += Math.log(1.0 / nodes.size());
NodeRef n = nodes.remove(index);
tree.setNodeHeight(n, nodeHeights[pointer]);
pointer--;
leftChild = tree.getChild(n, 0);
rightChild = tree.getChild(n, 1);
if (!tree.isExternal(leftChild)) {
nodes.add(leftChild);
}
if (!tree.isExternal(rightChild)) {
nodes.add(rightChild);
}
}
return prob;
}
private double setMissingNodeHeights(NodeRef node) {
double prob = 0.0;
// check if the node is external
if (!tree.isExternal(node)) {
// otherwise, call all children and add its taxon together to one
// clade
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef child = tree.getChild(node, i);
setMissingNodeHeights(child);
}
double min = getMinNodeHeight(node);
double max = getMaxNodeHeight(node);
if (max <= min) {
max = tree.getNodeHeight(tree.getRoot());
}
prob += Math.log(1.0 / (max - min));
double height = min + MathUtils.nextDouble() * (max - min);
tree.setNodeHeight(node, height);
}
return prob;
}
private double getMinNodeHeight(NodeRef node) {
double min = Double.MAX_VALUE;
for (int i = 0; i < tree.getChildCount(node); i++) {
NodeRef child = tree.getChild(node, i);
double height = tree.getNodeHeight(child);
if (height < min) {
min = height;
}
}
return min;
}
private double getMaxNodeHeight(NodeRef node) {
return tree.getNodeHeight(tree.getParent(node));
}
private void fillInternalNodes(NodeRef node) {
if (!tree.isExternal(node)) {
internalNodes.add(node);
int childCount = tree.getChildCount(node);
for (int i = 0; i < childCount; i++) {
fillInternalNodes(tree.getChild(node, i));
}
}
}
private void fillExternalNodes(NodeRef node) {
if (!tree.isExternal(node)) {
int childCount = tree.getChildCount(node);
for (int i = 0; i < childCount; i++) {
fillExternalNodes(tree.getChild(node, i));
}
} else {
Integer i = node.getNumber();
externalNodes.put(i, node);
}
}
private OperatorSchedule getOperatorSchedule(TreeModel treeModel) {
ExchangeOperator narrowExchange = new ExchangeOperator(
ExchangeOperator.NARROW, treeModel, 10);
ExchangeOperator wideExchange = new ExchangeOperator(
ExchangeOperator.WIDE, treeModel, 3);
SubtreeSlideOperator subtreeSlide = new SubtreeSlideOperator(treeModel,
10.0, 1.0, true, false, false, false, CoercionMode.COERCION_ON);
NNI nni = new NNI(treeModel, 10.0);
WilsonBalding wilsonBalding = new WilsonBalding(treeModel, 3.0);
FNPR fnpr = new FNPR(treeModel, 5.0);
OperatorSchedule schedule = new SimpleOperatorSchedule();
schedule.addOperator(narrowExchange);
schedule.addOperator(wideExchange);
schedule.addOperator(subtreeSlide);
schedule.addOperator(nni);
schedule.addOperator(wilsonBalding);
schedule.addOperator(fnpr);
return schedule;
}
protected double doUnguidedOperation() {
int index = schedule.getNextOperatorIndex();
SimpleMCMCOperator operator = (SimpleMCMCOperator) schedule
.getOperator(index);
return operator.doOperation();
}
/**
* @return the number of transitions since last call to reset().
*/
public long getTransitions() {
return transitions;
}
/**
* Set the number of transitions since last call to reset(). This is used to
* restore the state of the operator
*/
public void setTransitions(int transitions) {
this.transitions = transitions;
}
public double getTransistionProbability() {
long accepted = getAcceptCount();
long rejected = getRejectCount();
long transition = getTransitions();
return (double) transition / (double) (accepted + rejected);
}
public void reset() {
super.reset();
transitions = 0;
}
public double getMinimumAcceptanceLevel() {
return 0.50;
}
public double getMaximumAcceptanceLevel() {
return 1.0;
}
public double getMinimumGoodAcceptanceLevel() {
return 0.75;
}
public double getMaximumGoodAcceptanceLevel() {
return 1.0;
}
/*
* (non-Javadoc)
*
* @see dr.inference.operators.SimpleMCMCOperator#getOperatorName()
*/
@Override
public abstract String getOperatorName();
/*
* (non-Javadoc)
*
* @see dr.inference.operators.MCMCOperator#getPerformanceSuggestion()
*/
public abstract String getPerformanceSuggestion();
}