package com.ppfold.algo;
import java.util.ArrayList;
import java.util.List;
/**
* Contains methods to create a tree by neighbour-joining from the alignment.
*
* @author Z.Sukosd
*/
public class NeighbourJoining {
static double _PHI = (1 + Math.sqrt(5)) / 2;
static double _RESPHI = 2 - (1 + Math.sqrt(5)) / 2;
static int idcounter = -1; //node ID counter
public static Tree generateTreeNJ(Progress activity, List<String> sequences, List<int[]> columns,
List<String> names, Parameters param) throws InterruptedException {
long starttime = System.currentTimeMillis();
//Number of sequences
int n = sequences.size();
System.out.println("Number of sequences: "+ n);
System.out.println("Generating tree by neighbour-joining...");
if(n==1){
//special case, just one sequence
Node one = new Node(0);
one.setName(names.get(0));
return new Tree(one);
}
//matrix of distances
double [][] d = new double[n][n];
String seq1;
String seq2;
activity.setCurrentActivity("Calculating distance matrix...");
//calculate pairwise distances of sequences
for(int i=0; i<sequences.size(); i++){
for(int j=0;j<i+1; j++){
seq1 = sequences.get(i);
seq2 = sequences.get(j);
//System.out.println("Characters: " + seq1 + ", " + seq2);
activity.checkStop();
//System.out.println("Iteration " + cnt);
//create matrices for the branches of the tree.
double lowerbound = 0;
double upperbound = 10;
double midpoint = (upperbound-lowerbound)/2 + _RESPHI * (upperbound - lowerbound);
//System.out.println("Finding distance between " + i + " and " + j);
d[i][j] = d[j][i] = goldenSectionSearch(activity,
seq1, seq2, param.getrD(), param.getrV(),
param.getrV1(), param.getPr(),
lowerbound, midpoint, upperbound, Math.sqrt(1e-4));
//System.out.println("Distance is found to be " + d[i][j]);
if(d[i][j] > 10){
//If the optimum is very large, assume it's infinity and set to 10.
//(Someone will someday hate me for this comment - sorry ;))
d[i][j] = d[j][i] = 10;
}
}
}
//System.out.println("Pairwise distance matrix:");
//MatrixTools.print(d);
//This is to calculate the distances
//double val = average(d);
//System.out.println(val + " " + std(d,val) + " " + max(d) + " " + maxdiff(d));
//Runtime.getRuntime().exit(0);
//do the neighbour-joining
activity.setCurrentActivity("Neighbour joining...");
Tree result = NJ(activity, d,names);
System.out.println("TOTAL TIME ELAPSED IN NEIGHBOUR-JOINING: " + (System.currentTimeMillis()-starttime)/1000 + " seconds ");
// double value = lnProbability(sequences.get(0), sequences.get(1), 2.3363d, param.getrD(), param.getrV(),
// param.getrV1(), param.getPr());
// System.out.println(value + " = " + Math.exp(value));
//
// System.out.println("From A to A: ");
// double value2 = lnProbability(sequences.get(0), sequences.get(0), 0d, param.getrD(), param.getrV(),
// param.getrV1(), param.getPr());
// System.out.println(value + " = " + Math.exp(value2));
//
return result;
}
private static Tree NJ(Progress activity, double [][] d, List<String> names) throws InterruptedException {
//names contain the names of the leaves.
//names.size() corresponds to the number of sequences in the alignment.
//(each sequence must have a name)
int n = names.size();
List<Node> taxa = new ArrayList<Node>();
//Create a node for each leaf.
for(int i = 0; i<n; i++){
Node node = new Node();
idcounter++;
node.setId(i);
//System.out.println(names.get(i) + " has id " + i);
node.setName(names.get(i));
node.setDistanceFromParent(0);
taxa.add(node);
}
Node root = joinNeighbourTaxa(activity, taxa,d);
Tree tree = new Tree(root);
//tree.print();
return tree;
}
private static Node joinNeighbourTaxa(Progress activity, List<Node> taxa, double[][] d) throws InterruptedException{
activity.checkStop();
//printTaxa(taxa);
if(taxa.size()==1){
//stopping criterion, return the root.
if(taxa.get(0).getName()==null){
//taxa.get(0).setName("root");
}
return taxa.remove(0);
}
else if(taxa.size()==2){
//only 2 nodes left.
//Special case, join them, designate one as root.
taxa.get(0).addChild(taxa.get(1));
taxa.get(1).setDistanceFromParent(d[1][0]);
taxa.get(0).setDistanceFromParent(0);
taxa.remove(1);
double[][] newd = new double[1][1];
newd[0][0] = 1;
return joinNeighbourTaxa(activity, taxa,newd);
}
else{
int n = taxa.size(); //how many taxa should be joined
double partial1 = 0;
double partial2 = 0;
double [][] Q = new double [n][n];
for(int i = 0; i<n; i++){
partial1 = 0;
for(int k = 0; k<n; k++){
if(i!=k){
partial1 = partial1+d[i][k];
}
}
for(int j = 0; j<i; j++){
partial2 = 0;
for(int k = 0; k<n; k++){
if(j!=k){
partial2 = partial2+d[j][k];
}
}
Q[i][j] = Q[j][i] = (n-2)*d[i][j] - partial1 - partial2;
}
}
//System.out.println("Ids:");
//for(Node tax:taxa){
// System.out.print(tax.getId() + " ");
//}
//System.out.println();
//System.out.println("Distance matrix:");
//MatrixTools.print(d);
//System.out.println("Q: ");
//MatrixTools.print(Q);
int[] tojoin = minCoords(Q);
//System.out.println("Chosen value:" + tojoin[0] + ", " + tojoin[1] + " = " + Q[tojoin[0]][tojoin[1]]);
//System.out.println("Joining " + taxa.get(tojoin[0]).getId() + "(" + taxa.get(tojoin[0]).getName() + ") " +
// " to " + taxa.get(tojoin[1]).getId()+ "(" + taxa.get(tojoin[1]).getName() + ") ");
Node newnode = new Node();
idcounter++;
newnode.setId(idcounter);
//System.out.println("Created node id " + newnode.getId());
newnode.addChild(taxa.get(tojoin[1]));
newnode.addChild(taxa.get(tojoin[0]));
//the right coordinates tojoin[1] vs. tojoin[0] are very important here, do not mess!
partial1 = 0; //"g" sum, wikipedia definition
for(int k = 0; k<n; k++){
partial1 = partial1+d[tojoin[1]][k];
}
partial2 = 0;//"f" sum, wikipedia definition
for(int k = 0; k<n; k++){
partial2 = partial2+d[tojoin[0]][k];
}
//System.out.println("Partial2 (f sum): " + partial2);
//System.out.println("Partial1 (g sum): " + partial1);
//System.out.println("Distance of two taxa:" + d[tojoin[0]][tojoin[1]]);
double distance = 0.5*d[tojoin[0]][tojoin[1]] +
(partial2-partial1)/(2*n-4); //"f"'s distance from new node
taxa.get(tojoin[0]).setDistanceFromParent(distance);
//distance of second taxon by reflection
// System.out.println(taxa.get(tojoin[0]).getId() + " calculated first ");
//System.out.println("Distance to new taxon: " + distance);
double dist1 = distance;
distance = d[tojoin[0]][tojoin[1]]-distance; //"f"'s distance from the new node
//System.out.println("Distance of second new taxon: " + distance );
taxa.get(tojoin[1]).setDistanceFromParent(distance);
double dist2 = distance;
//Check for negative branch lengths - in this case just set the negative one to
//zero and the other one to the whole difference
if(taxa.get(tojoin[0]).getDistanceFromParent()<0){
taxa.get(tojoin[0]).setDistanceFromParent(0);
taxa.get(tojoin[1]).setDistanceFromParent(d[tojoin[0]][tojoin[1]]);
}
else if(taxa.get(tojoin[1]).getDistanceFromParent()<0){
taxa.get(tojoin[1]).setDistanceFromParent(0);
taxa.get(tojoin[0]).setDistanceFromParent(d[tojoin[0]][tojoin[1]]);
}
else{}
//System.out.println("Children of new node: " +
// newnode.getChildren().get(0).getId() + ":" + newnode.getChildren().get(0).getName()
// + " ( " + newnode.getChildren().get(0).getDistanceFromParent() + ")" +
// " and " + newnode.getChildren().get(1).getId() + ":" + newnode.getChildren().get(1).getName()
// + " ( " + newnode.getChildren().get(1).getDistanceFromParent() + ")");
//calculate new distance matrix
double [][] newd = new double[n-1][n-1];
//First copy the numbers that are the same.
//position in old matrix
int icnt = 0;
int jcnt = 0;
for(int i = 0; i<n-2; i++){
for(int j = 0; j<n-2; j++){
//icnt=i;
while(icnt==tojoin[0]||icnt==tojoin[1]){
//System.out.println("A: Skipping " +icnt);
icnt++;
}
while(jcnt==tojoin[0]||jcnt==tojoin[1]){
//System.out.println("B: Skipping " + jcnt);
jcnt++;
}
newd[i][j] = d[icnt][jcnt];
// System.out.println("Setting " + i + ", " + j + "(" + newd[i][j] + ") from " + icnt + ", " + jcnt +
// "(" + d[icnt][jcnt] + ")");
// System.out.println("NEWD is now: ");
// MatrixTools.print(newd);
jcnt++;
}
icnt++;
jcnt=0;
}
//System.out.println("New d before last ones: ");
//MatrixTools.print(newd);
//Now fill distances to last taxon
//(which is always positioned at the end of the new distance table/taxa list)
icnt = 0;
for(int k = 0; k<n-2; k++){
//coordinate in new matrix: (i,n-2) and (n-2,i) as it is symmetric.
while(icnt==tojoin[0]||icnt==tojoin[1]){
icnt++;
}
//System.out.println("Trying to fill " + k + ", " + (n-2) + ", old matrix equivalent " + icnt);
newd[k][n-2] = newd[n-2][k] = 0.5*(d[tojoin[1]][icnt] + d[tojoin[0]][icnt] - d[tojoin[0]][tojoin[1]]);
//System.out.println("0.5 * (" + d[tojoin[1]][icnt] + " + " + d[tojoin[0]][icnt] + " - " + d[tojoin[0]][tojoin[1]] + ") = " + newd[k][n-2]);
icnt++;
}
//System.out.println("New d:");
//MatrixTools.print(newd);
//printTaxa(taxa);
//the second coordinate is always larger so remove that taxon first
taxa.remove(tojoin[0]);
taxa.remove(tojoin[1]);
taxa.add(newnode);
//printTaxa(taxa);
return joinNeighbourTaxa(activity, taxa,newd);
}
}
private static void printTaxa(List<Node> taxa){
System.out.println("TAXON PRINT");
for(Node taxon:taxa){
System.out.println("Taxon id: " + taxon.getId() + ", taxon name: " + taxon.getName());
}
}
private static int[] minCoords(double [][] matrix){
//returns the coordinates of the minimum element in the matrix
if(matrix.length<2){
System.out.println("Warning: matrix has length 1");
int [] point = {0, 0};
return point;
}
double minimum = 100000000; //a very big number.
int [] point = {0, 0};
for(int i = 0; i<matrix.length; i++){
for(int j = 0; j<i; j++){
if(minimum>matrix[i][j]){
minimum = matrix[i][j];
point[0] = i; //the second coordinate is always larger
point[1] = j;
}
}
}
return point;
}
private static double goldenSectionSearch(Progress act,
String seq1, String seq2,
double[][] D, double[][] V, double [][] V1, double [] Pr,
double t1, double t2, double t3, double tau) throws InterruptedException{
// t1 and t3 are the current bounds; the minimum is between them.
// t2 is the center point, which is closer to t1 than to t3
//System.out.println(lnProbability(seq1, seq2, t2, D, V, V1, Pr));
//System.out.println("Prob: " + probability(seq1, seq2, t2, D, V, V1, Pr));
act.checkStop();
// Create a new possible center in the area between t2 and t3, closer to t2
double t4 = t2 + _RESPHI * (t3 - t2);
if(Math.abs(t3 - t1) < tau * (Math.abs(t2) + Math.abs(t4))){
return (t3 + t1) / 2;
}
// System.out.println(lnProbability(seq1, seq2, t4, D, V, V1, Pr) +
// " (for t=" + t4 + ") versus " + lnProbability(seq1, seq2, t2, D, V, V1, Pr)
// + " (for t=" + t2 + ")");
if(lnProbability(seq1, seq2, t4, D, V, V1, Pr) >
lnProbability(seq1, seq2, t2, D, V, V1, Pr)){
// System.out.println(" A: new bounds " + t2 + ", " + t3);
return goldenSectionSearch(act,seq1, seq2, D, V, V1, Pr, t2, t4, t3, tau);
}
else{
// System.out.println(" B: new bounds " + t4 + ", " + t1);
return goldenSectionSearch(act,seq1, seq2, D, V, V1, Pr, t4, t2, t1, tau);
}
}
private static double probability(String seq1, String seq2, double t,
double[][] D, double[][] V, double [][] V1, double [] Pr){
//Calculate exp{Rt} matrix.
double [][] expRT = MatrixTools.expRT(D, t, V, V1);
double [] vector1 = new double[4];
double [] vector2 = new double[4];
double result = 1;
//Iterate over positions
for(int pos = 0; pos<seq1.length(); pos++){
//Have to take sums of vector positions
double partial = 0;
vector1 = MatrixTools.createNtVector(seq1.charAt(pos));
vector2 = MatrixTools.createNtVector(seq2.charAt(pos));
for(int vectorpos1 = 0;vectorpos1<4; vectorpos1++){
if(vector1[vectorpos1]==0){
continue;
}
for(int vectorpos2 = 0;vectorpos2<4;vectorpos2++){
if(vector2[vectorpos2]==0){
continue;
}
double number = expRT[vectorpos1][vectorpos2]*Pr[vectorpos1]; //partial term;
partial = partial+number;
}
}
result = result*partial;
}
//System.out.println("Probability: " + result);
return result;
}
private static double lnProbability(String seq1, String seq2, double t,
double[][] D, double[][] V, double [][] V1, double [] Pr){
//Calculate exp{Rt} matrix.
double [][] expRT = MatrixTools.expRT(D, t, V, V1);
double [] vector1 = new double[4];
double [] vector2 = new double[4];
double result = 0;
//Iterate over positions
for(int pos = 0; pos<seq1.length(); pos++){
//Have to take sums of vector positions
double partial = 0;
vector1 = MatrixTools.createNtVector(seq1.charAt(pos));
vector2 = MatrixTools.createNtVector(seq2.charAt(pos));
for(int vectorpos1 = 0;vectorpos1<4; vectorpos1++){
if(vector1[vectorpos1]==0){
continue;
}
for(int vectorpos2 = 0;vectorpos2<4;vectorpos2++){
if(vector2[vectorpos2]==0){
continue;
}
double number = expRT[vectorpos1][vectorpos2]*Pr[vectorpos1]; //partial term;
partial = partial+number;
}
}
//System.out.println(partial);
result = result+Math.log(partial);
}
//System.out.println("Log-probability (t = " + t + "): " + result);
//System.out.println("Probability (t = " + t + "): " + Math.exp(result));
return result;
}
private static double average(double[][] d){
int n = d.length;
int cnt = 0;
double sum = 0;
for (int i = 0; i<n; i++){
for (int j = i+1; j<n; j++){
sum = sum + d[i][j];
cnt += 1;
}
}
return sum/cnt;
}
private static double std(double[][] d, double mean){
int n = d.length;
int cnt = 0;
double sum = 0;
for (int i = 0; i<n; i++){
for (int j = i+1; j<n; j++){
sum = sum + d[i][j]*d[i][j];
cnt += 1;
}
}
return Math.sqrt(sum/(double)cnt-mean*mean);
}
private static double max(double[][] d){
int n = d.length;
int n2 = n*n;
double max = 0;
for (int i = 0; i<n; i++){
for (int j = i+1; j<n; j++){
if(d[i][j] > max){
max = d[i][j];
}
}
}
return max;
}
private static double maxdiff(double[][] d){
int n = d.length;
int n2 = n*n;
double max = 0;
double min = 10;
for (int i = 0; i<n; i++){
for (int j = i+1; j<n; j++){
if(d[i][j] > max){
max = d[i][j];
}
if(d[i][j] < min){
min = d[i][j];
}
}
}
return max-min;
}
}