/* * TreeTraitNormalDistributionModel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.continuous; import dr.evolution.tree.Tree; import dr.inference.distribution.ParametricMultivariateDistributionModel; import dr.inference.model.AbstractModel; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.inferencexml.distribution.MultivariateNormalDistributionModelParser; import dr.math.distributions.MultivariateNormalDistribution; import dr.math.distributions.RandomGenerator; /** * A class that acts as a model for multivariate normally distributed data. * * @author Marc Suchard * @author Mandev Gill */ public class TreeTraitNormalDistributionModel extends AbstractModel implements ParametricMultivariateDistributionModel, RandomGenerator { public TreeTraitNormalDistributionModel(FullyConjugateMultivariateTraitLikelihood traitModel, Parameter rootValue, boolean conditionOnRoot) { super(MultivariateNormalDistributionModelParser.NORMAL_DISTRIBUTION_MODEL); this.traitModel = traitModel; if (rootValue != null) { this.rootValue = rootValue.getParameterValues(); } this.conditionOnRoot = conditionOnRoot; dim = traitModel.getTreeModel().getExternalNodeCount() * traitModel.getDimTrait(); addModel(traitModel); distributionKnown = false; //System.err.println("trait vector: " + traitModel.getTreeTraits()[0].getTraitString(traitModel.treeModel, traitModel.treeModel.getExternalNode(2))); //System.exit(0); } public TreeTraitNormalDistributionModel(FullyConjugateMultivariateTraitLikelihood traitModel, boolean conditionOnRoot) { this(traitModel, null, conditionOnRoot); } public Tree getTree() { return traitModel.getTreeModel(); } // ***************************************************************** // Interface MultivariateDistribution // ***************************************************************** public double logPdf(double[] x) { checkDistribution(); return distribution.logPdf(x); } public double[][] getScaleMatrix() { checkDistribution(); return distribution.getScaleMatrix(); } public double[] getMean() { checkDistribution(); return distribution.getMean(); } public String getType() { return "TreeTraitMVN"; } public int getDimTrait() { return traitModel.dimTrait; } // ***************************************************************** // Interface Model // ***************************************************************** public void handleModelChangedEvent(Model model, Object object, int index) { distributionKnown = false; } protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { distributionKnown = false; } protected void storeState() { storedDistribution = distribution; storedDistributionKnown = distributionKnown; } protected void restoreState() { distributionKnown = storedDistributionKnown; distribution = storedDistribution; } protected void acceptState() { } // no additional state needs accepting // ************************************************************** // Private instance variables and functions // ************************************************************** private void checkDistribution() { if (!distributionKnown) { mean = null; precision = null; distribution = createNewDistribution(); distributionKnown = true; } } private MultivariateNormalDistribution createNewDistribution() { return new MultivariateNormalDistribution(computeMean(), computePrecision()); } private double[] computeMean() { if (traitModel.strengthOfSelection != null) { return MultivariateTraitUtils.computeTreeTraitMeanOU(traitModel, rootValue, conditionOnRoot); } else { return MultivariateTraitUtils.computeTreeTraitMean(traitModel, rootValue, conditionOnRoot); } } private double[][] computePrecision() { return MultivariateTraitUtils.computeTreeTraitPrecision(traitModel, conditionOnRoot); } private final FullyConjugateMultivariateTraitLikelihood traitModel; private double[] mean; private double[][] precision; private MultivariateNormalDistribution distribution; private MultivariateNormalDistribution storedDistribution; private boolean distributionKnown; private boolean storedDistributionKnown; // RandomGenerator interface public double[] nextRandom() { checkDistribution(); return distribution.nextMultivariateNormal(); } public double logPdf(Object x) { checkDistribution(); return distribution.logPdf(x); } // ***************************************************************** // Interface DensityModel // ***************************************************************** @Override public Variable<Double> getLocationVariable() { throw new UnsupportedOperationException("Not implemented"); } private final boolean conditionOnRoot; private double[][] precisionMatrix = null; private double[] rootValue; private final int dim; /* public static final String TREE_TRAIT_NORMAL = "treeTraitNormalDistribution"; public static final String CONDITION = "conditionOnRoot"; public static XMLObjectParser TREE_TRAIT_MODEL = new AbstractXMLObjectParser() { public String getParserName() { return TREE_TRAIT_NORMAL; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { System.err.println("I AM IN THE RIGHT PARSER"); boolean conditionOnRoot = xo.getAttribute(CONDITION, false); FullyConjugateMultivariateTraitLikelihood traitModel = (FullyConjugateMultivariateTraitLikelihood) xo.getChild(FullyConjugateMultivariateTraitLikelihood.class); TreeTraitNormalDistributionModel treeTraitModel = new TreeTraitNormalDistributionModel(traitModel, conditionOnRoot); return treeTraitModel; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newBooleanRule(CONDITION, true), new ElementRule(FullyConjugateMultivariateTraitLikelihood.class) }; public String getParserDescription() { return "Parses TreeTraitNormalDistributionModel"; } public Class getReturnType() { return TreeTraitNormalDistributionModel.class; } }; */ }