/*
* FullyConjugateMultivariateTraitLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.continuous;
import dr.evolution.tree.MultivariateTraitTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.inference.model.*;
import dr.math.KroneckerOperation;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.interfaces.ConjugateWishartStatisticsProvider;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.*;
/**
* Integrated multivariate trait likelihood that assumes a fully-conjugate prior on the root.
* The fully-conjugate prior is a multivariate normal distribution with a precision scaled by
* diffusion process
*
* @author Marc A. Suchard
*/
public class FullyConjugateMultivariateTraitLikelihood extends IntegratedMultivariateTraitLikelihood
implements ConjugateWishartStatisticsProvider, GibbsSampleFromTreeInterface, Reportable {
// public FullyConjugateMultivariateTraitLikelihood(String traitName,
// MultivariateTraitTree treeModel,
// MultivariateDiffusionModel diffusionModel,
// CompoundParameter traitParameter,
// Parameter deltaParameter,
// List<Integer> missingIndices,
// boolean cacheBranches,
// boolean scaleByTime,
// boolean useTreeLength,
// BranchRateModel rateModel,
// Model samplingDensity,
// boolean reportAsMultivariate,
// double[] rootPriorMean,
// double rootPriorSampleSize,
// boolean reciprocalRates) {
//
// super(traitName, treeModel, diffusionModel, traitParameter, deltaParameter, missingIndices, cacheBranches, scaleByTime,
// useTreeLength, rateModel, samplingDensity, reportAsMultivariate, reciprocalRates);
//
// // fully-conjugate multivariate normal with own mean and prior sample size
// this.rootPriorMean = rootPriorMean;
// this.rootPriorSampleSize = rootPriorSampleSize;
//
// priorInformationKnown = false;
// }
//
//
// public FullyConjugateMultivariateTraitLikelihood(String traitName,
// MultivariateTraitTree treeModel,
// MultivariateDiffusionModel diffusionModel,
// CompoundParameter traitParameter,
// Parameter deltaParameter,
// List<Integer> missingIndices,
// boolean cacheBranches,
// boolean scaleByTime,
// boolean useTreeLength,
// BranchRateModel rateModel,
// List<BranchRateModel> driftModels,
// Model samplingDensity,
// boolean reportAsMultivariate,
// double[] rootPriorMean,
// double rootPriorSampleSize,
// boolean reciprocalRates) {
//
// super(traitName, treeModel, diffusionModel, traitParameter, deltaParameter, missingIndices, cacheBranches, scaleByTime,
// useTreeLength, rateModel, driftModels, samplingDensity, reportAsMultivariate, reciprocalRates);
//
// // fully-conjugate multivariate normal with own mean and prior sample size
// this.rootPriorMean = rootPriorMean;
// this.rootPriorSampleSize = rootPriorSampleSize;
//
// priorInformationKnown = false;
// }
public FullyConjugateMultivariateTraitLikelihood(String traitName,
MultivariateTraitTree treeModel,
MultivariateDiffusionModel diffusionModel,
CompoundParameter traitParameter,
Parameter deltaParameter,
List<Integer> missingIndices,
boolean cacheBranches,
boolean scaleByTime,
boolean useTreeLength,
BranchRateModel rateModel,
List<BranchRateModel> driftModels,
List<BranchRateModel> optimalValues,
BranchRateModel strengthOfSelection,
Model samplingDensity,
boolean reportAsMultivariate,
double[] rootPriorMean,
List<RestrictedPartials> partials,
double rootPriorSampleSize,
boolean reciprocalRates) {
super(traitName, treeModel, diffusionModel, traitParameter, deltaParameter, missingIndices, cacheBranches, scaleByTime,
useTreeLength, rateModel, driftModels, optimalValues, strengthOfSelection, samplingDensity, partials, reportAsMultivariate, reciprocalRates);
// fully-conjugate multivariate normal with own mean and prior sample size
this.rootPriorMean = rootPriorMean;
this.rootPriorSampleSize = rootPriorSampleSize;
priorInformationKnown = false;
}
//TODO temporary function so everything will compile. Need to actually write this.
public FullyConjugateMultivariateTraitLikelihood semiClone(CompoundParameter traitParameter){
return this;
}
// public FullyConjugateMultivariateTraitLikelihood semiClone(CompoundParameter traitParameter){
// return new FullyConjugateMultivariateTraitLikelihood(this.traitName, this.treeModel, this.diffusionModel, traitParameter,
// this.deltaParameter, this.missingIndices, this.cacheBranches, this.scaleByTime, this.useTreeLength, this.getBranchRateModel(),
// this.optimalValues, this.strengthOfSelection, this.samplingDensity, this.reportAsMultivariate, this.rootPriorMean,
// this.rootPriorSampleSize, this.reciprocalRates);
// }
public double getRescaledLengthToRoot(NodeRef nodeRef) {
double length = 0;
if (!treeModel.isRoot(nodeRef)) {
NodeRef parent = treeModel.getParent(nodeRef);
length += getRescaledBranchLengthForPrecision(nodeRef) + getRescaledLengthToRoot(parent);
}
return length;
}
protected double calculateAscertainmentCorrection(int taxonIndex) {
NodeRef tip = treeModel.getNode(taxonIndex);
int nodeIndex = treeModel.getNode(taxonIndex).getNumber();
if (ascertainedData == null) { // Assumes that ascertained data are fixed
ascertainedData = new double[dimTrait];
}
// diffusionModel.diffusionPrecisionMatrixParameter.setParameterValue(0,2); // For debugging non-1 values
double[][] traitPrecision = diffusionModel.getPrecisionmatrix();
double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix());
double lengthToRoot = getRescaledLengthToRoot(tip);
double marginalPrecisionScalar = 1.0 / lengthToRoot + rootPriorSampleSize;
double logLikelihood = 0;
for (int datum = 0; datum < numData; ++datum) {
// Get observed trait value
System.arraycopy(meanCache, nodeIndex * dim + datum * dimTrait, ascertainedData, 0, dimTrait);
if (DEBUG_ASCERTAINMENT) {
System.err.println("Datum #" + datum);
System.err.println("Value: " + new Vector(ascertainedData));
System.err.println("Cond : " + lengthToRoot);
System.err.println("MargV: " + 1.0 / marginalPrecisionScalar);
System.err.println("MargP: " + marginalPrecisionScalar);
System.err.println("diffusion prec: " + new Matrix(traitPrecision));
}
double SSE;
if (dimTrait > 1) {
throw new RuntimeException("Still need to implement multivariate ascertainment correction");
} else {
double precision = traitPrecision[0][0] * marginalPrecisionScalar;
SSE = ascertainedData[0] * precision * ascertainedData[0];
}
double thisLogLikelihood = -LOG_SQRT_2_PI * dimTrait
+ 0.5 * (logDetTraitPrecision + dimTrait * Math.log(marginalPrecisionScalar) - SSE);
if (DEBUG_ASCERTAINMENT) {
System.err.println("LogLik: " + thisLogLikelihood);
dr.math.distributions.NormalDistribution normal = new dr.math.distributions.NormalDistribution(0,
Math.sqrt(1.0 / (traitPrecision[0][0] * marginalPrecisionScalar)));
System.err.println("TTTLik: " + normal.logPdf(ascertainedData[0]));
if (datum >= 10) {
System.exit(-1);
}
}
logLikelihood += thisLogLikelihood;
}
return logLikelihood;
}
// public double getRootPriorSampleSize() {
// return rootPriorSampleSize;
// }
// public double[] getRootPriorMean() {
// double[] out = new double[rootPriorMean.length];
// System.arraycopy(rootPriorMean, 0, out, 0, out.length);
// return out;
// }
public WishartSufficientStatistics getWishartStatistics() {
computeWishartStatistics = true;
calculateLogLikelihood();
computeWishartStatistics = false;
return wishartStatistics;
}
@Override
public MatrixParameterInterface getPrecisionParamter() {
return diffusionModel.getPrecisionParameter();
}
// private double getLogPrecisionDetermination() {
// return Math.log(diffusionModel.getDeterminantPrecisionMatrix()) + dimTrait * Math.log(rootPriorSampleSize);
// }
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (model == diffusionModel) {
priorInformationKnown = false;
}
super.handleModelChangedEvent(model, object, index);
}
@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type){
if(variable==traitParameter &&(Parameter.ChangeType.ADDED==type || Parameter.ChangeType.REMOVED==type)){
dimKnown = false;
dim = traitParameter.getParameter(0).getDimension();
numData = dim / getDimTrait();
meanCache = new double[dim * treeModel.getNodeCount()];
storedMeanCache = new double[meanCache.length];
drawnStates = new double[dim * treeModel.getNodeCount()];
}
PostPreKnown=false;
super.handleVariableChangedEvent(variable,index,type);
}
@Override
public void storeState() {
super.storeState();
storedPostPreKnown=PostPreKnown;
storedDimKnown=dimKnown;
if(preP!=null)
System.arraycopy(preP, 0, storedPreP, 0,preP.length);
if(preMeans!=null){
for(int i = 0; i < preMeans.length; i++)
storedPreMeans[i] = preMeans[i].clone();
}
}
@Override
public void restoreState() {
super.restoreState();
PostPreKnown=storedPostPreKnown;
priorInformationKnown = false;
double[] tempPreP = storedPreP;
storedPreP = preP;
preP = tempPreP;
preMeans=storedPreMeans;
double[][] preMeansTemp = preMeans;
preMeans = storedPreMeans;
storedPreMeans = preMeansTemp;
dimKnown=storedDimKnown;
}
public void makeDirty() {
super.makeDirty();
priorInformationKnown = false;
}
public double getPriorSampleSize() {
return rootPriorSampleSize;
}
public double[] getPriorMean() {
return rootPriorMean;
}
@Override
public boolean getComputeWishartSufficientStatistics() {
return computeWishartStatistics;
}
public void doPreOrderTraversal(NodeRef node) {
if(preP==null){
preP=new double[treeModel.getNodeCount()];
storedPreP=new double[treeModel.getNodeCount()];
}
if(!dimKnown){
preMeans=new double[treeModel.getNodeCount()][dim];
storedPreMeans=new double[treeModel.getNodeCount()][dim];
dimKnown=true;
}
final int thisNumber = node.getNumber();
if (treeModel.isRoot(node)) {
preP[thisNumber] = rootPriorSampleSize;
for (int j = 0; j < dim; j++) {
preMeans[thisNumber][j]
= rootPriorMean[j % dimTrait];
}
} else {
final NodeRef parentNode = treeModel.getParent(node);
final NodeRef sibNode = getSisterNode(node);
final int parentNumber = parentNode.getNumber();
final int sibNumber = sibNode.getNumber();
/*
if (treeModel.isRoot(parentNode)){
//partial precisions
final double precisionParent = rootPriorSampleSize;
final double precisionSib = postP[sibNumber];
final double thisPrecision=1/treeModel.getBranchLength(node);
double tp= precisionParent + precisionSib;
preP[thisNumber]= tp*thisPrecision/(tp+thisPrecision);
//partial means
for (int j =0; j<dim;j++){
preMeans[thisNumber][j] = (precisionParent*preMeans[parentNumber][j] + precisionSib*rootPriorMean[j])/(precisionParent+precisionSib);
}
}else{
*/
//partial precisions
final double precisionParent = preP[parentNumber];
final double precisionSib = upperPrecisionCache[sibNumber];
final double thisPrecision = 1 / getRescaledBranchLengthForPrecision(node);
double tp = precisionParent + precisionSib;
preP[thisNumber] = tp * thisPrecision / (tp + thisPrecision);
//partial means
for (int j = 0; j < dim; j++) {
preMeans[thisNumber][j] = (precisionParent * preMeans[parentNumber][j] + precisionSib * cacheHelper.getMeanCache()[sibNumber*dim+j]) / (precisionParent + precisionSib);
}
}
if (treeModel.isExternal(node)) {
return;
} else {
doPreOrderTraversal(treeModel.getChild(node, 0));
doPreOrderTraversal(treeModel.getChild(node, 1));
}
}
public NodeRef getSisterNode(NodeRef node) {
NodeRef sib0 = treeModel.getChild(treeModel.getParent(node), 0);
NodeRef sib1 = treeModel.getChild(treeModel.getParent(node), 1);
if (sib0 == node) {
return sib1;
} else return sib0;
}
public double[] getConditionalMean(int taxa){
setup();
// double[] answer=new double[getRootNodeTrait().length];
double[] mean = new double[dim];
for (int i = 0; i < dim; i++) {
mean[i] = preMeans[taxa][i];
}
return mean;
}
public double[][] getConditionalMeans(){
setup();
// double[] answer=new double[getRootNodeTrait().length];
return preMeans;
}
public double getPrecisionFactor(int taxa){
setup();
return preP[taxa];
}
public double[] getPrecisionFactors(){
setup();
return preP;
}
public double[][] getConditionalPrecision(int taxa){
setup();
double[][] precisionParam =diffusionModel.getPrecisionmatrix();
// double[][] answer=new double[getRootNodeTrait().length][ getRootNodeTrait().length];
double p = getPrecisionFactor(taxa);
double[][] thisP = new double[dim][dim];
for (int i = 0; i < getNumData(); i++) {
for (int j = 0; j < getDimTrait(); j++) {
for (int k = 0; k < getDimTrait(); k++) {
// System.out.println("P: "+p);
// System.out.println("I: "+i+", J: "+j+" value:"+precisionParam[i][j]);
thisP[i * getDimTrait() + j][i * getDimTrait() + k] = p * precisionParam[j][k];
}
}
}
return thisP;
}
private void setup(){
if(!PostPreKnown){
double[][] traitPrecision = diffusionModel.getPrecisionmatrix();
double logDetTraitPrecision = Math.log(diffusionModel.getDeterminantPrecisionMatrix());
final boolean computeWishartStatistics = getComputeWishartSufficientStatistics();
if (computeWishartStatistics) {
wishartStatistics = new WishartSufficientStatistics(dimTrait);
}
// Use dynamic programming to compute conditional likelihoods at each internal node
postOrderTraverse(treeModel, treeModel.getRoot(), traitPrecision, logDetTraitPrecision, computeWishartStatistics);
doPreOrderTraversal(treeModel.getRoot());}
PostPreKnown=true;
}
protected void checkLogLikelihood(double loglikelihood, double logRemainders,
double[] conditionalRootMean, double conditionalRootPrecision,
double[][] traitPrecision) {
// System.err.println("root cmean : " + new Vector(conditionalRootMean));
// System.err.println("root cprec : " + conditionalRootPrecision);
// System.err.println("diffusion prec: " + new Matrix(traitPrecision));
//
// System.err.println("prior mean : " + new Vector(rootPriorMean));
// System.err.println("prior prec : " + rootPriorSampleSize);
double upperPrecision = conditionalRootPrecision * rootPriorSampleSize / (conditionalRootPrecision + rootPriorSampleSize);
// System.err.println("root cprec : " + upperPrecision);
double[][] newPrec = new double[traitPrecision.length][traitPrecision.length];
for (int i = 0; i < traitPrecision.length; ++i) {
for (int j = 0; j < traitPrecision.length; ++j) {
newPrec[i][j] = traitPrecision[i][j] * upperPrecision;
}
}
MultivariateNormalDistribution mvn = new MultivariateNormalDistribution(rootPriorMean, newPrec);
double logPdf = mvn.logPdf(conditionalRootMean);
if (Math.abs(loglikelihood - logRemainders - logPdf) > 1E-3) {
System.err.println("Got here subclass: " + loglikelihood);
System.err.println("logValue : " + (logRemainders + logPdf));
System.err.println("logRemainder = " + logRemainders);
System.err.println("");
}
// System.err.println("logRemainders : " + logRemainders);
// System.err.println("logPDF : " + logPdf);
// System.exit(-1);
}
protected double integrateLogLikelihoodAtRoot(double[] conditionalRootMean,
double[] marginalRootMean,
double[][] notUsed,
double[][] treePrecisionMatrix, double conditionalRootPrecision) {
final double square;
final double marginalPrecision = conditionalRootPrecision + rootPriorSampleSize;
final double marginalVariance = 1.0 / marginalPrecision;
// square : (Ay + Bz)' (A+B)^{-1} (Ay + Bz)
// A = conditionalRootPrecision * treePrecisionMatrix
// B = rootPriorSampleSize * treePrecisionMatrix
if (dimTrait > 1) {
computeWeightedAverage(conditionalRootMean, 0, conditionalRootPrecision, rootPriorMean, 0, rootPriorSampleSize,
marginalRootMean, 0, dimTrait);
square = computeQuadraticProduct(marginalRootMean, treePrecisionMatrix, marginalRootMean, dimTrait)
* marginalPrecision;
if (computeWishartStatistics) {
final double[] outerProducts = wishartStatistics.getScaleMatrix();
final double weight = conditionalRootPrecision * rootPriorSampleSize * marginalVariance;
for (int i = 0; i < dimTrait; i++) {
final double diffi = conditionalRootMean[i] - rootPriorMean[i];
for (int j = 0; j < dimTrait; j++) {
outerProducts[i * dimTrait + j] += diffi * weight * (conditionalRootMean[j] - rootPriorMean[j]);
}
}
wishartStatistics.incrementDf(1);
}
} else {
// 1D is very simple
final double x = conditionalRootMean[0] * conditionalRootPrecision + rootPriorMean[0] * rootPriorSampleSize;
square = x * x * treePrecisionMatrix[0][0] * marginalVariance;
if (computeWishartStatistics) {
final double[] outerProducts = wishartStatistics.getScaleMatrix();
final double y = conditionalRootMean[0] - rootPriorMean[0];
outerProducts[0] += y * y * conditionalRootPrecision * rootPriorSampleSize * marginalVariance;
wishartStatistics.incrementDf(1);
}
}
if (!priorInformationKnown) {
setRootPriorSumOfSquares(treePrecisionMatrix);
}
final double retValue = 0.5 * (dimTrait * Math.log(rootPriorSampleSize * marginalVariance) - zBz + square);
if (DEBUG) {
System.err.println("(Ay+Bz)(A+B)^{-1}(Ay+Bz) = " + square);
System.err.println("density = " + retValue);
System.err.println("zBz = " + zBz);
}
return retValue;
}
private void setRootPriorSumOfSquares(double[][] treePrecisionMatrix) {
zBz = computeQuadraticProduct(rootPriorMean, treePrecisionMatrix, rootPriorMean, dimTrait) * rootPriorSampleSize;
priorInformationKnown = true;
}
protected double[][] computeMarginalRootMeanAndVariance(double[] conditionalRootMean,
double[][] notUsed,
double[][] treeVarianceMatrix,
double conditionalRootPrecision) {
final double[][] outVariance = tmpM; // Use a temporary buffer, will stay valid for only a short while
computeWeightedAverage(conditionalRootMean, 0, conditionalRootPrecision, rootPriorMean, 0, rootPriorSampleSize,
conditionalRootMean, 0, dimTrait);
final double totalVariance = 1.0 / (conditionalRootPrecision + rootPriorSampleSize);
for (int i = 0; i < dimTrait; i++) {
for (int j = 0; j < dimTrait; j++) {
outVariance[i][j] = treeVarianceMatrix[i][j] * totalVariance;
}
}
return outVariance;
}
protected double[] rootPriorMean;
protected double rootPriorSampleSize;
double[] preP;
double[][] preMeans;
double[] storedPreP;
double[][] storedPreMeans;
Boolean PostPreKnown=false;
Boolean storedPostPreKnown=false;
private boolean priorInformationKnown = false;
private double zBz; // Prior sum-of-squares contribution
private boolean dimKnown=false;
private boolean storedDimKnown=false;
protected boolean computeWishartStatistics = false;
private double[] ascertainedData = null;
private static final boolean DEBUG_ASCERTAINMENT = false;
private double vectorMin(double[] vec) {
double min = Double.MAX_VALUE;
for (int i = 0; i < vec.length; ++i) {
min = Math.min(min, vec[i]);
}
return min;
}
private double matrixMin(double[][] mat) {
double min = Double.MAX_VALUE;
for (int i = 0; i < mat.length; ++i) {
min = Math.min(min, vectorMin(mat[i]));
}
return min;
}
private double vectorMax(double[] vec) {
double max = - Double.MAX_VALUE;
for (int i = 0; i < vec.length; ++i) {
max = Math.max(max, vec[i]);
}
return max;
}
private double matrixMax(double[][] mat) {
double max = -Double.MAX_VALUE;
for (int i = 0; i < mat.length; ++i) {
max = Math.max(max, vectorMax(mat[i]));
}
return max;
}
private double vectorSum(double[] vec) {
double sum = 0.0;
for (int i = 0; i < vec.length; ++i) {
sum += vec[i];
}
return sum;
}
private double matrixSum(double[][] mat) {
double sum = 0.0;
for (int i = 0; i < mat.length; ++i) {
sum += vectorSum(mat[i]);
}
return sum;
}
@Override
public String getReport() {
StringBuilder sb = new StringBuilder();
// sb.append(this.g)
// System.err.println("Hello");
sb.append("Tree:\n");
sb.append(getId()).append("\t");
sb.append(treeModel.toString());
sb.append("\n\n");
double[][] treeVariance = computeTreeVariance(true);
double[][] traitPrecision = getDiffusionModel().getPrecisionmatrix();
Matrix traitVariance = new Matrix(traitPrecision).inverse();
double[][] jointVariance = KroneckerOperation.product(treeVariance, traitVariance.toComponents());
sb.append("Tree variance:\n");
sb.append(new Matrix(treeVariance));
sb.append(matrixMin(treeVariance)).append("\t").append(matrixMax(treeVariance)).append("\t").append(matrixSum(treeVariance));
sb.append("\n\n");
sb.append("Trait variance:\n");
sb.append(traitVariance);
sb.append("\n\n");
// sb.append("Joint variance:\n");
// sb.append(new Matrix(jointVariance));
// sb.append("\n\n");
sb.append("Tree dim: " + treeVariance.length + "\n");
sb.append("data dim: " + jointVariance.length);
sb.append("\n\n");
double[] data = new double[jointVariance.length];
System.arraycopy(meanCache, 0, data, 0, jointVariance.length);
if (nodeToClampMap != null) {
int offset = treeModel.getExternalNodeCount() * getDimTrait();
for(Map.Entry<NodeRef, RestrictedPartials> clamps : nodeToClampMap.entrySet()) {
double[] partials = clamps.getValue().getPartials();
for (int i = 0; i < partials.length; ++i) {
data[offset] = partials[i];
++offset;
}
}
}
sb.append("Data:\n");
sb.append(new Vector(data)).append("\n");
sb.append(data.length).append("\t").append(vectorMin(data)).append("\t").append(vectorMax(data)).append("\t").append(vectorSum(data));
sb.append(treeModel.getNodeTaxon(treeModel.getExternalNode(0)).getId());
sb.append("\n\n");
MultivariateNormalDistribution mvn = new MultivariateNormalDistribution(new double[data.length], new Matrix(jointVariance).inverse().toComponents());
double logDensity = mvn.logPdf(data);
sb.append("logLikelihood: " + getLogLikelihood() + " == " + logDensity + "\n\n");
final WishartSufficientStatistics sufficientStatistics = getWishartStatistics();
final double[] outerProducts = sufficientStatistics.getScaleMatrix();
sb.append("Outer-products (DP):\n");
sb.append(new Vector(outerProducts));
sb.append(sufficientStatistics.getDf() + "\n");
Matrix treePrecision = new Matrix(treeVariance).inverse();
final int n = data.length / traitPrecision.length;
final int p = traitPrecision.length;
double[][] tmp = new double[n][p];
for (int i = 0; i < n; ++i) {
for (int j = 0; j < p; ++j) {
tmp[i][j] = data[i * p + j];
}
}
Matrix y = new Matrix(tmp);
Matrix S = null;
try {
S = y.transpose().product(treePrecision).product(y); // Using Matrix-Normal form
} catch (IllegalDimension illegalDimension) {
illegalDimension.printStackTrace();
}
sb.append("Outer-products (from tree variance:\n");
sb.append(S);
sb.append("\n\n");
return sb.toString();
}
class NodeToRootDistance {
NodeRef node;
double distance;
NodeToRootDistance(NodeRef node, double distance) {
this.node = node;
this.distance = distance;
}
}
class NodeToRootDistanceList extends ArrayList<NodeToRootDistance> {
NodeToRootDistanceList(NodeToRootDistanceList parentList) {
super(parentList);
}
NodeToRootDistanceList() {
super();
}
}
private void addNodeToList(final NodeRef thisNode, NodeToRootDistanceList parentList, NodeToRootDistanceList[] tipLists) {
if (!treeModel.isRoot(thisNode)) {
double increment = getRescaledBranchLengthForPrecision(thisNode);
if (parentList.size() > 0) {
increment += parentList.get(parentList.size() - 1).distance;
}
parentList.add(new NodeToRootDistance(thisNode, increment));
}
if (treeModel.isExternal(thisNode)) {
tipLists[thisNode.getNumber()] = parentList;
} else { // recurse
NodeToRootDistanceList shallowCopy = new NodeToRootDistanceList(parentList);
addNodeToList(treeModel.getChild(thisNode, 0), shallowCopy, tipLists);
addNodeToList(treeModel.getChild(thisNode, 1), parentList, tipLists);
}
}
private double getTimeBetweenNodeToRootLists(List<NodeToRootDistance> x, List<NodeToRootDistance> y) {
if (x.get(0) != y.get(0)) {
return 0.0;
}
int index = 1;
while (x.get(index) == y.get(index)) {
++index;
}
return x.get(index - 1).distance;
}
public double[][] computeTreeVariance2(boolean includeRoot) {
final int tipCount = treeModel.getExternalNodeCount();
double[][] variance = new double[tipCount][tipCount];
NodeToRootDistanceList[] tipToRootDistances = new NodeToRootDistanceList[tipCount];
// Recurse down tree to generate lists
addNodeToList(treeModel.getRoot(), new NodeToRootDistanceList(), tipToRootDistances);
for (int i = 0; i < tipCount; ++i) {
// Fill in diagonal
List<NodeToRootDistance> iList = tipToRootDistances[i];
double marginalTime = iList.get(iList.size() - 1).distance;
variance[i][i] = marginalTime;
for (int j = i + 1; j < tipCount; ++j) {
List<NodeToRootDistance> jList = tipToRootDistances[j];
double time = getTimeBetweenNodeToRootLists(iList, jList);
variance[j][i] = variance[i][j] = time;
}
}
variance = removeMissingTipsInTreeVariance(variance); // Automatically prune missing tips
if (DEBUG) {
System.err.println("");
System.err.println("New tree (trimmed) conditional variance:\n" + new Matrix(variance));
}
if (includeRoot) {
for (int i = 0; i < variance.length; ++i) {
for (int j = 0; j < variance[i].length; ++j) {
variance[i][j] += 1.0 / getPriorSampleSize();
}
}
}
return variance;
}
public double[][] computeTreeVariance(boolean includeRoot) {
final int tipCount = treeModel.getExternalNodeCount();
int length = tipCount;
boolean DO_CLAMP = true;
if (DO_CLAMP && nodeToClampMap != null) {
length += nodeToClampMap.size();
}
// System.exit(-1);
double[][] variance = new double[length][length];
for (int i = 0; i < tipCount; i++) {
// Fill in diagonal
double marginalTime = getRescaledLengthToRoot(treeModel.getExternalNode(i));
variance[i][i] = marginalTime;
// Fill in upper right triangle,
for (int j = i + 1; j < tipCount; j++) {
NodeRef mrca = findMRCA(i, j);
variance[i][j] = getRescaledLengthToRoot(mrca);
}
}
if (DO_CLAMP && nodeToClampMap != null) {
List<RestrictedPartials> partialsList = new ArrayList<RestrictedPartials>();
for (Map.Entry<NodeRef, RestrictedPartials> keySet : nodeToClampMap.entrySet()) {
partialsList.add(keySet.getValue());
}
for (int i = 0; i < partialsList.size(); ++i) {
RestrictedPartials partials = partialsList.get(i);
NodeRef node = partials.getNode();
variance[tipCount + i][tipCount + i] = getRescaledLengthToRoot(node) +
1.0 / partials.getPriorSampleSize();
for (int j = 0; j < tipCount; ++j) {
NodeRef friend = treeModel.getExternalNode(j);
NodeRef mrca = TreeUtils.getCommonAncestor(treeModel, node, friend);
variance[j][tipCount + i] = getRescaledLengthToRoot(mrca);
}
for (int j = 0; j < i; ++j) {
NodeRef friend = partialsList.get(j).getNode();
NodeRef mrca = TreeUtils.getCommonAncestor(treeModel, node, friend);
variance[tipCount + j][tipCount + i] = getRescaledLengthToRoot(mrca);
}
}
}
// Make symmetric
for (int i = 0; i < length; i++) {
for (int j = i + 1; j < length; j++) {
variance[j][i] = variance[i][j];
}
}
// if (DEBUG) {
// System.err.println("");
// System.err.println("New tree conditional variance:\n" + new Matrix(variance));
// }
//
// variance = removeMissingTipsInTreeVariance(variance); // Automatically prune missing tips
//
// if (DEBUG) {
// System.err.println("");
// System.err.println("New tree (trimmed) conditional variance:\n" + new Matrix(variance));
// }
if (includeRoot) {
for (int i = 0; i < variance.length; ++i) {
for (int j = 0; j < variance[i].length; ++j) {
variance[i][j] += 1.0 / getPriorSampleSize();
}
}
}
return variance;
}
private NodeRef findMRCA(int iTip, int jTip) {
Set<String> leafNames = new HashSet<String>();
leafNames.add(treeModel.getTaxonId(iTip));
leafNames.add(treeModel.getTaxonId(jTip));
return TreeUtils.getCommonAncestorNode(treeModel, leafNames);
}
private double[][] removeMissingTipsInTreeVariance(double[][] variance) {
final int tipCount = treeModel.getExternalNodeCount();
final int nonMissing = countNonMissingTips();
if (nonMissing == tipCount) { // Do nothing
return variance;
}
double[][] outVariance = new double[nonMissing][nonMissing];
int iReal = 0;
for (int i = 0; i < tipCount; i++) {
if (!missingTraits.isCompletelyMissing(i)) {
int jReal = 0;
for (int j = 0; j < tipCount; j++) {
if (!missingTraits.isCompletelyMissing(i)) {
outVariance[iReal][jReal] = variance[i][j];
jReal++;
}
}
iReal++;
}
}
return outVariance;
}
private int countNonMissingTips() {
int tipCount = treeModel.getExternalNodeCount();
for (int i = 0; i < tipCount; i++) {
if (missingTraits.isCompletelyMissing(i)) {
tipCount--;
}
}
return tipCount;
}
}