/* * MultivariateTraitUtils.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.continuous; import dr.evolution.tree.MultivariateTraitTree; import dr.evolution.tree.NodeRef; import dr.evolution.tree.TreeUtils; import dr.math.KroneckerOperation; import dr.math.matrixAlgebra.Matrix; import dr.math.matrixAlgebra.SymmetricMatrix; import java.util.HashSet; import java.util.Set; /** * @author Marc A. Suchard */ public class MultivariateTraitUtils { public static NodeRef findMRCA(FullyConjugateMultivariateTraitLikelihood trait, int iTip, int jTip) { MultivariateTraitTree treeModel = trait.getTreeModel(); Set<String> leafNames = new HashSet<String>(); leafNames.add(treeModel.getTaxonId(iTip)); leafNames.add(treeModel.getTaxonId(jTip)); return TreeUtils.getCommonAncestorNode(treeModel, leafNames); } public static double[][] computeTreePrecision(FullyConjugateMultivariateTraitLikelihood trait, boolean conditionOnRoot) { if (trait.strengthOfSelection != null) { return new SymmetricMatrix(computeTreeVarianceOU(trait, conditionOnRoot)).inverse().toComponents(); } else { return new SymmetricMatrix(computeTreeVariance(trait, conditionOnRoot)).inverse().toComponents(); } } public static double[][] computeTreeTraitPrecision(FullyConjugateMultivariateTraitLikelihood trait, boolean conditionOnRoot) { double[][] treePrecision = computeTreePrecision(trait, conditionOnRoot); double[][] traitPrecision = trait.getDiffusionModel().getPrecisionmatrix(); return productKronecker(treePrecision, traitPrecision); } private static double[][] productKronecker(double[][] A, double[][] B) { if (B.length > 1) { A = KroneckerOperation.product(A, B); } else { final double b = B[0][0]; for (int i = 0; i < A.length; ++i) { for (int j = 0; j < A[i].length; ++j) { A[i][j] *= b; } } } return A; } private static double[][] productMatrices(double[][] A, double[][] B) { double[][] C = new double[A.length][B[0].length]; for (int i = 0; i < A.length; i++) { for (int k = 0; k < B[0].length; k++) { for (int j = 0; j < A[0].length; j++) { C[i][k] = C[i][k] + A[i][j] * B[j][k]; } } } return C; } private static double[][] transposeMatrix(double[][] A) { double[][] B = new double[A[0].length][A.length]; for (int i = 0; i < A.length; i++) { for (int j = 0; j < A[0].length; j++) { B[j][i] = A[i][j]; } } return B; } private static double[][] computeLinCombMatrix(FullyConjugateMultivariateTraitLikelihood trait) { MultivariateTraitTree treeModel = trait.getTreeModel(); final int tipCount = treeModel.getExternalNodeCount(); final int branchCount = 2 * tipCount - 2; double[][] linCombMatrix = new double[tipCount][branchCount]; double tempScalar; NodeRef tempNode; int tempNodeIndex; for (int k = 0; k < tipCount; k++) { tempNode = treeModel.getExternalNode(k); //check if treeModel.getExternalNode(k).getNumber() == k tempScalar = 1; tempNodeIndex = k; for (int r = 0; r < branchCount; r++) { if (r == tempNodeIndex) { linCombMatrix[k][r] = tempScalar; // tempScalar = tempScalar * (1 - treeModel.getBranchLength(tempNode) * trait.getTimeScaledSelection(tempNode)); tempScalar = tempScalar * Math.exp(-trait.getTimeScaledSelection(tempNode)); tempNode = treeModel.getParent(tempNode); tempNodeIndex = tempNode.getNumber(); } else { linCombMatrix[k][r] = 0; } } } return linCombMatrix; } private static double[] computeRootMultipliers(FullyConjugateMultivariateTraitLikelihood trait) { MultivariateTraitTree myTreeModel = trait.getTreeModel(); final int tipCount = myTreeModel.getExternalNodeCount(); double[] multiplierVect = new double[tipCount]; NodeRef tempNode; for (int i = 0; i < tipCount; i++) { tempNode = myTreeModel.getExternalNode(i); multiplierVect[i] = Math.exp(-trait.getTimeScaledSelection(tempNode)); tempNode = myTreeModel.getParent(tempNode); while (!myTreeModel.isRoot(tempNode)) { multiplierVect[i] = multiplierVect[i] * Math.exp(-trait.getTimeScaledSelection(tempNode)); tempNode = myTreeModel.getParent(tempNode); } } return multiplierVect; } private static double[] getShiftContributionToMean(NodeRef node, FullyConjugateMultivariateTraitLikelihood trait) { MultivariateTraitTree treeModel = trait.getTreeModel(); double shiftContribution[] = new double[trait.dimTrait]; if (!treeModel.isRoot(node)) { NodeRef parent = treeModel.getParent(node); double shiftContributionParent[] = getShiftContributionToMean(parent, trait); for (int i = 0; i < shiftContribution.length; ++i) { shiftContribution[i] = trait.getShiftForBranchLength(node)[i] + shiftContributionParent[i]; } } return shiftContribution; } public static double[] computeTreeTraitMean(FullyConjugateMultivariateTraitLikelihood trait, double[] rootValue, boolean conditionOnRoot) { double[] root = trait.getPriorMean(); if (conditionOnRoot) { System.err.println("WARNING: Not yet fully implemented (conditioning on root in simulator)"); //root = new double[root.length]; root = rootValue; } final int nTaxa = trait.getTreeModel().getExternalNodeCount(); double[] mean = new double[root.length * nTaxa]; for (int i = 0; i < nTaxa; ++i) { System.arraycopy(root, 0, mean, i * root.length, root.length); } if (trait.driftModels != null) { MultivariateTraitTree myTreeModel = trait.getTreeModel(); for (int i = 0; i < nTaxa; ++i) { double[] shiftContribution = getShiftContributionToMean(myTreeModel.getExternalNode(i), trait); for (int j = 0; j < trait.dimTrait; ++j) { mean[i * trait.dimTrait + j] = mean[i * trait.dimTrait + j] + shiftContribution[j]; } } } return mean; } public static double[] computeTreeTraitMeanOU(FullyConjugateMultivariateTraitLikelihood trait, double[] rootValue, boolean conditionOnRoot) { double[] root = trait.getPriorMean(); MultivariateTraitTree myTreeModel = trait.getTreeModel(); double[][] linCombMatrix = computeLinCombMatrix(trait); double[] rootMultiplierVect = computeRootMultipliers(trait); if (conditionOnRoot) { root = rootValue; } final int nTaxa = myTreeModel.getExternalNodeCount(); final int branchCount = 2 * nTaxa - 2; final int traitDim = trait.dimTrait; double[] mean = new double[root.length * nTaxa]; double[] displacementMeans = new double[branchCount * traitDim]; double[] tempVect = new double[nTaxa * traitDim]; NodeRef tempNode; for (int k = 0; k < branchCount; k++) { tempNode = myTreeModel.getNode(k); for (int t = 0; t < traitDim; t++) { displacementMeans[k * traitDim + t] = (1 - Math.exp(-trait.getTimeScaledSelection(tempNode))) * trait.getOptimalValue(tempNode)[t]; } } //check this //multiply linCombMatrix with displacement means for (int i = 0; i < nTaxa; i++) { for (int j = 0; j < branchCount; j++) { for (int k = 0; k < traitDim; k++) { tempVect[i * traitDim + k] = tempVect[i * traitDim + k] + linCombMatrix[i][j] * displacementMeans[j * traitDim + k]; } } } for (int i = 0; i < nTaxa; ++i) { System.arraycopy(root, 0, mean, i * root.length, root.length); for (int j = 0; j < traitDim; ++j) { mean[i * traitDim + j] = mean[i * traitDim + j] * rootMultiplierVect[i] + tempVect[i * traitDim + j]; } } return mean; } public static double[][] computeTreeTraitVariance(FullyConjugateMultivariateTraitLikelihood trait, boolean conditionOnRoot) { double[][] treeVariance = computeTreeVariance(trait, conditionOnRoot); double[][] traitVariance = new SymmetricMatrix(trait.getDiffusionModel().getPrecisionmatrix()).inverse().toComponents(); return productKronecker(treeVariance, traitVariance); } public static double[][] computeTreeVariance(FullyConjugateMultivariateTraitLikelihood trait, boolean conditionOnRoot) { MultivariateTraitTree treeModel = trait.getTreeModel(); final int tipCount = treeModel.getExternalNodeCount(); double[][] variance = new double[tipCount][tipCount]; for (int i = 0; i < tipCount; i++) { // Fill in diagonal double marginalTime = trait.getRescaledLengthToRoot(treeModel.getExternalNode(i)); variance[i][i] = marginalTime; // Fill in upper right triangle, for (int j = i + 1; j < tipCount; j++) { NodeRef mrca = findMRCA(trait, i, j); if (DEBUG) { System.err.println(trait.getTreeModel().getRoot().getNumber()); System.err.print("Taxa pair: " + i + " : " + j + " (" + mrca.getNumber() + ") = "); } double length = trait.getRescaledLengthToRoot(mrca); if (DEBUG) { System.err.println(length); } variance[i][j] = length; } } // Make symmetric for (int i = 0; i < tipCount; i++) { for (int j = i + 1; j < tipCount; j++) { variance[j][i] = variance[i][j]; } } if (DEBUG) { System.err.println(""); System.err.println("New tree conditional variance:\n" + new Matrix(variance)); } // TODO Handle missing values // variance = removeMissingTipsInTreeVariance(variance); // Automatically prune missing tips // // if (DEBUG) { // System.err.println(""); // System.err.println("New tree (trimmed) conditional variance:\n" + new Matrix(variance)); // } if (!conditionOnRoot) { double priorSampleSize = trait.getPriorSampleSize(); for (int i = 0; i < tipCount; ++i) { for (int j = 0; j < tipCount; ++j) { variance[i][j] += 1.0 / priorSampleSize; } } if (DEBUG) { System.err.println(""); System.err.println("New tree unconditional variance:\n" + new Matrix(variance)); } } return variance; } public static double[][] computeTreeVarianceOU(FullyConjugateMultivariateTraitLikelihood trait, boolean conditionOnRoot) { MultivariateTraitTree treeModel = trait.getTreeModel(); final int tipCount = treeModel.getExternalNodeCount(); final int branchCount = 2 * tipCount - 2; double[][] variance = new double[tipCount][tipCount]; double[][] tempMatrix = new double[tipCount][branchCount]; double[][] diagMatrix = new double[branchCount][branchCount]; /* double tempScalar; NodeRef tempNode; int tempNodeIndex; for(int k = 0; k < tipCount; k++){ tempNode = treeModel.getExternalNode(k); //check if treeModel.getExternalNode(k).getNumber() == k tempScalar = 1; tempNodeIndex = k; for(int r = 0; r < branchCount; r++){ if(r == tempNodeIndex){ tempMatrix[k][r] = tempScalar; tempScalar = tempScalar*(1-treeModel.getBranchLength(tempNode)*trait.getTimeScaledSelection(tempNode)); tempNode = treeModel.getParent(tempNode); tempNodeIndex = tempNode.getNumber(); }else{ tempMatrix[k][r]= 0; } } } */ tempMatrix = computeLinCombMatrix(trait); for (int i = 0; i < branchCount; i++) { // diagMatrix[i][i] = treeModel.getBranchLength(treeModel.getNode(i)); // diagMatrix[i][i] = (2*strengthOfSelection.getBranchRate(treeModel, node) ) / (1 - Math.exp(-2*getTimeScaledSelection(node))); diagMatrix[i][i] = (1 - Math.exp(-2 * trait.getTimeScaledSelection(treeModel.getNode(i)))) / (2 * trait.strengthOfSelection.getBranchRate(treeModel, treeModel.getNode(i))); } variance = productMatrices(productMatrices(tempMatrix, diagMatrix), transposeMatrix(tempMatrix)); return variance; } private static final boolean DEBUG = false; }