/*
* LeastSquaresClockTree.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evolution.tree;
import dr.math.ConjugateDirectionSearch;
import dr.math.MultivariateFunction;
import dr.math.MultivariateMinimum;
/**
*
* @version $Id: LeastSquaresClockTree.java,v 1.4 2005/05/24 20:25:56 rambaut Exp $
*
* @author Andrew Rambaut
* @author Alexei Drummond
*/
public class LeastSquaresClockTree extends SimpleTree {
/**
* constructor
*/
public LeastSquaresClockTree(Tree sourceTree) {
super(sourceTree);
this.sourceTree = sourceTree;
this.mu = 1.0;
this.optimizeMu = true;
}
/**
* constructor with a specific mutation rate
*/
public LeastSquaresClockTree(Tree sourceTree, double mu) {
this.sourceTree = sourceTree;
this.mu = mu;
this.optimizeMu = false;
}
public double getMu() { return mu; }
public void optimize() {
nodeCount = getInternalNodeCount();
int argumentCount = nodeCount;
if (optimizeMu) {
argumentCount++;
muIndex = nodeCount;
}
MultivariateMinimum optimizer = new ConjugateDirectionSearch();
nodeValues = new double[nodeCount];
double[] xvec = new double[argumentCount];
for (int i = 0; i < nodeCount; i++) {
xvec[i] = 1.0;
}
if (optimizeMu) {
xvec[muIndex] = mu;
}
optimizer.optimize(leastSquaresClock, xvec, 1E-8, 1E-8);
}
public double getSumOfSquares() {
double[] score = new double[] { 0.0 };
NodeRef root = getRoot();
if (getChildCount(root) != 2) {
throw new IllegalArgumentException("The tree must have a bifurcating root node");
}
NodeRef node1 = getChild(root, 0);
NodeRef node2 = getChild(root, 1);
if (!isExternal(node1)) {
getSumOfSquaresAtNode(node1, score);
}
if (!isExternal(node2)) {
getSumOfSquaresAtNode(node2, score);
}
double dist1 = sourceTree.getBranchLength(sourceTree.getNode(node1.getNumber())) +
sourceTree.getBranchLength(sourceTree.getNode(node2.getNumber()));
double time = getNodeHeight(root) - getNodeHeight(node1) +
getNodeHeight(root) - getNodeHeight(node2);
double dist2 = time * mu;
double diff = dist1 - dist2;
score[0] += diff * diff;
return score[0];
}
//
// Private stuff
//
private void getSumOfSquaresAtNode(NodeRef node, double[] score) {
if (!isExternal(node)) {
for (int i = 0; i < getChildCount(node); i++) {
NodeRef child = getChild(node, i);
if (!isExternal(child)) {
getSumOfSquaresAtNode(child, score);
score[0] += getScoreAtNode(child);
}
}
}
}
/**
* returns the rate on the branch to the node above.
*/
private double getScoreAtNode(NodeRef node) {
double dist1 = sourceTree.getBranchLength(sourceTree.getNode(node.getNumber()));
double time = getNodeHeight(getParent(node)) - getNodeHeight(node);
double dist2 = time * mu;
double diff = dist1 - dist2;
return diff * diff;
}
private double setNodeHeightsFromValues(NodeRef node) {
if (!isExternal(node)) {
double maxHeight = setNodeHeightsFromValues(getChild(node, 0));
for (int i = 1; i < getChildCount(node); i++) {
double height = setNodeHeightsFromValues(getChild(node, i));
if (height > maxHeight) maxHeight = height;
}
setNodeHeight(node, maxHeight + nodeValues[node.getNumber() - getExternalNodeCount()]);
}
return getNodeHeight(node);
}
private MultivariateFunction leastSquaresClock = new MultivariateFunction() {
public double evaluate(double[] argument) {
for (int i = 0; i < getInternalNodeCount(); i++) {
nodeValues[i] = argument[i];
}
setNodeHeightsFromValues(getRoot());
if (optimizeMu) {
mu = argument[muIndex];
}
double score = getSumOfSquares();
return score;
}
public int getNumArguments() {
if (optimizeMu) {
return getInternalNodeCount() + 1;
} else {
return getInternalNodeCount();
}
}
public double getLowerBound(int n) {
if (optimizeMu && n == muIndex) {
return Double.MIN_VALUE;
} else {
return 0.0;
}
}
public double getUpperBound(int n) {
if (optimizeMu && n == muIndex) {
return Double.MAX_VALUE;
} else {
return Double.MAX_VALUE;
}
}
};
private int nodeCount;
private double[] nodeValues;
private Tree sourceTree;
private double mu;
private boolean optimizeMu;
private int muIndex;
}