package beast.evolution.operators;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import beast.core.Description;
import beast.core.Input;
import beast.core.Input.Validate;
import beast.core.util.Log;
import beast.evolution.alignment.Taxon;
import beast.evolution.alignment.TaxonSet;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.util.Randomizer;
@Description("Tree operator which randomly changes the height of a node, " +
"then reconstructs the tree from node heights.")
public class NodeReheight extends TreeOperator {
public final Input<TaxonSet> taxonSetInput = new Input<>("taxonset", "taxon set describing species tree taxa and their gene trees", Validate.REQUIRED);
public final Input<List<Tree>> geneTreesInput = new Input<>("genetree", "list of gene trees that constrain species tree movement", new ArrayList<>());
Node[] m_nodes;
/**
* map node number of leafs in gene trees to leaf nr in species tree *
*/
List<Map<Integer, Integer>> m_taxonMap;
int nrOfGeneTrees;
int nrOfSpecies;
@Override
public void initAndValidate() {
/** maps gene taxa names to species number **/
final Map<String, Integer> taxonMap = new HashMap<>();
final List<Taxon> list = taxonSetInput.get().taxonsetInput.get();
if (list.size() <= 1) {
Log.err.println("NodeReheight operator requires at least 2 taxa while the taxon set (id=" + taxonSetInput.get().getID() +") has only " + list.size() + " taxa. "
+ "If the XML file was set up in BEAUti, this probably means a taxon assignment needs to be set up in the taxonset panel.");
// assume we are in BEAUti, back off for now
return;
}
for (int i = 0; i < list.size(); i++) {
final Taxon taxa = list.get(i);
// cast should be ok if taxon-set is the set for the species tree
final TaxonSet set = (TaxonSet) taxa;
for (final Taxon taxon : set.taxonsetInput.get()) {
taxonMap.put(taxon.getID(), i);
}
}
/** build the taxon map for each gene tree **/
m_taxonMap = new ArrayList<>();
for (final Tree tree : geneTreesInput.get()) {
final Map<Integer, Integer> map = new HashMap<>();
setupTaxaMap(tree.getRoot(), map, taxonMap);
m_taxonMap.add(map);
}
nrOfGeneTrees = geneTreesInput.get().size();
nrOfSpecies = treeInput.get().getLeafNodeCount();
}
// initialisation code: create node number in gene tree to node number in species tree map
private void setupTaxaMap(final Node node, final Map<Integer, Integer> map, final Map<String, Integer> taxonMap) {
if (node.isLeaf()) {
map.put(node.getNr(), taxonMap.get(node.getID()));
} else {
setupTaxaMap(node.getLeft(), map, taxonMap);
setupTaxaMap(node.getRight(), map, taxonMap);
}
}
@Override
public double proposal() {
final Tree tree = treeInput.get();
m_nodes = tree.getNodesAsArray();
final int nodeCount = tree.getNodeCount();
// randomly change left/right order
tree.startEditing(this); // we change the tree
reorder(tree.getRoot());
// collect heights
final double[] heights = new double[nodeCount];
final int[] reverseOrder = new int[nodeCount];
collectHeights(tree.getRoot(), heights, reverseOrder, 0);
// change height of an internal node
int nodeIndex = Randomizer.nextInt(heights.length);
while (m_nodes[reverseOrder[nodeIndex]].isLeaf()) {
nodeIndex = Randomizer.nextInt(heights.length);
}
final double maxHeight = calcMaxHeight(reverseOrder, nodeIndex);
heights[nodeIndex] = Randomizer.nextDouble() * maxHeight;
m_nodes[reverseOrder[nodeIndex]].setHeight(heights[nodeIndex]);
// reconstruct tree from heights
final Node root = reconstructTree(heights, reverseOrder, 0, heights.length, new boolean[heights.length]);
assert checkConsistency(root, new boolean[heights.length]) ;
// System.err.println("Inconsisten tree");
// }
root.setParent(null);
tree.setRoot(root);
return 0;
}
private boolean checkConsistency(final Node node, final boolean[] used) {
if (used[node.getNr()]) {
// used twice? tha's bad
return false;
}
used[node.getNr()] = true;
if ( node.isLeaf() ) {
return true;
}
return checkConsistency(node.getLeft(), used) && checkConsistency(node.getRight(), used);
}
/**
* calculate maximum height that node nodeIndex can become restricted
* by nodes on the left and right
*/
private double calcMaxHeight(final int[] reverseOrder, final int nodeIndex) {
// find maximum height between two species. Only upper right part is populated
final double[][] maxHeight = new double[nrOfSpecies][nrOfSpecies];
for (int i = 0; i < nrOfSpecies; i++) {
Arrays.fill(maxHeight[i], Double.POSITIVE_INFINITY);
}
// calculate for every species tree the maximum allowable merge point
for (int i = 0; i < nrOfGeneTrees; i++) {
final Tree tree = geneTreesInput.get().get(i);
findMaximaInGeneTree(tree.getRoot(), new boolean[nrOfSpecies], m_taxonMap.get(i), maxHeight);
}
// find species on the left of selected node
final boolean[] isLowerSpecies = new boolean[nrOfSpecies];
final Node[] nodes = treeInput.get().getNodesAsArray();
for (int i = 0; i < nodeIndex; i++) {
final Node node = nodes[reverseOrder[i]];
if (node.isLeaf()) {
isLowerSpecies[node.getNr()] = true;
}
}
// find species on the right of selected node
final boolean[] isUpperSpecies = new boolean[nrOfSpecies];
for (int i = nodeIndex + 1; i < nodes.length; i++) {
final Node node = nodes[reverseOrder[i]];
if (node.isLeaf()) {
isUpperSpecies[node.getNr()] = true;
}
}
// find max
double max = Double.POSITIVE_INFINITY;
for (int i = 0; i < nrOfSpecies; i++) {
if (isLowerSpecies[i]) {
for (int j = 0; j < nrOfSpecies; j++) {
if (j != i && isUpperSpecies[j]) {
final int x = Math.min(i, j);
final int y = Math.max(i, j);
max = Math.min(max, maxHeight[x][y]);
}
}
}
}
return max;
} // calcMaxHeight
/**
* for every species in the left on the gene tree and for every species in the right
* cap the maximum join height by the lowest place the two join in the gene tree
*/
private void findMaximaInGeneTree(final Node node, final boolean[] taxonSet, final Map<Integer, Integer> taxonMap, final double[][] maxHeight) {
if (node.isLeaf()) {
final int species = taxonMap.get(node.getNr());
taxonSet[species] = true;
} else {
final boolean[] isLeftTaxonSet = new boolean[nrOfSpecies];
findMaximaInGeneTree(node.getLeft(), isLeftTaxonSet, taxonMap, maxHeight);
final boolean[] isRightTaxonSet = new boolean[nrOfSpecies];
findMaximaInGeneTree(node.getRight(), isRightTaxonSet, taxonMap, maxHeight);
for (int i = 0; i < nrOfSpecies; i++) {
if (isLeftTaxonSet[i]) {
for (int j = 0; j < nrOfSpecies; j++) {
if (j != i && isRightTaxonSet[j]) {
final int x = Math.min(i, j);
final int y = Math.max(i, j);
maxHeight[x][y] = Math.min(maxHeight[x][y], node.getHeight());
}
}
}
}
for (int i = 0; i < nrOfSpecies; i++) {
taxonSet[i] = isLeftTaxonSet[i] | isRightTaxonSet[i];
}
}
}
/**
* construct tree top down by joining heighest left and right nodes *
*/
private Node reconstructTree(final double[] heights, final int[] reverseOrder, final int from, final int to, final boolean[] hasParent) {
//nodeIndex = maxIndex(heights, 0, heights.length);
int nodeIndex = -1;
double max = Double.NEGATIVE_INFINITY;
for (int j = from; j < to; j++) {
if (max < heights[j] && !m_nodes[reverseOrder[j]].isLeaf()) {
max = heights[j];
nodeIndex = j;
}
}
if (nodeIndex < 0) {
return null;
}
final Node node = m_nodes[reverseOrder[nodeIndex]];
//int left = maxIndex(heights, 0, nodeIndex);
int left = -1;
max = Double.NEGATIVE_INFINITY;
for (int j = from; j < nodeIndex; j++) {
if (max < heights[j] && !hasParent[j]) {
max = heights[j];
left = j;
}
}
//int right = maxIndex(heights, nodeIndex+1, heights.length);
int right = -1;
max = Double.NEGATIVE_INFINITY;
for (int j = nodeIndex + 1; j < to; j++) {
if (max < heights[j] && !hasParent[j]) {
max = heights[j];
right = j;
}
}
node.setLeft(m_nodes[reverseOrder[left]]);
node.getLeft().setParent(node);
node.setRight(m_nodes[reverseOrder[right]]);
node.getRight().setParent(node);
if (node.getLeft().isLeaf()) {
heights[left] = Double.NEGATIVE_INFINITY;
}
if (node.getRight().isLeaf()) {
heights[right] = Double.NEGATIVE_INFINITY;
}
hasParent[left] = true;
hasParent[right] = true;
heights[nodeIndex] = Double.NEGATIVE_INFINITY;
reconstructTree(heights, reverseOrder, from, nodeIndex, hasParent);
reconstructTree(heights, reverseOrder, nodeIndex, to, hasParent);
return node;
}
// helper for reconstructTree, to find maximum in range
// private int maxIndex(final double[] heights, final int from, final int to) {
// int maxIndex = -1;
// double max = Double.NEGATIVE_INFINITY;
// for (int i = from; i < to; i++) {
// if (max < heights[i]) {
// max = heights[i];
// maxIndex = i;
// }
// }
// return maxIndex;
// }
/**
** gather height of each node, and the node index associated with the height.*
**/
private int collectHeights(final Node node, final double[] heights, final int[] reverseOrder, int current) {
if (node.isLeaf()) {
heights[current] = node.getHeight();
reverseOrder[current] = node.getNr();
current++;
} else {
current = collectHeights(node.getLeft(), heights, reverseOrder, current);
heights[current] = node.getHeight();
reverseOrder[current] = node.getNr();
current++;
current = collectHeights(node.getRight(), heights, reverseOrder, current);
}
return current;
}
/**
* randomly changes left and right children in every internal node *
*/
private void reorder(final Node node) {
if (!node.isLeaf()) {
if (Randomizer.nextBoolean()) {
final Node tmp = node.getLeft();
node.setLeft(node.getRight());
node.setRight(tmp);
}
reorder(node.getLeft());
reorder(node.getRight());
}
}
} // class NodeReheight