/* * IntegratedMultivariateTraitLikelihood.java * * Copyright (c) 2002-2012 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.continuous; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.tree.TreeModel; import dr.inference.loggers.LogColumn; import dr.inference.model.CompoundParameter; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.math.distributions.MultivariateNormalDistribution; import dr.math.distributions.WishartSufficientStatistics; import dr.math.matrixAlgebra.Matrix; import dr.math.matrixAlgebra.SymmetricMatrix; import dr.math.matrixAlgebra.Vector; import dr.util.Author; import dr.util.Citation; import java.util.Arrays; import java.util.List; import java.util.logging.Logger; /** * A multivariate trait likelihood that analytically integrates out the unobserved trait values at all internal * and root nodes * * @author Marc A. Suchard */ public abstract class IntegratedMultivariateTraitLikelihood extends AbstractMultivariateTraitLikelihood { public static final double LOG_SQRT_2_PI = 0.5 * Math.log(2 * Math.PI); public IntegratedMultivariateTraitLikelihood(String traitName, TreeModel treeModel, MultivariateDiffusionModel diffusionModel, CompoundParameter traitParameter, List<Integer> missingIndices, boolean cacheBranches, boolean scaleByTime, boolean useTreeLength, BranchRateModel rateModel, Model samplingDensity, boolean reportAsMultivariate, boolean reciprocalRates) { this(traitName, treeModel, diffusionModel, traitParameter, null, missingIndices, cacheBranches, scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate, reciprocalRates); } public IntegratedMultivariateTraitLikelihood(String traitName, TreeModel treeModel, MultivariateDiffusionModel diffusionModel, CompoundParameter traitParameter, Parameter deltaParameter, List<Integer> missingIndices, boolean cacheBranches, boolean scaleByTime, boolean useTreeLength, BranchRateModel rateModel, Model samplingDensity, boolean reportAsMultivariate, boolean reciprocalRates) { super(traitName, treeModel, diffusionModel, traitParameter, deltaParameter, missingIndices, cacheBranches, scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate, reciprocalRates); meanCache = new double[dim * treeModel.getNodeCount()]; drawnStates = new double[dim * treeModel.getNodeCount()]; upperPrecisionCache = new double[treeModel.getNodeCount()]; lowerPrecisionCache = new double[treeModel.getNodeCount()]; logRemainderDensityCache = new double[treeModel.getNodeCount()]; if (cacheBranches) { storedMeanCache = new double[dim * treeModel.getNodeCount()]; storedUpperPrecisionCache = new double[treeModel.getNodeCount()]; storedLowerPrecisionCache = new double[treeModel.getNodeCount()]; storedLogRemainderDensityCache = new double[treeModel.getNodeCount()]; } missing = new boolean[treeModel.getNodeCount()]; Arrays.fill(missing, true); // All internal and root nodes are missing // Set up reusable temporary storage Ay = new double[dimTrait]; tmpM = new double[dimTrait][dimTrait]; tmp2 = new double[dimTrait]; zeroDimVector = new double[dim]; setTipDataValuesForAllNodes(missingIndices); } private void setTipDataValuesForAllNodes(List<Integer> missingIndices) { for (int i = 0; i < treeModel.getExternalNodeCount(); i++) { NodeRef node = treeModel.getExternalNode(i); setTipDataValuesForNode(node); } for (Integer i : missingIndices) { int whichTip = i / dim; Logger.getLogger("dr.evomodel").info( "\tMarking taxon " + treeModel.getTaxonId(whichTip) + " as completely missing"); missing[whichTip] = true; } } public double getTotalTreePrecision() { getLogLikelihood(); // Do peeling if necessary final int rootIndex = treeModel.getRoot().getNumber(); return lowerPrecisionCache[rootIndex]; } private void setTipDataValuesForNode(NodeRef node) { // Set tip data values int index = node.getNumber(); double[] traitValue = traitParameter.getParameter(index).getParameterValues(); if (traitValue.length < dim) { throw new RuntimeException("The trait parameter for the tip with index, " + index + ", is too short"); } System.arraycopy(traitValue, 0, meanCache, dim * index, dim); missing[index] = false; } public double[] getTipDataValues(int index) { double[] traitValue = new double[dim]; System.arraycopy(meanCache, dim * index, traitValue, 0, dim); return traitValue; } public void setTipDataValuesForNode(int index, double[] traitValue) { // Set tip data values System.arraycopy(traitValue, 0, meanCache, dim * index, dim); makeDirty(); } protected String extraInfo() { return "\tSample internal node traits: false\n"; } public List<Citation> getCitations() { List<Citation> citations = super.getCitations(); citations.add( new Citation( new Author[]{ new Author("O", "Pybus"), new Author("P", "Lemey"), new Author("A", "Rambaut"), new Author("MA", "Suchard") }, Citation.Status.IN_PREPARATION ) ); return citations; } public double getLogDataLikelihood() { return getLogLikelihood(); } public abstract boolean getComputeWishartSufficientStatistics(); public double calculateLogLikelihood() { double logLikelihood = 0; double[][] traitPrecision = diffusionModel.getPrecisionmatrix(); double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix()); double[] conditionalRootMean = tmp2; final boolean computeWishartStatistics = getComputeWishartSufficientStatistics(); if (computeWishartStatistics) { // if (wishartStatistics == null) { wishartStatistics = new WishartSufficientStatistics(dimTrait); // } else { // wishartStatistics.clear(); // } } // Use dynamic programming to compute conditional likelihoods at each internal node postOrderTraverse(treeModel, treeModel.getRoot(), traitPrecision, logDetTraitPrecision, computeWishartStatistics); if (DEBUG) { System.err.println("mean: " + new Vector(meanCache)); System.err.println("upre: " + new Vector(upperPrecisionCache)); System.err.println("lpre: " + new Vector(lowerPrecisionCache)); System.err.println("cach: " + new Vector(logRemainderDensityCache)); } // Compute the contribution of each datum at the root final int rootIndex = treeModel.getRoot().getNumber(); // Precision scalar of datum conditional on root double conditionalRootPrecision = lowerPrecisionCache[rootIndex]; for (int datum = 0; datum < numData; datum++) { double thisLogLikelihood = 0; // Get conditional mean of datum conditional on root System.arraycopy(meanCache, rootIndex * dim + datum * dimTrait, conditionalRootMean, 0, dimTrait); if (DEBUG) { System.err.println("Datum #" + datum); System.err.println("root mean: " + new Vector(conditionalRootMean)); System.err.println("root prec: " + conditionalRootPrecision); System.err.println("diffusion prec: " + new Matrix(traitPrecision)); } // B = root prior precision // z = root prior mean // A = likelihood precision // y = likelihood mean // y'Ay double yAy = computeWeightedAverageAndSumOfSquares(conditionalRootMean, Ay, traitPrecision, dimTrait, conditionalRootPrecision); // Also fills in Ay if (conditionalRootPrecision != 0) { thisLogLikelihood += -LOG_SQRT_2_PI * dimTrait + 0.5 * (logDetTraitPrecision + dimTrait * Math.log(conditionalRootPrecision) - yAy); } if (DEBUG) { double[][] T = new double[dimTrait][dimTrait]; for (int i = 0; i < dimTrait; i++) { for (int j = 0; j < dimTrait; j++) { T[i][j] = traitPrecision[i][j] * conditionalRootPrecision; } } System.err.println("Conditional root MVN precision = \n" + new Matrix(T)); System.err.println("Conditional root MVN density = " + MultivariateNormalDistribution.logPdf( conditionalRootMean, new double[dimTrait], T, Math.log(MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate(T)), 1.0)); } if (integrateRoot) { // Integrate root trait out against rootPrior thisLogLikelihood += integrateLogLikelihoodAtRoot(conditionalRootMean, Ay, tmpM, traitPrecision, conditionalRootPrecision); // Ay is destroyed } if (DEBUG) { System.err.println("yAy = " + yAy); System.err.println("logLikelihood (before remainders) = " + thisLogLikelihood + " (should match conditional root MVN density when root not integrated out)"); } logLikelihood += thisLogLikelihood; } logLikelihood += sumLogRemainders(); if (DEBUG) { // Root trait is univariate!!! System.err.println("logLikelihood (final) = " + logLikelihood); // checkViaLargeMatrixInversion(); } if (DEBUG_PNAS) { checkLogLikelihood(logLikelihood, sumLogRemainders(), conditionalRootMean, conditionalRootPrecision, traitPrecision); } areStatesRedrawn = false; // Should redraw internal node states when needed return logLikelihood; } protected void checkLogLikelihood(double loglikelihood, double logRemainders, double[] conditionalRootMean, double conditionalRootPrecision, double[][] traitPrecision) { // Do nothing; for checking PNAS paper } protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { if (variable == traitParameter) { // A tip value got updated if (index > dimTrait * treeModel.getExternalNodeCount()) { throw new RuntimeException("Attempting to update an invalid index"); } meanCache[index] = traitParameter.getValue(index); likelihoodKnown = false; } super.handleVariableChangedEvent(variable, index, type); } protected static double computeWeightedAverageAndSumOfSquares(double[] y, double[] Ay, double[][] A, int dim, double scale) { // returns Ay and yAy double yAy = 0; for (int i = 0; i < dim; i++) { Ay[i] = 0; for (int j = 0; j < dim; j++) Ay[i] += A[i][j] * y[j] * scale; yAy += y[i] * Ay[i]; } return yAy; } private double sumLogRemainders() { double sumLogRemainders = 0; for (double r : logRemainderDensityCache) sumLogRemainders += r; // Could skip leafs return sumLogRemainders; } protected abstract double integrateLogLikelihoodAtRoot(double[] conditionalRootMean, double[] marginalRootMean, double[][] temporaryStorage, double[][] treePrecisionMatrix, double conditionalRootPrecision); public void makeDirty() { super.makeDirty(); areStatesRedrawn = false; } void postOrderTraverse(TreeModel treeModel, NodeRef node, double[][] precisionMatrix, double logDetPrecisionMatrix, boolean cacheOuterProducts) { final int thisNumber = node.getNumber(); if (treeModel.isExternal(node)) { // Fill in precision scalar, traitValues already filled in if (missing[thisNumber]) { upperPrecisionCache[thisNumber] = 0; lowerPrecisionCache[thisNumber] = 0; // Needed in the pre-order traversal } else { // not missing tip trait upperPrecisionCache[thisNumber] = 1.0 / getRescaledBranchLength(node); lowerPrecisionCache[thisNumber] = Double.POSITIVE_INFINITY; } return; } final NodeRef childNode0 = treeModel.getChild(node, 0); final NodeRef childNode1 = treeModel.getChild(node, 1); postOrderTraverse(treeModel, childNode0, precisionMatrix, logDetPrecisionMatrix, cacheOuterProducts); postOrderTraverse(treeModel, childNode1, precisionMatrix, logDetPrecisionMatrix, cacheOuterProducts); final int childNumber0 = childNode0.getNumber(); final int childNumber1 = childNode1.getNumber(); final int meanOffset0 = dim * childNumber0; final int meanOffset1 = dim * childNumber1; final int meanThisOffset = dim * thisNumber; final double precision0 = upperPrecisionCache[childNumber0]; final double precision1 = upperPrecisionCache[childNumber1]; final double totalPrecision = precision0 + precision1; lowerPrecisionCache[thisNumber] = totalPrecision; // Multiple child0 and child1 densities if (totalPrecision == 0) { System.arraycopy(zeroDimVector, 0, meanCache, meanThisOffset, dim); } else { // computeWeightedMeanCache(meanThisOffset, meanOffset0, meanOffset1, precision0, precision1); computeWeightedAverage( meanCache, meanOffset0, precision0, meanCache, meanOffset1, precision1, meanCache, meanThisOffset, dim); } if (!treeModel.isRoot(node)) { // Integrate out trait value at this node double thisPrecision = 1.0 / getRescaledBranchLength(node); upperPrecisionCache[thisNumber] = totalPrecision * thisPrecision / (totalPrecision + thisPrecision); } // Compute logRemainderDensity logRemainderDensityCache[thisNumber] = 0; if (precision0 != 0 && precision1 != 0) { incrementRemainderDensities( precisionMatrix, logDetPrecisionMatrix, thisNumber, meanThisOffset, meanOffset0, meanOffset1, precision0, precision1, cacheOuterProducts); } } private void incrementRemainderDensities(double[][] precisionMatrix, double logDetPrecisionMatrix, int thisIndex, int thisOffset, int childOffset0, int childOffset1, double precision0, double precision1, boolean cacheOuterProducts) { final double remainderPrecision = precision0 * precision1 / (precision0 + precision1); if (cacheOuterProducts) { incrementOuterProducts(thisOffset, childOffset0, childOffset1, precision0, precision1); } for (int k = 0; k < numData; k++) { double childSS0 = 0; double childSS1 = 0; double crossSS = 0; for (int i = 0; i < dimTrait; i++) { final double wChild0i = meanCache[childOffset0 + k * dimTrait + i] * precision0; final double wChild1i = meanCache[childOffset1 + k * dimTrait + i] * precision1; for (int j = 0; j < dimTrait; j++) { final double child0j = meanCache[childOffset0 + k * dimTrait + j]; final double child1j = meanCache[childOffset1 + k * dimTrait + j]; childSS0 += wChild0i * precisionMatrix[i][j] * child0j; childSS1 += wChild1i * precisionMatrix[i][j] * child1j; crossSS += (wChild0i + wChild1i) * precisionMatrix[i][j] * meanCache[thisOffset + k * dimTrait + j]; } } logRemainderDensityCache[thisIndex] += -dimTrait * LOG_SQRT_2_PI + 0.5 * (dimTrait * Math.log(remainderPrecision) + logDetPrecisionMatrix) - 0.5 * (childSS0 + childSS1 - crossSS); } } private void incrementOuterProducts(int thisOffset, int childOffset0, int childOffset1, double precision0, double precision1) { final double[][] outerProduct = wishartStatistics.getScaleMatrix(); for (int k = 0; k < numData; k++) { for (int i = 0; i < dimTrait; i++) { final double wChild0i = meanCache[childOffset0 + k * dimTrait + i] * precision0; final double wChild1i = meanCache[childOffset1 + k * dimTrait + i] * precision1; for (int j = 0; j < dimTrait; j++) { final double child0j = meanCache[childOffset0 + k * dimTrait + j]; final double child1j = meanCache[childOffset1 + k * dimTrait + j]; outerProduct[i][j] += wChild0i * child0j; outerProduct[i][j] += wChild1i * child1j; outerProduct[i][j] -= (wChild0i + wChild1i) * meanCache[thisOffset + k * dimTrait + j]; } } } wishartStatistics.incrementDf(1); // Peeled one node } // private void computeWeightedMeanCache(int thisOffset, // int childOffset0, // int childOffset1, // double precision0, // double precision1) { // // final double totalVariance = 1.0 / (precision0 + precision1); // for (int i = 0; i < dim; i++) { // meanCache[thisOffset + i] = (meanCache[childOffset0 + i] * precision0 + // meanCache[childOffset1 + i] * precision1) // * totalVariance; // } // } protected double[] getRootNodeTrait() { return getTraitForNode(treeModel, treeModel.getRoot(), traitName); } public double[] getTraitForNode(Tree tree, NodeRef node, String traitName) { // if (tree != treeModel) { // throw new RuntimeException("Can only reconstruct states on treeModel given to constructor"); // } getLogLikelihood(); if (!areStatesRedrawn) redrawAncestralStates(); int index = node.getNumber(); double[] trait = new double[dim]; System.arraycopy(drawnStates, index * dim, trait, 0, dim); return trait; } public void redrawAncestralStates() { double[][] treePrecision = diffusionModel.getPrecisionmatrix(); double[][] treeVariance = new SymmetricMatrix(treePrecision).inverse().toComponents(); preOrderTraverseSample(treeModel, treeModel.getRoot(), 0, treePrecision, treeVariance); if (DEBUG) { System.err.println("all draws = " + new Vector(drawnStates)); } areStatesRedrawn = true; } public void storeState() { super.storeState(); if (cacheBranches) { System.arraycopy(meanCache, 0, storedMeanCache, 0, meanCache.length); System.arraycopy(upperPrecisionCache, 0, storedUpperPrecisionCache, 0, upperPrecisionCache.length); System.arraycopy(lowerPrecisionCache, 0, storedLowerPrecisionCache, 0, lowerPrecisionCache.length); System.arraycopy(logRemainderDensityCache, 0, storedLogRemainderDensityCache, 0, logRemainderDensityCache.length); } } public void restoreState() { super.restoreState(); if (cacheBranches) { double[] tmp; tmp = storedMeanCache; storedMeanCache = meanCache; meanCache = tmp; tmp = storedUpperPrecisionCache; storedUpperPrecisionCache = upperPrecisionCache; upperPrecisionCache = tmp; tmp = storedLowerPrecisionCache; storedLowerPrecisionCache = lowerPrecisionCache; lowerPrecisionCache = tmp; tmp = storedLogRemainderDensityCache; storedLogRemainderDensityCache = logRemainderDensityCache; logRemainderDensityCache = tmp; } } // Computes x^t A y, used many times in these computations protected static double computeQuadraticProduct(double[] x, double[][] A, double[] y, int dim) { double sum = 0; for (int i = 0; i < dim; i++) { for (int j = 0; j < dim; j++) { sum += x[i] * A[i][j] * y[j]; } } return sum; } // Computes the weighted average of two vectors, used many times in these computations protected static void computeWeightedAverage(double[] in0, int offset0, double weight0, double[] in1, int offset1, double weight1, double[] out2, int offset2, int length) { final double totalInverseWeight = 1.0 / (weight0 + weight1); for (int i = 0; i < length; i++) { out2[offset2 + i] = (in0[offset0 + i] * weight0 + in1[offset1 + i] * weight1) * totalInverseWeight; } } protected abstract double[][] computeMarginalRootMeanAndVariance(double[] conditionalRootMean, double[][] treePrecisionMatrix, double[][] treeVarianceMatrix, double conditionalRootPrecision); private void preOrderTraverseSample(TreeModel treeModel, NodeRef node, int parentIndex, double[][] treePrecision, double[][] treeVariance) { final int thisIndex = node.getNumber(); if (treeModel.isRoot(node)) { // draw root double[] rootMean = new double[dimTrait]; final int rootIndex = treeModel.getRoot().getNumber(); double rootPrecision = lowerPrecisionCache[rootIndex]; for (int datum = 0; datum < numData; datum++) { System.arraycopy(meanCache, thisIndex * dim + datum * dimTrait, rootMean, 0, dimTrait); double[][] variance = computeMarginalRootMeanAndVariance(rootMean, treePrecision, treeVariance, rootPrecision); double[] draw = MultivariateNormalDistribution.nextMultivariateNormalVariance(rootMean, variance); if (DEBUG_PREORDER) { Arrays.fill(draw, 1.0); } System.arraycopy(draw, 0, drawnStates, rootIndex * dim + datum * dimTrait, dimTrait); if (DEBUG) { System.err.println("Root mean: " + new Vector(rootMean)); System.err.println("Root var : " + new Matrix(variance)); System.err.println("Root draw: " + new Vector(draw)); } } } else { // draw conditional on parentState if (!missing[thisIndex]) { System.arraycopy(meanCache, thisIndex * dim, drawnStates, thisIndex * dim, dim); } else { // This code should work for sampling a missing tip trait as well, but needs testing // parent trait at drawnStates[parentOffset] double precisionToParent = 1.0 / getRescaledBranchLength(node); double precisionOfNode = lowerPrecisionCache[thisIndex]; double totalPrecision = precisionOfNode + precisionToParent; double[] mean = Ay; // temporary storage double[][] var = tmpM; // temporary storage for (int datum = 0; datum < numData; datum++) { int parentOffset = parentIndex * dim + datum * dimTrait; int thisOffset = thisIndex * dim + datum * dimTrait; if (DEBUG) { double[] parentValue = new double[dimTrait]; System.arraycopy(drawnStates, parentOffset, parentValue, 0, dimTrait); System.err.println("Parent draw: " + new Vector(parentValue)); if (parentValue[0] != drawnStates[parentOffset]) { throw new RuntimeException("Error in setting indices"); } } for (int i = 0; i < dimTrait; i++) { mean[i] = (drawnStates[parentOffset + i] * precisionToParent + meanCache[thisOffset + i] * precisionOfNode) / totalPrecision; for (int j = 0; j < dimTrait; j++) { var[i][j] = treeVariance[i][j] / totalPrecision; } } double[] draw = MultivariateNormalDistribution.nextMultivariateNormalVariance(mean, var); System.arraycopy(draw, 0, drawnStates, thisOffset, dimTrait); if (DEBUG) { System.err.println("Int prec: " + totalPrecision); System.err.println("Int mean: " + new Vector(mean)); System.err.println("Int var : " + new Matrix(var)); System.err.println("Int draw: " + new Vector(draw)); System.err.println(""); } } } } if (peel() && !treeModel.isExternal(node)) { preOrderTraverseSample(treeModel, treeModel.getChild(node, 0), thisIndex, treePrecision, treeVariance); preOrderTraverseSample(treeModel, treeModel.getChild(node, 1), thisIndex, treePrecision, treeVariance); } } protected boolean peel() { return true; } public LogColumn[] getColumns() { return new LogColumn[]{ new LikelihoodColumn(getId())}; } protected boolean areStatesRedrawn = false; protected double[] meanCache; protected double[] upperPrecisionCache; protected double[] lowerPrecisionCache; private double[] logRemainderDensityCache; protected boolean[] missing; private double[] storedMeanCache; private double[] storedUpperPrecisionCache; private double[] storedLowerPrecisionCache; private double[] storedLogRemainderDensityCache; private double[] drawnStates; protected final boolean integrateRoot = true; // Set to false if conditioning on root value (not fully implemented) protected static boolean DEBUG = false; protected static boolean DEBUG_PREORDER = false; protected static boolean DEBUG_PNAS = false; private double[] zeroDimVector; protected WishartSufficientStatistics wishartStatistics; // Reusable temporary storage protected double[] Ay; protected double[][] tmpM; protected double[] tmp2; }