package viz.process;
import java.io.PrintStream;
//import org.apache.commons.math.FunctionEvaluationException;
//import org.apache.commons.math.optimization.DifferentiableMultivariateRealOptimizer;
//import org.apache.commons.math.optimization.GoalType;
//import org.apache.commons.math.optimization.MultivariateRealOptimizer;
//import org.apache.commons.math.optimization.OptimizationException;
//import org.apache.commons.math.optimization.RealPointValuePair;
//import org.apache.commons.math.optimization.VectorialPointValuePair;
//import org.apache.commons.math.optimization.direct.PowellOptimizer;
//import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
import viz.DensiTree;
import viz.Node;
public class BranchLengthOptimiser {
DensiTree m_dt;
BranchScorer scorer;
final static int MAX_ATTEMPTS = 500;
final static float RANGE = 100;
public BranchLengthOptimiser(DensiTree dt) {
m_dt = dt;
}
public void optimiseScore(Node tree) {
long start = System.currentTimeMillis();
// initialise CladeBranchInfo data structures
float [] heights = new float[m_dt.m_sLabels.size() * 2 - 1];
Node [] nodes = new Node[m_dt.m_sLabels.size() * 2 - 1];
collectNodes(tree, nodes, heights, tree.m_fPosY);
scorer = new BranchScorer(m_dt, nodes);
double startscore = scorer.score(heights);
initialiseTree(heights, nodes);
long start2 = System.currentTimeMillis();
// PowellOptimizer optimizer = new PowellOptimizer();
// double [] startvalue = new double[heights.length];
// for (int i = 0; i < heights.length; i++) {
// startvalue[i] = heights[i];
// }
// try {
// RealPointValuePair optimum = optimizer.optimize(scorer,
// GoalType.MINIMIZE,
// startvalue);
// for (int i = 0; i < heights.length; i++) {
// heights[i] = (float)optimum.getPoint()[i];
// }
// for (int i = 0; i < heights.length; i++) {
// Node node = nodes[i];
// if (!node.isRoot()) {
// node.m_fLength = heights[i] - heights[node.getParent().getNr()];
// }
// }
// } catch (OptimizationException e1) {
// // TODO Auto-generated catch block
// e1.printStackTrace();
// } catch (FunctionEvaluationException e1) {
// // TODO Auto-generated catch block
// e1.printStackTrace();
// } catch (IllegalArgumentException e1) {
// // TODO Auto-generated catch block
// e1.printStackTrace();
// }
optimiseTree(heights, nodes, scorer);
for (int i = 0; i < nodes.length - 1; i++) {
nodes[i].m_fLength = heights[i] - heights[nodes[i].getParent().getNr()];
}
long end = System.currentTimeMillis();
System.err.println("\n\n\n" + (end-start)/1000.0 + " seconds optimising " + (start2-start)/1000.0 + " seconds initialising");
double endscore = scorer.score(heights);
System.err.println("Start score: " + startscore + " End score: " + endscore + "\n\n\n");
if (m_dt.m_sOptFile != null) {
try {
PrintStream out = new PrintStream(m_dt.m_sOptFile);
out.println(tree.toString(m_dt.m_sLabels, false));
out.close();
} catch (Exception e) {
e.printStackTrace();
}
System.exit(0);
}
}
void optimiseTree(float[] heights, Node[] nodes, BranchScorer scorer) {
boolean bProgress = true;
for (int i = 0; i < MAX_ATTEMPTS && bProgress; i++) {
bProgress = false;
// optimise internal nodes by finding the best Uniform operation on a grid
// for each node individually
for (int k = nodes.length/2+1; k < nodes.length; k++) {
Node node = nodes[k];
int iCladeLeft = node.m_left.m_iClade;
int iCladeRight = node.m_right.m_iClade;
float leftHeight = heights[node.m_left.m_iLabel];
float rightHeight = heights[node.m_right.m_iLabel];
float minHeight = Math.min(leftHeight, rightHeight);
float maxHeight;
if (node.isRoot()) {
CladeBranchInfo infoLeft = scorer.m_cladeBranchInfo.get(iCladeLeft);
CladeBranchInfo infoRight = scorer.m_cladeBranchInfo.get(iCladeRight);
maxHeight = Math.min(leftHeight - infoLeft.getMaxLength(), rightHeight - infoRight.getMaxLength());
} else {
maxHeight = heights[node.getParent().getNr()];
}
float bestHeight = heights[node.getNr()];
heights[k] = bestHeight;
double bestScore = scorer.score(heights);
for (int j = 1; j < RANGE; j++) {
float height = j*(maxHeight - minHeight)/RANGE + minHeight;
heights[k] = height;
double score = scorer.score(heights);
if (score < bestScore) {
bProgress = true;
bestScore = score;
bestHeight = height;
}
}
heights[k] = bestHeight;
}
System.err.print(".");
}
}
private void initialiseTree(float[] heights, Node[] nodes) {
// do pre-optimisation; position each node optimally, without considering parents
for (int k = m_dt.m_sLabels.size(); k < nodes.length; k++) {
Node node = nodes[k];
CladeBranchInfo infoLeft = scorer.m_cladeBranchInfo.get(node.m_left.m_iClade);
CladeBranchInfo infoRight = scorer.m_cladeBranchInfo.get(node.m_right.m_iClade);
float leftHeight = heights[node.m_left.m_iLabel];
float rightHeight = heights[node.m_right.m_iLabel];
float minHeight = Math.min(leftHeight, rightHeight);
float maxHeight = Math.min(leftHeight - infoLeft.getMaxLength(), rightHeight - infoRight.getMaxLength());
float bestHeight = Math.min(minHeight, heights[node.m_iLabel]);
float bestScore = infoLeft.score(bestHeight, leftHeight) +
infoRight.score(bestHeight, rightHeight);
for (int j = 2; j < RANGE; j++) {
float height = j*(maxHeight - minHeight)/RANGE + minHeight;
float score = infoLeft.score(height, leftHeight) +
infoRight.score(height, rightHeight);
if (score < bestScore) {
bestScore = score;
bestHeight = height;
}
}
heights[k] = bestHeight;
if (!node.isRoot()) {
node.m_fLength = bestHeight - heights[node.getParent().m_iLabel];
}
node.m_left.m_fLength = leftHeight - bestHeight;
node.m_right.m_fLength = rightHeight - bestHeight;
}
}
private void collectNodes(Node node, Node[] nodes, float [] heights, float height) {
nodes[node.m_iLabel] = node;
heights[node.m_iLabel] = height;
if (!node.isLeaf()) {
collectNodes(node.m_left, nodes, heights, height + node.m_left.m_fLength);
collectNodes(node.m_right, nodes, heights, height + node.m_right.m_fLength);
}
}
} // class BranchLengthOptimiser
//38.52506983048988