/*
* File ClusterTree.java
*
* Copyright (C) 2010 Remco Bouckaert remco@cs.auckland.ac.nz
*
* This file is part of BEAST2.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package beast.util;
import beast.core.Description;
import beast.core.Input;
import beast.core.StateNode;
import beast.core.StateNodeInitialiser;
import beast.core.parameter.RealParameter;
import beast.evolution.alignment.Alignment;
import beast.evolution.alignment.distance.Distance;
import beast.evolution.alignment.distance.JukesCantorDistance;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.*;
/**
* Adapted from Weka's HierarchicalClustering class *
*/
@Description("Create initial beast.tree by hierarchical clustering, either through one of the classic link methods " +
"or by neighbor joining. The following link methods are supported: " +
"<br/>o single link, " +
"<br/>o complete link, " +
"<br/>o UPGMA=average link, " +
"<br/>o mean link, " +
"<br/>o centroid, " +
"<br/>o Ward and " +
"<br/>o adjusted complete link " +
"<br/>o neighborjoining " +
"<br/>o neighborjoining2 - corrects tree for tip data, unlike plain neighborjoining")
public class ClusterTree extends Tree implements StateNodeInitialiser {
public enum Type {single, average, complete, upgma, mean, centroid, ward, adjcomplete, neighborjoining, neighborjoining2}
double EPSILON = 1e-10;
final public Input<Type> clusterTypeInput = new Input<>("clusterType", "type of clustering algorithm used for generating initial beast.tree. " +
"Should be one of " + Arrays.toString(Type.values()) + " (default " + Type.average + ")", Type.average, Type.values());
final public Input<Alignment> dataInput = new Input<>("taxa", "alignment data used for calculating distances for clustering");
final public Input<Distance> distanceInput = new Input<>("distance", "method for calculating distance between two sequences (default Jukes Cantor)");
final public Input<RealParameter> clockRateInput = new Input<>("clock.rate",
"the clock rate parameter, used to divide all divergence times by, to convert from substitutions to times. (default 1.0)",
new RealParameter(new Double[] {1.0}));
/**
* Whether the distance represent node height (if false) or branch length (if true).
*/
protected boolean distanceIsBranchLength = false;
Distance distance;
List<String> taxaNames;
/**
* Holds the Link type used calculate distance between clusters
*/
Type linkType = Type.single;
@Override
public void initAndValidate() {
RealParameter clockRate = clockRateInput.get();
if (dataInput.get() != null) {
taxaNames = dataInput.get().getTaxaNames();
} else {
if (m_taxonset.get() == null) {
throw new RuntimeException("At least one of taxa and taxonset input needs to be specified");
}
taxaNames = m_taxonset.get().asStringList();
}
if (Boolean.valueOf(System.getProperty("beast.resume")) &&
(isEstimatedInput.get() || (m_initial.get() != null && m_initial.get().isEstimatedInput.get()))) {
// don't bother creating a cluster tree to save some time, if it is read from file anyway
// make a caterpillar
Node left = newNode();
left.setNr(0);
left.setID(taxaNames.get(0));
left.setHeight(0);
for (int i = 1; i < taxaNames.size(); i++) {
final Node right = newNode();
right.setNr(i);
right.setID(taxaNames.get(i));
right.setHeight(0);
final Node parent = newNode();
parent.setNr(taxaNames.size() + i - 1);
parent.setHeight(i);
left.setParent(parent);
parent.setLeft(left);
right.setParent(parent);
parent.setRight(right);
left = parent;
}
root = left;
leafNodeCount = taxaNames.size();
nodeCount = leafNodeCount * 2 - 1;
internalNodeCount = leafNodeCount - 1;
super.initAndValidate();
return;
}
distance = distanceInput.get();
if (distance == null) {
distance = new JukesCantorDistance();
}
if (distance instanceof Distance.Base){
if (dataInput.get() == null) {
// Distance requires an alignment?
}
((Distance.Base) distance).setPatterns(dataInput.get());
}
linkType = clusterTypeInput.get();
if (linkType == Type.upgma) linkType = Type.average;
if (linkType == Type.neighborjoining || linkType == Type.neighborjoining2) {
distanceIsBranchLength = true;
}
final Node root = buildClusterer();
setRoot(root);
root.labelInternalNodes((getNodeCount() + 1) / 2);
super.initAndValidate();
if (linkType == Type.neighborjoining2) {
// set tip dates to zero
final Node[] nodes = getNodesAsArray();
for (int i = 0; i < getLeafNodeCount(); i++) {
nodes[i].setHeight(0);
}
super.initAndValidate();
}
if (m_initial.get() != null)
processTraits(m_initial.get().m_traitList.get());
else
processTraits(m_traitList.get());
if (timeTraitSet != null)
adjustTreeNodeHeights(root);
else {
// all nodes should be at zero height if no date-trait is available
for (int i = 0; i < getLeafNodeCount(); i++) {
getNode(i).setHeight(0);
}
}
//divide all node heights by clock rate to convert from substitutions to time.
for (Node node : getInternalNodes()) {
double height = node.getHeight();
node.setHeight(height/clockRate.getValue());
}
initStateNodes();
}
public ClusterTree() {
} // c'tor
/**
* class representing node in cluster hierarchy *
*/
class NodeX {
NodeX m_left;
NodeX m_right;
NodeX m_parent;
int m_iLeftInstance;
int m_iRightInstance;
double m_fLeftLength = 0;
double m_fRightLength = 0;
double m_fHeight = 0;
void setHeight(double height1, double height2) {
if (height1 < EPSILON) {
height1 = EPSILON;
}
if (height2 < EPSILON) {
height2 = EPSILON;
}
m_fHeight = height1;
if (m_left == null) {
m_fLeftLength = height1;
} else {
m_fLeftLength = height1 - m_left.m_fHeight;
}
if (m_right == null) {
m_fRightLength = height2;
} else {
m_fRightLength = height2 - m_right.m_fHeight;
}
}
void setLength(double length1, double length2) {
if (length1 < EPSILON) {
length1 = EPSILON;
}
if (length2 < EPSILON) {
length2 = EPSILON;
}
m_fLeftLength = length1;
m_fRightLength = length2;
m_fHeight = length1;
if (m_left != null) {
m_fHeight += m_left.m_fHeight;
}
}
@Override
public String toString() {
final DecimalFormat myFormatter = new DecimalFormat("#.#####", new DecimalFormatSymbols(Locale.US));
if (m_left == null) {
if (m_right == null) {
return "(" + taxaNames.get(m_iLeftInstance) + ":" + myFormatter.format(m_fLeftLength) + "," +
taxaNames.get(m_iRightInstance) + ":" + myFormatter.format(m_fRightLength) + ")";
} else {
return "(" + taxaNames.get(m_iLeftInstance) + ":" + myFormatter.format(m_fLeftLength) + "," +
m_right.toString() + ":" + myFormatter.format(m_fRightLength) + ")";
}
} else {
if (m_right == null) {
return "(" + m_left.toString() + ":" + myFormatter.format(m_fLeftLength) + "," +
taxaNames.get(m_iRightInstance) + ":" + myFormatter.format(m_fRightLength) + ")";
} else {
return "(" + m_left.toString() + ":" + myFormatter.format(m_fLeftLength) + "," + m_right.toString() + ":" + myFormatter.format(m_fRightLength) + ")";
}
}
}
Node toNode() {
final Node node = newNode();
node.setHeight(m_fHeight);
if (m_left == null) {
node.setLeft(newNode());
node.getLeft().setNr(m_iLeftInstance);
node.getLeft().setID(taxaNames.get(m_iLeftInstance));
node.getLeft().setHeight(m_fHeight - m_fLeftLength);
if (m_right == null) {
node.setRight(newNode());
node.getRight().setNr(m_iRightInstance);
node.getRight().setID(taxaNames.get(m_iRightInstance));
node.getRight().setHeight(m_fHeight - m_fRightLength);
} else {
node.setRight(m_right.toNode());
}
} else {
node.setLeft(m_left.toNode());
if (m_right == null) {
node.setRight(newNode());
node.getRight().setNr(m_iRightInstance);
node.getRight().setID(taxaNames.get(m_iRightInstance));
node.getRight().setHeight(m_fHeight - m_fRightLength);
} else {
node.setRight(m_right.toNode());
}
}
if (node.getHeight() < node.getLeft().getHeight() + EPSILON) {
node.setHeight(node.getLeft().getHeight() + EPSILON);
}
if (node.getHeight() < node.getRight().getHeight() + EPSILON) {
node.setHeight(node.getRight().getHeight() + EPSILON);
}
node.getRight().setParent(node);
node.getLeft().setParent(node);
return node;
}
} // class NodeX
/**
* used for priority queue for efficient retrieval of pair of clusters to merge*
*/
class Tuple {
public Tuple(final double d, final int i, final int j, final int size1, final int size2) {
m_fDist = d;
m_iCluster1 = i;
m_iCluster2 = j;
m_nClusterSize1 = size1;
m_nClusterSize2 = size2;
}
double m_fDist;
int m_iCluster1;
int m_iCluster2;
int m_nClusterSize1;
int m_nClusterSize2;
}
/**
* comparator used by priority queue*
*/
class TupleComparator implements Comparator<Tuple> {
@Override
public int compare(final Tuple o1, final Tuple o2) {
if (o1.m_fDist < o2.m_fDist) {
return -1;
} else if (o1.m_fDist == o2.m_fDist) {
return 0;
}
return 1;
}
}
// return distance according to distance metric
double distance(final int taxon1, final int taxon2) {
return distance.pairwiseDistance(taxon1, taxon2);
} // distance
// 1-norm
double distance(final double[] pattern1, final double[] pattern2) {
double dist = 0;
for (int i = 0; i < dataInput.get().getPatternCount(); i++) {
dist += dataInput.get().getPatternWeight(i) * Math.abs(pattern1[i] - pattern2[i]);
}
return dist / dataInput.get().getSiteCount();
}
@SuppressWarnings("unchecked")
public Node buildClusterer() {
final int taxonCount = taxaNames.size();
if (taxonCount == 1) {
// pathological case
final Node node = newNode();
node.setHeight(1);
node.setNr(0);
return node;
}
// use array of integer vectors to store cluster indices,
// starting with one cluster per instance
final List<Integer>[] clusterID = new ArrayList[taxonCount];
for (int i = 0; i < taxonCount; i++) {
clusterID[i] = new ArrayList<>();
clusterID[i].add(i);
}
// calculate distance matrix
final int clusters = taxonCount;
// used for keeping track of hierarchy
final NodeX[] clusterNodes = new NodeX[taxonCount];
if (linkType == Type.neighborjoining || linkType == Type.neighborjoining2) {
neighborJoining(clusters, clusterID, clusterNodes);
} else {
doLinkClustering(clusters, clusterID, clusterNodes);
}
// move all clusters in m_nClusterID array
// & collect hierarchy
for (int i = 0; i < taxonCount; i++) {
if (clusterID[i].size() > 0) {
return clusterNodes[i].toNode();
}
}
return null;
} // buildClusterer
/**
* use neighbor joining algorithm for clustering
* This is roughly based on the RapidNJ simple implementation and runs at O(n^3)
* More efficient implementations exist, see RapidNJ (or my GPU implementation :-))
*
* @param clusters
* @param clusterID
* @param clusterNodes
*/
void neighborJoining(int clusters, final List<Integer>[] clusterID, final NodeX[] clusterNodes) {
final int n = taxaNames.size();
final double[][] dist = new double[clusters][clusters];
for (int i = 0; i < clusters; i++) {
dist[i][i] = 0;
for (int j = i + 1; j < clusters; j++) {
dist[i][j] = getDistance0(clusterID[i], clusterID[j]);
dist[j][i] = dist[i][j];
}
}
final double[] separationSums = new double[n];
final double[] separations = new double[n];
final int[] nextActive = new int[n];
//calculate initial separation rows
for (int i = 0; i < n; i++) {
double sum = 0;
for (int j = 0; j < n; j++) {
sum += dist[i][j];
}
separationSums[i] = sum;
separations[i] = sum / (clusters - 2);
nextActive[i] = i + 1;
}
while (clusters > 2) {
// find minimum
int min1 = -1;
int min2 = -1;
double min = Double.MAX_VALUE;
{
int i = 0;
while (i < n) {
final double sep1 = separations[i];
final double[] row = dist[i];
int j = nextActive[i];
while (j < n) {
final double sep2 = separations[j];
final double val = row[j] - sep1 - sep2;
if (val < min) {
// new minimum
min1 = i;
min2 = j;
min = val;
}
j = nextActive[j];
}
i = nextActive[i];
}
}
// record distance
final double minDistance = dist[min1][min2];
clusters--;
final double sep1 = separations[min1];
final double sep2 = separations[min2];
final double dist1 = (0.5 * minDistance) + (0.5 * (sep1 - sep2));
final double dist2 = (0.5 * minDistance) + (0.5 * (sep2 - sep1));
if (clusters > 2) {
// update separations & distance
double newSeparationSum = 0;
final double mutualDistance = dist[min1][min2];
final double[] row1 = dist[min1];
final double[] row2 = dist[min2];
for (int i = 0; i < n; i++) {
if (i == min1 || i == min2 || clusterID[i].size() == 0) {
row1[i] = 0;
} else {
final double val1 = row1[i];
final double val2 = row2[i];
final double distance = (val1 + val2 - mutualDistance) / 2.0;
newSeparationSum += distance;
// update the separationsum of cluster i.
separationSums[i] += (distance - val1 - val2);
separations[i] = separationSums[i] / (clusters - 2);
row1[i] = distance;
dist[i][min1] = distance;
}
}
separationSums[min1] = newSeparationSum;
separations[min1] = newSeparationSum / (clusters - 2);
separationSums[min2] = 0;
merge(min1, min2, dist1, dist2, clusterID, clusterNodes);
int prev = min2;
// since min1 < min2 we havenActiveRows[0] >= 0, so the next loop should be save
while (clusterID[prev].size() == 0) {
prev--;
}
nextActive[prev] = nextActive[min2];
} else {
merge(min1, min2, dist1, dist2, clusterID, clusterNodes);
break;
}
}
for (int i = 0; i < n; i++) {
if (clusterID[i].size() > 0) {
for (int j = i + 1; j < n; j++) {
if (clusterID[j].size() > 0) {
final double dist1 = dist[i][j];
if (clusterID[i].size() == 1) {
merge(i, j, dist1, 0, clusterID, clusterNodes);
} else if (clusterID[j].size() == 1) {
merge(i, j, 0, dist1, clusterID, clusterNodes);
} else {
merge(i, j, dist1 / 2.0, dist1 / 2.0, clusterID, clusterNodes);
}
break;
}
}
}
}
} // neighborJoining
/**
* Perform clustering using a link method
* This implementation uses a priority queue resulting in a O(n^2 log(n)) algorithm
*
* @param clusters number of clusters
* @param clusterID
* @param clusterNodes
*/
void doLinkClustering(int clusters, final List<Integer>[] clusterID, final NodeX[] clusterNodes) {
final int instances = taxaNames.size();
final PriorityQueue<Tuple> queue = new PriorityQueue<>(clusters * clusters / 2, new TupleComparator());
final double[][] distance0 = new double[clusters][clusters];
for (int i = 0; i < clusters; i++) {
distance0[i][i] = 0;
for (int j = i + 1; j < clusters; j++) {
distance0[i][j] = getDistance0(clusterID[i], clusterID[j]);
distance0[j][i] = distance0[i][j];
queue.add(new Tuple(distance0[i][j], i, j, 1, 1));
}
}
while (clusters > 1) {
int min1 = -1;
int min2 = -1;
// use priority queue to find next best pair to cluster
Tuple t;
do {
t = queue.poll();
}
while (t != null && (clusterID[t.m_iCluster1].size() != t.m_nClusterSize1 || clusterID[t.m_iCluster2].size() != t.m_nClusterSize2));
min1 = t.m_iCluster1;
min2 = t.m_iCluster2;
merge(min1, min2, t.m_fDist/2.0, t.m_fDist/2.0, clusterID, clusterNodes);
// merge clusters
// update distances & queue
for (int i = 0; i < instances; i++) {
if (i != min1 && clusterID[i].size() != 0) {
final int i1 = Math.min(min1, i);
final int i2 = Math.max(min1, i);
final double distance = getDistance(distance0, clusterID[i1], clusterID[i2]);
queue.add(new Tuple(distance, i1, i2, clusterID[i1].size(), clusterID[i2].size()));
}
}
clusters--;
}
} // doLinkClustering
void merge(int min1, int min2, double dist1, double dist2, final List<Integer>[] clusterID, final NodeX[] clusterNodes) {
if (min1 > min2) {
final int h = min1;
min1 = min2;
min2 = h;
final double f = dist1;
dist1 = dist2;
dist2 = f;
}
clusterID[min1].addAll(clusterID[min2]);
//clusterID[min2].removeAllElements();
clusterID[min2].removeAll(clusterID[min2]);
// track hierarchy
final NodeX node = new NodeX();
if (clusterNodes[min1] == null) {
node.m_iLeftInstance = min1;
} else {
node.m_left = clusterNodes[min1];
clusterNodes[min1].m_parent = node;
}
if (clusterNodes[min2] == null) {
node.m_iRightInstance = min2;
} else {
node.m_right = clusterNodes[min2];
clusterNodes[min2].m_parent = node;
}
if (distanceIsBranchLength) {
node.setLength(dist1, dist2);
} else {
node.setHeight(dist1, dist2);
}
clusterNodes[min1] = node;
} // merge
/**
* calculate distance the first time when setting up the distance matrix *
*/
double getDistance0(final List<Integer> cluster1, final List<Integer> cluster2) {
double bestDist = Double.MAX_VALUE;
switch (linkType) {
case single:
case neighborjoining:
case neighborjoining2:
case centroid:
case complete:
case adjcomplete:
case average:
case mean:
// set up two instances for distance function
bestDist = distance(cluster1.get(0), cluster2.get(0));
break;
case ward: {
// finds the distance of the change in caused by merging the cluster.
// The information of a cluster is calculated as the error sum of squares of the
// centroids of the cluster and its members.
final double ESS1 = calcESS(cluster1);
final double ESS2 = calcESS(cluster2);
final List<Integer> merged = new ArrayList<>();
merged.addAll(cluster1);
merged.addAll(cluster2);
final double ESS = calcESS(merged);
bestDist = ESS * merged.size() - ESS1 * cluster1.size() - ESS2 * cluster2.size();
}
break;
default:
break;
}
return bestDist;
} // getDistance0
/**
* calculate the distance between two clusters
*
* @param cluster1 list of indices of instances in the first cluster
* @param cluster2 dito for second cluster
* @return distance between clusters based on link type
*/
double getDistance(final double[][] distance, final List<Integer> cluster1, final List<Integer> cluster2) {
double bestDist = Double.MAX_VALUE;
switch (linkType) {
case single:
// find single link distance aka minimum link, which is the closest distance between
// any item in cluster1 and any item in cluster2
bestDist = Double.MAX_VALUE;
for (int i = 0; i < cluster1.size(); i++) {
final int i1 = cluster1.get(i);
for (int j = 0; j < cluster2.size(); j++) {
final int i2 = cluster2.get(j);
final double dist = distance[i1][i2];
if (bestDist > dist) {
bestDist = dist;
}
}
}
break;
case complete:
case adjcomplete:
// find complete link distance aka maximum link, which is the largest distance between
// any item in cluster1 and any item in cluster2
bestDist = 0;
for (int i = 0; i < cluster1.size(); i++) {
final int i1 = cluster1.get(i);
for (int j = 0; j < cluster2.size(); j++) {
final int i2 = cluster2.get(j);
final double dist = distance[i1][i2];
if (bestDist < dist) {
bestDist = dist;
}
}
}
if (linkType == Type.complete) {
break;
}
// calculate adjustment, which is the largest within cluster distance
double maxDist = 0;
for (int i = 0; i < cluster1.size(); i++) {
final int i1 = cluster1.get(i);
for (int j = i + 1; j < cluster1.size(); j++) {
final int i2 = cluster1.get(j);
final double dist = distance[i1][i2];
if (maxDist < dist) {
maxDist = dist;
}
}
}
for (int i = 0; i < cluster2.size(); i++) {
final int i1 = cluster2.get(i);
for (int j = i + 1; j < cluster2.size(); j++) {
final int i2 = cluster2.get(j);
final double dist = distance[i1][i2];
if (maxDist < dist) {
maxDist = dist;
}
}
}
bestDist -= maxDist;
break;
case average:
// finds average distance between the elements of the two clusters
bestDist = 0;
for (int i = 0; i < cluster1.size(); i++) {
final int i1 = cluster1.get(i);
for (int j = 0; j < cluster2.size(); j++) {
final int i2 = cluster2.get(j);
bestDist += distance[i1][i2];
}
}
bestDist /= (cluster1.size() * cluster2.size());
break;
case mean: {
// calculates the mean distance of a merged cluster (akak Group-average agglomerative clustering)
final List<Integer> merged = new ArrayList<>();
merged.addAll(cluster1);
merged.addAll(cluster2);
bestDist = 0;
for (int i = 0; i < merged.size(); i++) {
final int i1 = merged.get(i);
for (int j = i + 1; j < merged.size(); j++) {
final int i2 = merged.get(j);
bestDist += distance[i1][i2];
}
}
final int n = merged.size();
bestDist /= (n * (n - 1.0) / 2.0);
}
break;
case centroid:
// finds the distance of the centroids of the clusters
final int patterns = dataInput.get().getPatternCount();
final double[] centroid1 = new double[patterns];
for (int i = 0; i < cluster1.size(); i++) {
final int taxonIndex = cluster1.get(i);
for (int j = 0; j < patterns; j++) {
centroid1[j] += dataInput.get().getPattern(taxonIndex, j);
}
}
final double[] centroid2 = new double[patterns];
for (int i = 0; i < cluster2.size(); i++) {
final int taxonIndex = cluster2.get(i);
for (int j = 0; j < patterns; j++) {
centroid2[j] += dataInput.get().getPattern(taxonIndex, j);
}
}
for (int j = 0; j < patterns; j++) {
centroid1[j] /= cluster1.size();
centroid2[j] /= cluster2.size();
}
bestDist = distance(centroid1, centroid2);
break;
case ward: {
// finds the distance of the change in caused by merging the cluster.
// The information of a cluster is calculated as the error sum of squares of the
// centroids of the cluster and its members.
final double ESS1 = calcESS(cluster1);
final double ESS2 = calcESS(cluster2);
final List<Integer> merged = new ArrayList<>();
merged.addAll(cluster1);
merged.addAll(cluster2);
final double ESS = calcESS(merged);
bestDist = ESS * merged.size() - ESS1 * cluster1.size() - ESS2 * cluster2.size();
}
break;
default:
break;
}
return bestDist;
} // getDistance
/**
* calculated error sum-of-squares for instances wrt centroid *
*/
double calcESS(final List<Integer> cluster) {
final int patterns = dataInput.get().getPatternCount();
final double[] centroid = new double[patterns];
for (int i = 0; i < cluster.size(); i++) {
final int taxonIndex = cluster.get(i);
for (int j = 0; j < patterns; j++) {
centroid[j] += dataInput.get().getPattern(taxonIndex, j);
}
}
for (int j = 0; j < patterns; j++) {
centroid[j] /= cluster.size();
}
// set up two instances for distance function
double eSS = 0;
for (int i = 0; i < cluster.size(); i++) {
final double[] instance = new double[patterns];
final int taxonIndex = cluster.get(i);
for (int j = 0; j < patterns; j++) {
instance[j] += dataInput.get().getPattern(taxonIndex, j);
}
eSS += distance(centroid, instance);
}
return eSS / cluster.size();
} // calcESS
@Override
public void initStateNodes() {
if (m_initial.get() != null) {
m_initial.get().assignFromWithoutID(this);
}
}
@Override
public void getInitialisedStateNodes(final List<StateNode> stateNodes) {
if (m_initial.get() != null) {
stateNodes.add(m_initial.get());
}
}
} // class ClusterTree