/* * MultivariateDistributionLikelihood.java * * Copyright (c) 2002-2012 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.distribution; import dr.inference.model.*; import dr.inferencexml.distribution.DistributionLikelihoodParser; import dr.math.distributions.*; import dr.util.Attribute; import dr.xml.*; /** * @author Marc Suchard */ public class MultivariateDistributionLikelihood extends AbstractDistributionLikelihood { public static final String MVN_PRIOR = "multivariateNormalPrior"; public static final String MVN_MEAN = "meanParameter"; public static final String MVN_PRECISION = "precisionParameter"; public static final String MVN_CV = "coefficientOfVariation"; public static final String WISHART_PRIOR = "multivariateWishartPrior"; public static final String INV_WISHART_PRIOR = "multivariateInverseWishartPrior"; public static final String DIRICHLET_PRIOR = "dirichletPrior"; public static final String DF = "df"; public static final String SCALE_MATRIX = "scaleMatrix"; public static final String MVGAMMA_PRIOR = "multivariateGammaPrior"; public static final String MVGAMMA_SHAPE = "shapeParameter"; public static final String MVGAMMA_SCALE = "scaleParameter"; public static final String COUNTS = "countsParameter"; public static final String NON_INFORMATIVE = "nonInformative"; public static final String MULTIVARIATE_LIKELIHOOD = "multivariateDistributionLikelihood"; public static final String DATA = "data"; private final MultivariateDistribution distribution; public MultivariateDistributionLikelihood(String name, ParametricMultivariateDistributionModel model) { super(model); this.distribution = model; } public MultivariateDistributionLikelihood(String name, MultivariateDistribution distribution) { super(new DefaultModel(name)); this.distribution = distribution; } public MultivariateDistributionLikelihood(MultivariateDistribution distribution) { this(distribution.getType(), distribution); } public double calculateLogLikelihood() { double logL = 0.0; for (Attribute<double[]> data : dataList) { logL += distribution.logPdf(data.getAttributeValue()); } return logL; } @Override public void addData(Attribute<double[]> data) { super.addData(data); if (data instanceof Variable && getModel() instanceof DefaultModel) { ((DefaultModel) getModel()).addVariable((Variable) data); } } public MultivariateDistribution getDistribution() { return distribution; } public static XMLObjectParser DIRICHLET_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return DIRICHLET_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(COUNTS); Parameter counts = (Parameter) cxo.getChild(Parameter.class); DirichletDistribution dirichlet = new DirichletDistribution(counts.getParameterValues()); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( dirichlet); cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { likelihood.addData((Parameter) cxo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element " + cxo.getName()); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(COUNTS, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, 1, Integer.MAX_VALUE), }; public String getParserDescription() { return "Calculates the likelihood of some data under a Dirichlet distribution."; } public Class getReturnType() { return Likelihood.class; } }; public static XMLObjectParser INV_WISHART_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return INV_WISHART_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { int df = xo.getIntegerAttribute(DF); XMLObject cxo = xo.getChild(SCALE_MATRIX); MatrixParameter scaleMatrix = (MatrixParameter) cxo.getChild(MatrixParameter.class); InverseWishartDistribution invWishart = new InverseWishartDistribution(df, scaleMatrix.getParameterAsMatrix()); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( invWishart); cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof MatrixParameter) { likelihood.addData((MatrixParameter) cxo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element " + cxo.getName()); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newDoubleRule(DF), new ElementRule(SCALE_MATRIX, new XMLSyntaxRule[]{new ElementRule(MatrixParameter.class)}), }; public String getParserDescription() { return "Calculates the likelihood of some data under an Inverse-Wishart distribution."; } public Class getReturnType() { return Likelihood.class; } }; public static XMLObjectParser WISHART_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return WISHART_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { MultivariateDistributionLikelihood likelihood; if (xo.hasAttribute(NON_INFORMATIVE) && xo.getBooleanAttribute(NON_INFORMATIVE)) { // Make non-informative settings XMLObject cxo = xo.getChild(DATA); int dim = ((MatrixParameter) cxo.getChild(0)).getColumnDimension(); likelihood = new MultivariateDistributionLikelihood(new WishartDistribution(dim)); } else { if (!xo.hasAttribute(DF) || !xo.hasChildNamed(SCALE_MATRIX)) { throw new XMLParseException("Must specify both a df and scaleMatrix"); } double df = xo.getDoubleAttribute(DF); XMLObject cxo = xo.getChild(SCALE_MATRIX); MatrixParameter scaleMatrix = (MatrixParameter) cxo.getChild(MatrixParameter.class); likelihood = new MultivariateDistributionLikelihood( new WishartDistribution(df, scaleMatrix.getParameterAsMatrix()) ); } XMLObject cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof MatrixParameter) { likelihood.addData((MatrixParameter) cxo.getChild(j)); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element " + cxo.getName()); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules;{ rules = new XMLSyntaxRule[]{ AttributeRule.newBooleanRule(NON_INFORMATIVE, true), AttributeRule.newDoubleRule(DF, true), new ElementRule(SCALE_MATRIX, new XMLSyntaxRule[]{new ElementRule(MatrixParameter.class)}, true), new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(MatrixParameter.class, 1, Integer.MAX_VALUE)} ) }; } public String getParserDescription() { return "Calculates the likelihood of some data under a Wishart distribution."; } public Class getReturnType() { return Likelihood.class; } }; public static XMLObjectParser MULTIVARIATE_LIKELIHOOD_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return MULTIVARIATE_LIKELIHOOD; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(DistributionLikelihoodParser.DISTRIBUTION); ParametricMultivariateDistributionModel distribution = (ParametricMultivariateDistributionModel) cxo.getChild(ParametricMultivariateDistributionModel.class); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood(xo.getId(), distribution); cxo = xo.getChild(DATA); if (cxo != null) { for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { Parameter data = (Parameter) cxo.getChild(j); if (data instanceof MatrixParameter) { MatrixParameter matrix = (MatrixParameter) data; if (matrix.getParameter(0).getDimension() != distribution.getMean().length) throw new XMLParseException("dim(" + data.getStatisticName() + ") = " + matrix.getParameter(0).getDimension() + " is not equal to dim(" + distribution.getType() + ") = " + distribution.getMean().length + " in " + xo.getName() + "element"); for (int i = 0; i < matrix.getParameterCount(); i++) { likelihood.addData(matrix.getParameter(i)); } } else { if (data.getDimension() != distribution.getMean().length) throw new XMLParseException("dim(" + data.getStatisticName() + ") = " + data.getDimension() + " is not equal to dim(" + distribution.getType() + ") = " + distribution.getMean().length + " in " + xo.getName() + "element"); likelihood.addData(data); } } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(DistributionLikelihoodParser.DISTRIBUTION, new XMLSyntaxRule[]{new ElementRule(ParametricMultivariateDistributionModel.class)} ), new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)}, true) }; public String getParserDescription() { return "Calculates the likelihood of some data under a given multivariate distribution."; } public Class getReturnType() { return MultivariateDistributionLikelihood.class; } }; public static XMLObjectParser MVN_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return MVN_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = xo.getChild(MVN_MEAN); Parameter mean = (Parameter) cxo.getChild(Parameter.class); cxo = xo.getChild(MVN_PRECISION); MatrixParameter precision = (MatrixParameter) cxo.getChild(MatrixParameter.class); if (mean.getDimension() != precision.getRowDimension() || mean.getDimension() != precision.getColumnDimension()) throw new XMLParseException("Mean and precision have wrong dimensions in " + xo.getName() + " element"); MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( new MultivariateNormalDistribution(mean.getParameterValues(), precision.getParameterAsMatrix()) ); cxo = xo.getChild(DATA); if (cxo != null) { for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { Parameter data = (Parameter) cxo.getChild(j); if (data instanceof MatrixParameter) { MatrixParameter matrix = (MatrixParameter) data; if (matrix.getParameter(0).getDimension() != mean.getDimension()) throw new XMLParseException("dim(" + data.getStatisticName() + ") = " + matrix.getParameter(0).getDimension() + " is not equal to dim(" + mean.getStatisticName() + ") = " + mean.getDimension() + " in " + xo.getName() + "element"); for (int i = 0; i < matrix.getParameterCount(); i++) { likelihood.addData(matrix.getParameter(i)); } } else { if (data.getDimension() != mean.getDimension()) throw new XMLParseException("dim(" + data.getStatisticName() + ") = " + data.getDimension() + " is not equal to dim(" + mean.getStatisticName() + ") = " + mean.getDimension() + " in " + xo.getName() + "element"); likelihood.addData(data); } } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(MVN_MEAN, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(MVN_PRECISION, new XMLSyntaxRule[]{new ElementRule(MatrixParameter.class)}), new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)}, true) }; public String getParserDescription() { return "Calculates the likelihood of some data under a given multivariate-normal distribution."; } public Class getReturnType() { return MultivariateDistributionLikelihood.class; } }; public static XMLObjectParser MVGAMMA_PRIOR_PARSER = new AbstractXMLObjectParser() { public String getParserName() { return MVGAMMA_PRIOR; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { double[] shape; double[] scale; if (xo.hasChildNamed(MVGAMMA_SHAPE)) { XMLObject cxo = xo.getChild(MVGAMMA_SHAPE); shape = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); cxo = xo.getChild(MVGAMMA_SCALE); scale = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); if (shape.length != scale.length) throw new XMLParseException("Shape and scale have wrong dimensions in " + xo.getName() + " element"); } else { XMLObject cxo = xo.getChild(MVN_MEAN); double[] mean = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); cxo = xo.getChild(MVN_CV); double[] cv = ((Parameter) cxo.getChild(Parameter.class)).getParameterValues(); if (mean.length != cv.length) throw new XMLParseException("Mean and CV have wrong dimensions in " + xo.getName() + " element"); final int dim = mean.length; shape = new double[dim]; scale = new double[dim]; for (int i = 0; i < dim; i++) { double c2 = cv[i] * cv[i]; shape[i] = 1.0 / c2; scale[i] = c2 * mean[i]; } } MultivariateDistributionLikelihood likelihood = new MultivariateDistributionLikelihood( new MultivariateGammaDistribution(shape, scale) ); XMLObject cxo = xo.getChild(DATA); for (int j = 0; j < cxo.getChildCount(); j++) { if (cxo.getChild(j) instanceof Parameter) { Parameter data = (Parameter) cxo.getChild(j); likelihood.addData(data); if (data.getDimension() != shape.length) throw new XMLParseException("dim(" + data.getStatisticName() + ") != " + shape.length + " in " + xo.getName() + "element"); } else { throw new XMLParseException("illegal element in " + xo.getName() + " element"); } } return likelihood; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new XORRule( new ElementRule(MVGAMMA_SHAPE, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(MVN_MEAN, new XMLSyntaxRule[]{new ElementRule(Parameter.class)})), new XORRule( new ElementRule(MVGAMMA_SCALE, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(MVN_CV, new XMLSyntaxRule[]{new ElementRule(Parameter.class)})), new ElementRule(DATA, new XMLSyntaxRule[]{new ElementRule(Parameter.class, 1, Integer.MAX_VALUE)}) }; public String getParserDescription() { return "Calculates the likelihood of some data under a given multivariate-gamma distribution."; } public Class getReturnType() { return MultivariateDistributionLikelihood.class; } }; }