/* * MultivariateDiffusionModel.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.continuous; import dr.evolution.tree.Tree; import dr.evolution.tree.TreeAttributeProvider; import dr.inference.model.*; import dr.math.distributions.MultivariateNormalDistribution; import dr.xml.*; import org.w3c.dom.Document; import org.w3c.dom.Element; /** * @author Marc Suchard */ public class MultivariateDiffusionModel extends AbstractModel implements TreeAttributeProvider { public static final String DIFFUSION_PROCESS = "multivariateDiffusionModel"; public static final String DIFFUSION_CONSTANT = "precisionMatrix"; public static final String PRECISION_TREE_ATTRIBUTE = "precision"; public static final double LOG2PI = Math.log(2*Math.PI); /** * Construct a diffusion model. */ public MultivariateDiffusionModel(MatrixParameterInterface diffusionPrecisionMatrixParameter) { super(DIFFUSION_PROCESS); this.diffusionPrecisionMatrixParameter = diffusionPrecisionMatrixParameter; calculatePrecisionInfo(); addVariable(diffusionPrecisionMatrixParameter); } public MultivariateDiffusionModel() { super(DIFFUSION_PROCESS); } // public void randomize(Parameter trait) { // } public void check(Parameter trait) throws XMLParseException { assert trait != null; } public MatrixParameterInterface getPrecisionParameter() { checkVariableChanged(); return diffusionPrecisionMatrixParameter; } public double[][] getPrecisionmatrix() { if (diffusionPrecisionMatrixParameter != null) { checkVariableChanged(); return diffusionPrecisionMatrixParameter.getParameterAsMatrix(); } return null; } public double getDeterminantPrecisionMatrix() { checkVariableChanged(); return determinatePrecisionMatrix; } /** * @return the log likelihood of going from start to stop in the given time */ public double getLogLikelihood(double[] start, double[] stop, double time) { if (time == 0) { boolean equal = true; for(int i=0; i<start.length; i++) { if( start[i] != stop[i] ) { equal = false; break; } } if (equal) return 0.0; return Double.NEGATIVE_INFINITY; } return calculateLogDensity(start, stop, time); } protected void checkVariableChanged(){ if(variableChanged){ calculatePrecisionInfo(); variableChanged=false; } } protected double calculateLogDensity(double[] start, double[] stop, double time) { checkVariableChanged(); final double logDet = Math.log(determinatePrecisionMatrix); return MultivariateNormalDistribution.logPdf(stop, start, diffusionPrecisionMatrix, logDet, time); } // todo should be a test, no? public static void main(String[] args) { double[] start = {1, 2}; double[] stop = {0, 0}; double[][] precision = {{2, 0.5}, {0.5, 1}}; double scale = 0.2; MatrixParameter precMatrix = new MatrixParameter("Hello"); precMatrix.addParameter(new Parameter.Default(precision[0])); precMatrix.addParameter(new Parameter.Default(precision[1])); MultivariateDiffusionModel model = new MultivariateDiffusionModel(precMatrix); System.err.println("logPDF = " + model.calculateLogDensity(start, stop, scale)); System.err.println("Should be -19.948"); } protected void calculatePrecisionInfo() { diffusionPrecisionMatrix = diffusionPrecisionMatrixParameter.getParameterAsMatrix(); determinatePrecisionMatrix = MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate( diffusionPrecisionMatrix); } // ***************************************************************** // Interface Model // ***************************************************************** public void handleModelChangedEvent(Model model, Object object, int index) { // no intermediates need to be recalculated... } protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { variableChanged=true; // calculatePrecisionInfo(); } protected void storeState() { savedDeterminatePrecisionMatrix = determinatePrecisionMatrix; savedDiffusionPrecisionMatrix = diffusionPrecisionMatrix; storedVariableChanged=variableChanged; } protected void restoreState() { determinatePrecisionMatrix = savedDeterminatePrecisionMatrix; diffusionPrecisionMatrix = savedDiffusionPrecisionMatrix; variableChanged=storedVariableChanged; } protected void acceptState() { } // no additional state needs accepting public String[] getTreeAttributeLabel() { return new String[] {PRECISION_TREE_ATTRIBUTE}; } public String[] getAttributeForTree(Tree tree) { if (diffusionPrecisionMatrixParameter != null) { return new String[] {diffusionPrecisionMatrixParameter.toSymmetricString()}; } diffusionPrecisionMatrixParameter.toString(); return new String[] { "null" }; } // ************************************************************** // XMLElement IMPLEMENTATION // ************************************************************** public Element createElement(Document document) { throw new RuntimeException("Not implemented!"); } // ************************************************************** // XMLObjectParser // ************************************************************** public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return DIFFUSION_PROCESS; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(DIFFUSION_CONSTANT); MatrixParameterInterface diffusionParam = (MatrixParameterInterface) cxo.getChild(MatrixParameterInterface.class); return new MultivariateDiffusionModel(diffusionParam); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "Describes a multivariate normal diffusion process."; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(DIFFUSION_CONSTANT, new XMLSyntaxRule[]{new ElementRule(MatrixParameterInterface.class)}), }; public Class getReturnType() { return MultivariateDiffusionModel.class; } }; // ************************************************************** // Private instance variables // ************************************************************** protected MatrixParameterInterface diffusionPrecisionMatrixParameter; private double determinatePrecisionMatrix; private double savedDeterminatePrecisionMatrix; private double[][] diffusionPrecisionMatrix; private double[][] savedDiffusionPrecisionMatrix; private boolean variableChanged=true; private boolean storedVariableChanged; }