/*
* FullyConjugateMultivariateTraitLikelihood.java
*
* Copyright (c) 2002-2012 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.continuous;
import dr.evolution.tree.NodeRef;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.interfaces.ConjugateWishartStatisticsProvider;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.Vector;
import java.util.List;
/**
* Integrated multivariate trait likelihood that assumes a fully-conjugate prior on the root.
* The fully-conjugate prior is a multivariate normal distribution with a precision scaled by
* diffusion process
*
* @author Marc A. Suchard
*/
public class FullyConjugateMultivariateTraitLikelihood extends IntegratedMultivariateTraitLikelihood implements ConjugateWishartStatisticsProvider {
public FullyConjugateMultivariateTraitLikelihood(String traitName,
TreeModel treeModel,
MultivariateDiffusionModel diffusionModel,
CompoundParameter traitParameter,
Parameter deltaParameter,
List<Integer> missingIndices,
boolean cacheBranches,
boolean scaleByTime,
boolean useTreeLength,
BranchRateModel rateModel,
Model samplingDensity,
boolean reportAsMultivariate,
double[] rootPriorMean,
double rootPriorSampleSize,
boolean reciprocalRates) {
super(traitName, treeModel, diffusionModel, traitParameter, deltaParameter, missingIndices, cacheBranches, scaleByTime,
useTreeLength, rateModel, samplingDensity, reportAsMultivariate, reciprocalRates);
// fully-conjugate multivariate normal with own mean and prior sample size
this.rootPriorMean = rootPriorMean;
this.rootPriorSampleSize = rootPriorSampleSize;
priorInformationKnown = false;
}
protected double getRescaledLengthToRoot(NodeRef nodeRef) {
double length = 0;
NodeRef parent = treeModel.getParent(nodeRef);
if (!treeModel.isRoot(parent)) {
length += getRescaledLengthToRoot(parent);
}
length += getRescaledBranchLength(nodeRef);
return length;
}
protected double calculateAscertainmentCorrection(int taxonIndex) {
NodeRef tip = treeModel.getNode(taxonIndex);
int nodeIndex = treeModel.getNode(taxonIndex).getNumber();
if (ascertainedData == null) { // Assumes that ascertained data are fixed
ascertainedData = new double[dimTrait];
}
// diffusionModel.diffusionPrecisionMatrixParameter.setParameterValue(0,2); // For debugging non-1 values
double[][] traitPrecision = diffusionModel.getPrecisionmatrix();
double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix());
double lengthToRoot = getRescaledLengthToRoot(tip);
double marginalPrecisionScalar = 1.0 / lengthToRoot + rootPriorSampleSize;
double logLikelihood = 0;
for (int datum = 0; datum < numData; ++datum) {
// Get observed trait value
System.arraycopy(meanCache, nodeIndex * dim + datum * dimTrait, ascertainedData, 0, dimTrait);
if (DEBUG_ASCERTAINMENT) {
System.err.println("Datum #" + datum);
System.err.println("Value: " + new Vector(ascertainedData));
System.err.println("Cond : " + lengthToRoot);
System.err.println("MargV: " + 1.0 / marginalPrecisionScalar);
System.err.println("MargP: " + marginalPrecisionScalar);
System.err.println("diffusion prec: " + new Matrix(traitPrecision));
}
double SSE;
if (dimTrait > 1) {
throw new RuntimeException("Still need to implement multivariate ascertainment correction");
} else {
double precision = traitPrecision[0][0] * marginalPrecisionScalar;
SSE = ascertainedData[0] * precision * ascertainedData[0];
}
double thisLogLikelihood = -LOG_SQRT_2_PI * dimTrait
+ 0.5 * (logDetTraitPrecision + dimTrait * Math.log(marginalPrecisionScalar) - SSE);
if (DEBUG_ASCERTAINMENT) {
System.err.println("LogLik: " + thisLogLikelihood);
dr.math.distributions.NormalDistribution normal = new dr.math.distributions.NormalDistribution(0,
Math.sqrt(1.0 / (traitPrecision[0][0] * marginalPrecisionScalar)));
System.err.println("TTTLik: " + normal.logPdf(ascertainedData[0]));
if (datum >= 10) {
System.exit(-1);
}
}
logLikelihood += thisLogLikelihood;
}
return logLikelihood;
}
// public double getRootPriorSampleSize() {
// return rootPriorSampleSize;
// }
// public double[] getRootPriorMean() {
// double[] out = new double[rootPriorMean.length];
// System.arraycopy(rootPriorMean, 0, out, 0, out.length);
// return out;
// }
public WishartSufficientStatistics getWishartStatistics() {
computeWishartStatistics = true;
calculateLogLikelihood();
computeWishartStatistics = false;
return wishartStatistics;
}
// private double getLogPrecisionDetermination() {
// return Math.log(diffusionModel.getDeterminantPrecisionMatrix()) + dimTrait * Math.log(rootPriorSampleSize);
// }
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model == diffusionModel) {
priorInformationKnown = false;
}
super.handleModelChangedEvent(model, object, index);
}
public void restoreState() {
super.restoreState();
priorInformationKnown = false;
}
public void makeDirty() {
super.makeDirty();
priorInformationKnown = false;
}
@Override
public boolean getComputeWishartSufficientStatistics() {
return computeWishartStatistics;
}
protected void checkLogLikelihood(double loglikelihood, double logRemainders,
double[] conditionalRootMean, double conditionalRootPrecision,
double[][] traitPrecision) {
// System.err.println("root cmean : " + new Vector(conditionalRootMean));
// System.err.println("root cprec : " + conditionalRootPrecision);
// System.err.println("diffusion prec: " + new Matrix(traitPrecision));
//
// System.err.println("prior mean : " + new Vector(rootPriorMean));
// System.err.println("prior prec : " + rootPriorSampleSize);
double upperPrecision = conditionalRootPrecision * rootPriorSampleSize / (conditionalRootPrecision + rootPriorSampleSize);
// System.err.println("root cprec : " + upperPrecision);
double[][] newPrec = new double[traitPrecision.length][traitPrecision.length];
for (int i = 0; i < traitPrecision.length; ++i) {
for (int j = 0; j < traitPrecision.length; ++j) {
newPrec[i][j] = traitPrecision[i][j] * upperPrecision;
}
}
MultivariateNormalDistribution mvn = new MultivariateNormalDistribution(rootPriorMean, newPrec);
double logPdf = mvn.logPdf(conditionalRootMean);
System.err.println("Got here subclass: " + loglikelihood);
System.err.println("logValue : " + (logRemainders + logPdf));
System.err.println("");
// System.err.println("logRemainders : " + logRemainders);
// System.err.println("logPDF : " + logPdf);
// System.exit(-1);
}
protected double integrateLogLikelihoodAtRoot(double[] conditionalRootMean,
double[] marginalRootMean,
double[][] notUsed,
double[][] treePrecisionMatrix, double conditionalRootPrecision) {
final double square;
final double marginalPrecision = conditionalRootPrecision + rootPriorSampleSize;
final double marginalVariance = 1.0 / marginalPrecision;
// square : (Ay + Bz)' (A+B)^{-1} (Ay + Bz)
// A = conditionalRootPrecision * treePrecisionMatrix
// B = rootPriorSampleSize * treePrecisionMatrix
if (dimTrait > 1) {
computeWeightedAverage(conditionalRootMean, 0, conditionalRootPrecision, rootPriorMean, 0, rootPriorSampleSize,
marginalRootMean, 0, dimTrait);
square = computeQuadraticProduct(marginalRootMean, treePrecisionMatrix, marginalRootMean, dimTrait)
* marginalPrecision;
if (computeWishartStatistics) {
final double[][] outerProducts = wishartStatistics.getScaleMatrix();
final double weight = conditionalRootPrecision * rootPriorSampleSize * marginalVariance;
for (int i = 0; i < dimTrait; i++) {
final double diffi = conditionalRootMean[i] - rootPriorMean[i];
for (int j = 0; j < dimTrait; j++) {
outerProducts[i][j] += diffi * weight * (conditionalRootMean[j] - rootPriorMean[j]);
}
}
wishartStatistics.incrementDf(1);
}
} else {
// 1D is very simple
final double x = conditionalRootMean[0] * conditionalRootPrecision + rootPriorMean[0] * rootPriorSampleSize;
square = x * x * treePrecisionMatrix[0][0] * marginalVariance;
if (computeWishartStatistics) {
final double[][] outerProducts = wishartStatistics.getScaleMatrix();
final double y = conditionalRootMean[0] - rootPriorMean[0];
outerProducts[0][0] += y * y * conditionalRootPrecision * rootPriorSampleSize * marginalVariance;
wishartStatistics.incrementDf(1);
}
}
if (!priorInformationKnown) {
setRootPriorSumOfSquares(treePrecisionMatrix);
}
final double retValue = 0.5 * (dimTrait * Math.log(rootPriorSampleSize * marginalVariance) - zBz + square);
if (DEBUG) {
System.err.println("(Ay+Bz)(A+B)^{-1}(Ay+Bz) = " + square);
System.err.println("density = " + retValue);
System.err.println("zBz = " + zBz);
}
return retValue;
}
private void setRootPriorSumOfSquares(double[][] treePrecisionMatrix) {
zBz = computeQuadraticProduct(rootPriorMean, treePrecisionMatrix, rootPriorMean, dimTrait) * rootPriorSampleSize;
priorInformationKnown = true;
}
protected double[][] computeMarginalRootMeanAndVariance(double[] conditionalRootMean,
double[][] notUsed,
double[][] treeVarianceMatrix,
double conditionalRootPrecision) {
final double[][] outVariance = tmpM; // Use a temporary buffer, will stay valid for only a short while
computeWeightedAverage(conditionalRootMean, 0, conditionalRootPrecision, rootPriorMean, 0, rootPriorSampleSize,
conditionalRootMean, 0, dimTrait);
final double totalVariance = 1.0 / (conditionalRootPrecision + rootPriorSampleSize);
for (int i = 0; i < dimTrait; i++) {
for (int j = 0; j < dimTrait; j++) {
outVariance[i][j] = treeVarianceMatrix[i][j] * totalVariance;
}
}
return outVariance;
}
protected double[] rootPriorMean;
protected double rootPriorSampleSize;
private boolean priorInformationKnown = false;
private double zBz; // Prior sum-of-squares contribution
protected boolean computeWishartStatistics = false;
private double[] ascertainedData = null;
private static final boolean DEBUG_ASCERTAINMENT = false;
}