package dr.evomodel.continuous;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
/**
* @author Marc Suchard
*/
public abstract class BivariateTraitBranchAttributeProvider extends TreeTrait.DefaultBehavior implements TreeTrait<Double> {
public static final String FORMAT = "%5.4f";
public BivariateTraitBranchAttributeProvider(AbstractMultivariateTraitLikelihood traitLikelihood) {
traitName = traitLikelihood.getTraitParameter().getId();
label = traitName + extensionName();
double[] rootTrait = traitLikelihood.getRootNodeTrait();
if (rootTrait.length != 2)
throw new RuntimeException("BivariateTraitBranchAttributeProvider only works for 2D traits");
}
protected abstract String extensionName();
protected double branchFunction(double[] startValue, double[] endValue, double startTime, double endTime) {
return convert(endValue[0]-startValue[0], endValue[1] - startValue[1], startTime - endTime);
}
protected abstract double convert(double latDifference, double longDifference, double timeDifference);
public String getTraitName() {
return label;
}
public Intent getIntent() {
return Intent.BRANCH;
}
public Class getTraitClass() {
return Double.class;
}
public int getDimension() {
return 1;
}
public Double getTrait(Tree tree, NodeRef node) {
if (tree != traitLikelihood.getTreeModel())
throw new RuntimeException("Bad bug.");
NodeRef parent = tree.getParent(node);
double[] startTrait = traitLikelihood.getTraitForNode(tree, parent, traitName);
double[] endTrait = traitLikelihood.getTraitForNode(tree, node, traitName);
double startTime = tree.getNodeHeight(parent);
double endTime = tree.getNodeHeight(node);
return branchFunction(startTrait, endTrait, startTime, endTime);
}
public String getTraitString(Tree tree, NodeRef node) {
NodeRef parent = tree.getParent(node);
double[] startTrait = traitLikelihood.getTraitForNode(tree, parent, traitName);
double[] endTrait = traitLikelihood.getTraitForNode(tree, node, traitName);
double startTime = tree.getNodeHeight(parent);
double endTime = tree.getNodeHeight(node);
return String.format(BivariateTraitBranchAttributeProvider.FORMAT,
branchFunction(startTrait, endTrait, startTime, endTime));
}
protected AbstractMultivariateTraitLikelihood traitLikelihood;
protected String traitName;
protected String label;
}