/* * PrecisionMatrixGibbsOperator.java * * Copyright (c) 2002-2012 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.operators; import dr.evolution.tree.NodeRef; import dr.evomodel.continuous.AbstractMultivariateTraitLikelihood; import dr.evomodel.continuous.SampledMultivariateTraitLikelihood; import dr.evomodel.tree.TreeModel; import dr.inference.distribution.MultivariateDistributionLikelihood; import dr.inference.distribution.MultivariateNormalDistributionModel; import dr.inference.model.MatrixParameter; import dr.inference.model.Parameter; import dr.inference.operators.GibbsOperator; import dr.inference.operators.MCMCOperator; import dr.inference.operators.OperatorFailedException; import dr.inference.operators.SimpleMCMCOperator; import dr.math.distributions.WishartDistribution; import dr.math.distributions.WishartSufficientStatistics; import dr.math.interfaces.ConjugateWishartStatisticsProvider; import dr.math.matrixAlgebra.IllegalDimension; import dr.math.matrixAlgebra.SymmetricMatrix; import dr.util.Attribute; import dr.xml.*; import java.util.List; /** * @author Marc Suchard */ public class PrecisionMatrixGibbsOperator extends SimpleMCMCOperator implements GibbsOperator { public static final String VARIANCE_OPERATOR = "precisionGibbsOperator"; // public static final String PRECISION_MATRIX = "precisionMatrix"; public static final String TREE_MODEL = "treeModel"; // public static final String OUTCOME = "outcome"; // public static final String MEAN = "mean"; public static final String DISTRIBUTION = "distribution"; public static final String PRIOR = "prior"; // public static final String TRAIT_MODEL = "traitModel"; private final AbstractMultivariateTraitLikelihood traitModel; private final MultivariateDistributionLikelihood multivariateLikelihood; private final Parameter meanParam; private final MatrixParameter precisionParam; // private WishartDistribution priorDistribution; private final double priorDf; private SymmetricMatrix priorInverseScaleMatrix; private final TreeModel treeModel; private final int dim; private double numberObservations; private final String traitName; private final boolean isSampledTraitLikelihood; public PrecisionMatrixGibbsOperator( MultivariateDistributionLikelihood likelihood, WishartDistribution priorDistribution, double weight) { super(); // Unnecessary variables this.traitModel = null; this.treeModel = null; this.traitName = null; this.isSampledTraitLikelihood = false; this.multivariateLikelihood = likelihood; MultivariateNormalDistributionModel density = (MultivariateNormalDistributionModel) likelihood.getDistribution(); this.meanParam = density.getMeanParameter(); this.precisionParam = density.getPrecisionMatrixParameter(); this.dim = meanParam.getDimension(); this.priorDf = priorDistribution.df(); // TODO Remove code duplication with below this.priorInverseScaleMatrix = null; if (priorDistribution.scaleMatrix() != null) this.priorInverseScaleMatrix = (SymmetricMatrix) (new SymmetricMatrix(priorDistribution.scaleMatrix())).inverse(); setWeight(weight); } public PrecisionMatrixGibbsOperator( AbstractMultivariateTraitLikelihood traitModel, WishartDistribution priorDistribution, double weight) { super(); this.traitModel = traitModel; this.meanParam = null; this.precisionParam = (MatrixParameter) traitModel.getDiffusionModel().getPrecisionParameter(); // this.priorDistribution = priorDistribution; this.priorDf = priorDistribution.df(); this.priorInverseScaleMatrix = null; if (priorDistribution.scaleMatrix() != null) this.priorInverseScaleMatrix = (SymmetricMatrix) (new SymmetricMatrix(priorDistribution.scaleMatrix())).inverse(); setWeight(weight); this.treeModel = traitModel.getTreeModel(); traitName = traitModel.getTraitName(); dim = precisionParam.getRowDimension(); // assumed to be square isSampledTraitLikelihood = (traitModel instanceof SampledMultivariateTraitLikelihood); if (!isSampledTraitLikelihood && !(traitModel instanceof ConjugateWishartStatisticsProvider)) { throw new RuntimeException("Only implemented for a SampledMultivariateTraitLikelihood and " + "ConjugateWishartStatisticsProvider"); } multivariateLikelihood = null; } public int getStepCount() { return 1; } // private void incrementScaledSquareMatrix(double[][] out, double[][] in, double scalar, int dim) { // for (int i = 0; i < dim; i++) { // for (int j = 0; j < dim; j++) { // out[i][j] += scalar * in[i][j]; // } // } // } // private void zeroSquareMatrix(double[][] out, int dim) { // for (int i = 0; i < dim; i++) { // for (int j = 0; j < dim; j++) { // out[i][j] = 0.0; // } // } // } private void incrementOuterProduct(double[][] S, MultivariateDistributionLikelihood likelihood) { double[] mean = likelihood.getDistribution().getMean(); numberObservations = 0; List<Attribute<double[]>> dataList = likelihood.getDataList(); int count = 0; for (Attribute<double[]> d : dataList) { double[] data = d.getAttributeValue(); for (int i = 0; i < dim; i++) { data[i] -= mean[i]; } for (int i = 0; i < dim; i++) { // symmetric matrix, for (int j = i; j < dim; j++) { S[j][i] = S[i][j] += data[i] * data[j]; } } numberObservations += 1; } } private void incrementOuterProduct(double[][] S, ConjugateWishartStatisticsProvider integratedLikelihood) { final WishartSufficientStatistics sufficientStatistics = integratedLikelihood.getWishartStatistics(); final double[][] outerProducts = sufficientStatistics.getScaleMatrix(); final double df = sufficientStatistics.getDf(); // final double df = 2; // final double df = integratedLikelihood.getTotalTreePrecision(); // System.err.println("OuterProducts = \n" + new Matrix(outerProducts)); // System.err.println("Total tree DF = " + df); // System.exit(-1); for (int i = 0; i < outerProducts.length; i++) { System.arraycopy(outerProducts[i], 0, S[i], 0, S[i].length); } numberObservations = df; } private void incrementOuterProduct(double[][] S, NodeRef node) { if (!treeModel.isRoot(node)) { NodeRef parent = treeModel.getParent(node); double[] parentTrait = treeModel.getMultivariateNodeTrait(parent, traitName); double[] childTrait = treeModel.getMultivariateNodeTrait(node, traitName); double time = traitModel.getRescaledBranchLength(node); if (time > 0) { double sqrtTime = Math.sqrt(time); double[] delta = new double[dim]; for (int i = 0; i < dim; i++) delta[i] = (childTrait[i] - parentTrait[i]) / sqrtTime; for (int i = 0; i < dim; i++) { // symmetric matrix, for (int j = i; j < dim; j++) S[j][i] = S[i][j] += delta[i] * delta[j]; } numberObservations += 1; // This assumes a *single* observation per tip } } // recurse down tree for (int i = 0; i < treeModel.getChildCount(node); i++) incrementOuterProduct(S, treeModel.getChild(node, i)); } public double[][] getOperationScaleMatrixAndSetObservationCount() { // calculate sum-of-the-weighted-squares matrix over tree double[][] S = new double[dim][dim]; SymmetricMatrix S2; SymmetricMatrix inverseS2 = null; numberObservations = 0; // Need to reset, as incrementOuterProduct can be recursive if (isSampledTraitLikelihood) { incrementOuterProduct(S, treeModel.getRoot()); } else { // IntegratedTraitLikelihood if (traitModel != null) { // is a tree incrementOuterProduct(S, (ConjugateWishartStatisticsProvider) traitModel); } else { // is a normal-normal-wishart model incrementOuterProduct(S, multivariateLikelihood); } } try { S2 = new SymmetricMatrix(S); if (priorInverseScaleMatrix != null) S2 = priorInverseScaleMatrix.add(S2); inverseS2 = (SymmetricMatrix) S2.inverse(); } catch (IllegalDimension illegalDimension) { illegalDimension.printStackTrace(); } assert inverseS2 != null; return inverseS2.toComponents(); } public double doOperation() throws OperatorFailedException { final double[][] scaleMatrix = getOperationScaleMatrixAndSetObservationCount(); final double treeDf = numberObservations; final double df = priorDf + treeDf; double[][] draw = WishartDistribution.nextWishart(df, scaleMatrix); for (int i = 0; i < dim; i++) { Parameter column = precisionParam.getParameter(i); for (int j = 0; j < dim; j++) column.setParameterValueQuietly(j, draw[j][i]); } precisionParam.fireParameterChangedEvent(); return 0; } public String getPerformanceSuggestion() { return null; } public String getOperatorName() { return VARIANCE_OPERATOR; } public static dr.xml.XMLObjectParser PARSER = new dr.xml.AbstractXMLObjectParser() { public String getParserName() { return VARIANCE_OPERATOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (xo.getChildCount() != 2) { throw new XMLParseException( "Element with id = '" + xo.getName() + "' should contain either:\n" + "\t 1 multivariateTraitLikelihood and 1 multivariateDistributionLikelihood (prior), or\n" + "\t 2 multivariateDistributionLikelihoods (likelihood and prior)\n" ); } double weight = xo.getDoubleAttribute(WEIGHT); AbstractMultivariateTraitLikelihood traitModel = (AbstractMultivariateTraitLikelihood) xo.getChild(AbstractMultivariateTraitLikelihood.class); MultivariateDistributionLikelihood prior = null; MatrixParameter precMatrix = null; MultivariateDistributionLikelihood likelihood = null; if (traitModel != null) { precMatrix = (MatrixParameter) traitModel.getDiffusionModel().getPrecisionParameter(); prior = (MultivariateDistributionLikelihood) xo.getChild(MultivariateDistributionLikelihood.class); } else { // generic likelihood and prior for (int i = 0; i < xo.getChildCount(); ++i) { MultivariateDistributionLikelihood density = (MultivariateDistributionLikelihood) xo.getChild(i); if (density.getDistribution() instanceof WishartDistribution) { prior = density; } else if (density.getDistribution() instanceof MultivariateNormalDistributionModel) { likelihood = density; precMatrix = ((MultivariateNormalDistributionModel) density.getDistribution()).getPrecisionMatrixParameter(); } } if (prior == null || likelihood == null) { throw new XMLParseException( "Must provide a multivariate normal likelihood and Wishart prior in element '" + xo.getName() + "'\n" ); } } if (!(prior.getDistribution() instanceof WishartDistribution)) { throw new XMLParseException("Only a Wishart distribution is conjugate for Gibbs sampling"); } // Make sure precMatrix is square and dim(precMatrix) = dim(parameter) if (precMatrix.getColumnDimension() != precMatrix.getRowDimension()) { throw new XMLParseException("The variance matrix is not square or of wrong dimension"); } if (traitModel != null) { return new PrecisionMatrixGibbsOperator( traitModel, (WishartDistribution) prior.getDistribution(), weight ); } else { return new PrecisionMatrixGibbsOperator(likelihood, (WishartDistribution) prior.getDistribution(), weight); } } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element returns a multivariate normal random walk operator on a given parameter."; } public Class getReturnType() { return MCMCOperator.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{ AttributeRule.newDoubleRule(WEIGHT), new ElementRule(AbstractMultivariateTraitLikelihood.class, true), new ElementRule(MultivariateDistributionLikelihood.class, 1, 2), }; }; }