package com.ppfold.algo; import java.util.List; /** * Contains methods for the optimization of branch lengths to obtain the maximum * likelihood estimate tree. * * @author Z.Sukosd */ public class MaximumLikelihoodTree { static double _PHI = (1 + Math.sqrt(5)) / 2; static double _RESPHI = 2 - (1 + Math.sqrt(5)) / 2; public static int optimizeBranchLengths(Progress act, Tree tree,List<int[]> columns_int, List<char[]> columns, List<String> names,Parameters param, int iterlimit) throws InterruptedException{ //optimizes the branch lengths of the input tree. System.out.println("Optimizing branch lengths..."); long starttime = System.currentTimeMillis(); double [][] D = param.getrD(); double [][] V = param.getrV(); double [][] V1 = param.getrV1(); double [] Pr = param.getPr(); double [][] upvectors = new double [columns.size()][4]; double [][] downvectors = new double [columns.size()][4]; boolean iteratemore = true; int cnt = 0; //System.out.println("Start tree: "); //tree.print(); //Create a list of nodes of the tree. List <Node> allnodes = tree.createListOfNodes(); //Create a list of leaves so they can be referenced quickly later. tree.generateLeafList(names); double probability = 0; while(iteratemore&&cnt<iterlimit){ act.checkStop(); act.setProgress((double)cnt*(double)1/(double)iterlimit); //System.out.println("Iteration " + cnt); act.setCurrentActivity("Optimizing tree: iteration " + (cnt+1) + "/" + iterlimit); //create matrices for the branches of the tree. //these will be the same for each iteration //but have to be recalculated after adjusted branch lengths. tree.getRoot().calculateChildrenMatrix(D,V,V1); tree.getRoot().initializeChildrenUpDownVectors(); for(Node branch_node:allnodes){ //System.out.println("Optimizing branch length from node " + branch_node.getId() + // " which is " + branch_node.getDistanceFromParent()); act.checkStop(); for (int col = 0; col < columns.size(); col++) { // stepping column number. // Each column has the nt in the same position from all sequences. char[] column = columns.get(col); // Reset all vectors for the new column. tree.getRoot().resetChildrenDownVectors(); tree.getRoot().resetChildrenUpVectors(); for (int row = 0; row < column.length; row++) { // find node corresponding to the rownumber (sequence) Node node = tree.findNodeWithName(row); if (node == null) { System.err.println("Can't find node with name " + names.get(row)); } // now set the down-bottom vectors of this Node. node.setDownBottomVector(MatrixTools.createNtVector(column[row])); // now set the up-top vector of this node's children for(Node n: node.getChildren()){ n.setUpTopVector(MatrixTools.createNtVector(column[row])); } } // When finished with the whole sequence, // recursively find the down- and up-vectors of all //other nodes for this column. tree.calculateDownVectors(); MatrixTools.copyFromTo(branch_node.getDownBottomVector(), downvectors[col]); tree.calculateUpVectors(); MatrixTools.copyFromTo(branch_node.getUpTopVector(), upvectors[col]); //tree.getRoot().printChildrenUpDownVectors(); } //optimize final probability of tree probability = lnProbability(upvectors, downvectors, branch_node.getDistanceFromParent(), D, V, V1, Pr); //System.out.println("Probability of tree: " + probability); double lowerbound = 0; double upperbound = 10; double midpoint = (upperbound-lowerbound)/2 + _RESPHI * (upperbound - lowerbound); //Change bounds as appropriate branch_node.setNewDistanceFromParent(goldenSectionSearch( upvectors, downvectors, param.getrD(), param.getrV(), param.getrV1(), param.getPr(), lowerbound, midpoint, upperbound, Math.sqrt(1e-4))); probability = lnProbability(upvectors, downvectors, branch_node.getNewDistanceFromParent(), D, V, V1, Pr); // System.out.println("NEW Probability of tree: " + probability); // System.out.println("OLD length: " + branch_node.getDistanceFromParent()); // System.out.println("NEW length: " + branch_node.getNewDistanceFromParent()); } iteratemore = !tree.setNewBranches(); if(cnt==0){ System.out.println("Start log-probability of tree: " + probability); } cnt++; } if(cnt==iterlimit){ //System.out.println(); System.out.println("WARNING! Iteration limit exceeded! (" + iterlimit + ")" + " Tree may not be optimal" + " (but it's probably good enough)."); System.out.println("End log-probability of tree: " + probability); } else{ System.out.println("All branch lengths converged after " + cnt + " iterations." ); System.out.println("End log-probability of tree: " + probability); } //System.out.println(" prob is = " + Math.exp(probability)); //System.out.println("Final tree: "); //tree.print(); System.out.println("TOTAL TIME ELAPSED IN MLE: " + (System.currentTimeMillis()-starttime)/1000 + " seconds "); return cnt; } private static double goldenSectionSearch( double[][] uptopvectors, double[][] downbottomvectors, double[][] D, double[][] V, double [][] V1, double [] Pr, double t1, double t2, double t3, double tau){ // t1 and t3 are the current bounds; the minimum is between them. // t2 is the center point, which is closer to t1 than to t3 //System.out.println("Current bounds: " + t1 + ", " + t3 + ", center " + t2); // Create a new possible center in the area between t2 and t3, closer to t2 double t4 = t2 + _RESPHI * (t3 - t2); if(Math.abs(t3 - t1) < tau * (Math.abs(t2) + Math.abs(t4))){ return (t3 + t1) / 2; } // System.out.println(probability(uptopvectors, downbottomvectors, t4, D, V, V1, Pr) + // " versus " + probability(uptopvectors, downbottomvectors, t2, D, V, V1, Pr)); if(lnProbability(uptopvectors, downbottomvectors, t4, D, V, V1, Pr) > lnProbability(uptopvectors, downbottomvectors, t2, D, V, V1, Pr)){ // System.out.println(" A: new bounds " + t2 + ", " + t3); return goldenSectionSearch(uptopvectors, downbottomvectors, D, V, V1, Pr, t2, t4, t3, tau); } else{ // System.out.println(" B: new bounds " + t4 + ", " + t1); return goldenSectionSearch(uptopvectors, downbottomvectors, D, V, V1, Pr, t4, t2, t1, tau); } } /*private static double probability(double[][] uptopvectors, double[][] downbottomvectors, double t, double[][] D, double[][] V, double [][] V1, double [] Pr){ //calculates the probability of the tree //as the product of (uptop[i] expRt) .* downbottom[i] <dot> p_eq, for all columns. double probability = 1; double number = 0; //FOR DEBUG! double[] partial = new double[4]; double[] tmpvector1 = new double[4]; double[][] expRT = MatrixTools.expRT(D, t, V, V1); //MatrixTools.print(expRT); for(int i = 0; i<uptopvectors.length; i++){ MatrixTools.resetVector(tmpvector1, 0); MatrixTools.resetVector(partial, 0); MatrixTools.copyFromTo(uptopvectors[i], partial); MatrixTools.multiplyVectorMatrix(partial, expRT, tmpvector1); MatrixTools.multiplySeries(partial, downbottomvectors[i]); number = MatrixTools.scalarProduct(partial, Pr); probability = probability*number; //System.out.println("Column "+ i + " has probability: " + number); } //FOR DEBUG: return probability; }*/ private static double lnProbability(double[][] uptopvectors, double[][] downbottomvectors, double t, double[][] D, double[][] V, double [][] V1, double [] Pr){ //calculates the probability of the tree //as the product of (uptop[i] expRt) .* downbottom[i] <dot> p_eq, for all columns. double probability = 0; double number = 0; //FOR DEBUG! double[] partial = new double[4]; double[] tmpvector1 = new double[4]; double[][] expRT = MatrixTools.expRT(D, t, V, V1); //MatrixTools.print(expRT); for(int i = 0; i<uptopvectors.length; i++){ MatrixTools.resetVector(tmpvector1, 0); MatrixTools.resetVector(partial, 0); MatrixTools.copyFromTo(uptopvectors[i], partial); //MatrixTools.copyFromTo(downbottomvectors[i], partial); MatrixTools.multiplyVectorMatrix(partial, expRT, tmpvector1); MatrixTools.multiplySeries(partial, downbottomvectors[i]); //MatrixTools.multiplySeries(partial, uptopvectors[i]); number = MatrixTools.scalarProduct(partial, Pr); probability = probability+Math.log(number); //System.out.println("Column "+ i + " has probability: " + Math.log(number)); } //FOR DEBUG: //System.out.println("P(tree | t = " + t + ") = " + probability); return probability; } public static int STARTREEoptimizeBranchLengths(Progress act, Tree tree,List<int[]> columns_int, List<char[]> columns, List<String> names,Parameters param, int iterlimit) throws InterruptedException{ //dummy method to set all branch lengths to 0; this is a "hack" to simulate a star-tree //for experiments. System.out.println("Setting all branch lengths to zero (simulating star-tree)..."); List <Node> allnodes = tree.createListOfNodes(); for(Node branch_node:allnodes){ branch_node.setDistanceFromParent(0); } return 0; } }