/* * IntervalLatentLiabilityLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.continuous; import dr.evolution.alignment.PatternList; import dr.evomodel.tree.TreeModel; import dr.inference.model.*; import dr.math.distributions.Distribution; import dr.util.Citable; import dr.util.Citation; import dr.util.CommonCitations; import dr.xml.*; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.logging.Logger; /** * A class to model multivariate unit-interval data as realizations from a latent (liability) multivariate Brownian diffusion * * @author Marc A. Suchard * @author Philippe Lemey * @version $Id$ */ public class IntervalLatentLiabilityLikelihood extends AbstractModelLikelihood implements LatentTruncation, Citable, SoftThresholdLikelihood { public final static String LATENT_LIABILITY_LIKELIHOOD = "intervalLatentLiabilityLikelihood"; public IntervalLatentLiabilityLikelihood(TreeModel treeModel, CompoundParameter tipTraitParameter) { super(LATENT_LIABILITY_LIKELIHOOD); this.treeModel = treeModel; this.patternList = null; this.tipTraitParameter = tipTraitParameter; addVariable(tipTraitParameter); setTipDataValuesForAllNodes(); // System.err.println("Name: " + tipTraitParameter.getId()); // System.exit(-1); for (int i = 0; i < tipTraitParameter.getParameterCount(); ++i) { Parameter p = tipTraitParameter.getParameter(i); p.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, p.getDimension())); } // tipTraitParameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, tipTraitParameter.getDimension())); StringBuilder sb = new StringBuilder(); sb.append("Constructing a unit interval latent liability likelihood model:\n"); // sb.append("\tBinary patterns: ").append(patternList.getId()).append("\n"); sb.append("\tPlease cite:\n").append(Utils.getCitationString(this)); Logger.getLogger("dr.evomodel.continuous").info(sb.toString()); } private void setTipDataValuesForAllNodes() { System.err.println(tipTraitParameter.getParameterCount()); System.err.println(tipTraitParameter.getDimension()); // System.exit(-1); if (tipData == null) { // tipData = new int[treeModel.getExternalNodeCount()][patternList.getPatternCount()]; tipData = new long[tipTraitParameter.getDimension()]; } // for (int i = 0; i < treeModel.getExternalNodeCount(); i++) { // NodeRef node = treeModel.getExternalNode(i); // String id = treeModel.getTaxonId(i); // int index = patternList.getTaxonIndex(id); // setTipDataValuesForNode(node, index); // } // for (int tip = 0; tip < treeModel.getExternalNodeCount(); ++tip) { // System.err.println("Tip #" + tip); // Parameter oneTipTraitParameter = tipTraitParameter.getParameter(tip); // int[] data = tipData[tip]; for (int index = 0; index < tipData.length; ++index) { // int datum = data[index]; // double trait = oneTipTraitParameter.getParameterValue(index); // valid = Math.round(trait) == datum; tipData[index] = Math.round(tipTraitParameter.getParameterValue(index)); // System.err.print(" " + tipData[index]); // } } // System.exit(-1); } // private void setTipDataValuesForNode(NodeRef node, int indexFromPatternList) { // // Set tip data values // int index = node.getNumber(); // if (index != indexFromPatternList) { // throw new RuntimeException("Need to figure out the indexing"); // } // // for (int datum = 0; datum < patternList.getPatternCount(); ++datum) { // tipData[index][datum] = patternList.getPattern(datum)[index] == 1; // if (DEBUG) { // Parameter oneTipTraitParameter = tipTraitParameter.getParameter(index); // System.err.println("Data = " + tipData[index][datum] + " : " + oneTipTraitParameter.getParameterValue(datum)); // } // } // } @Override protected void handleModelChangedEvent(Model model, Object object, int index) { } @Override protected void handleVariableChangedEvent(Variable variable, int index, Variable.ChangeType type) { likelihoodKnown = false; } @Override protected void storeState() { storedLogLikelihood = logLikelihood; } @Override protected void restoreState() { logLikelihood = storedLogLikelihood; likelihoodKnown = true; } @Override protected void acceptState() { // do nothing } public void makeDirty() { likelihoodKnown = false; } public Model getModel() { return this; } public double getLogLikelihood() { if (!likelihoodKnown) { logLikelihood = computeLogLikelihood(); likelihoodKnown = true; } return logLikelihood; } public String toString() { return getClass().getName() + "(" + getLogLikelihood() + ")"; } protected double computeLogLikelihood() { boolean valid = true; // for (int tip = 0; tip < tipData.length && valid; ++tip) { // valid = validTraitForTip(tip); // } for (int index = 0; index < tipData.length && valid; ++index) { double raw = tipTraitParameter.getParameterValue(index); long round = Math.round(raw); valid = round == tipData[index]; // System.err.println(tipData[index] + " " + round + " " + raw); // TODO Handle missing values } // System.err.println("valid = " + valid); // check boolean valid2 = true; for (int tip = 0; tip < treeModel.getExternalNodeCount() && valid2; ++tip) { if (!validTraitForTip(tip)) { valid2 = false; } } // System.err.println(valid + " " + valid2); if (valid != valid2) { throw new RuntimeException("Error in computing validity of tips values"); } if (valid) { return 0.0; } else { // System.exit(-1); return Double.NEGATIVE_INFINITY; } } public boolean validTraitForTip(int tip) { boolean valid = true; Parameter oneTipTraitParameter = tipTraitParameter.getParameter(tip); final int offset = oneTipTraitParameter.getDimension() * tip; for (int index = 0; index < oneTipTraitParameter.getDimension() && valid; ++index) { double raw = oneTipTraitParameter.getParameterValue(index); long round = Math.round(raw); valid = round == tipData[index + offset]; } return valid; } // ************************************************************** // XMLObjectParser // ************************************************************** public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public final static String TIP_TRAIT = "tipTrait"; public String getParserName() { return LATENT_LIABILITY_LIKELIHOOD; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { AbstractMultivariateTraitLikelihood traitLikelihood = (AbstractMultivariateTraitLikelihood) xo.getChild(AbstractMultivariateTraitLikelihood.class); // PatternList patternList = (PatternList) xo.getChild(PatternList.class); TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class); CompoundParameter tipTraitParameter = (CompoundParameter) xo.getElementFirstChild(TIP_TRAIT); int numTaxa = treeModel.getTaxonCount(); int numData = traitLikelihood.getNumData(); int dimTrait = traitLikelihood.getDimTrait(); if (tipTraitParameter.getDimension() != numTaxa * numData * dimTrait) { throw new XMLParseException("Tip trait parameter is wrong dimension in latent liability model"); } // if (!(patternList.getDataType() instanceof TwoStates)) { // throw new XMLParseException("Latent liability model currently only works for binary data"); // } // if (patternList.getPatternCount() != numData * dimTrait) { // throw new XMLParseException("Binary data is wrong dimension in latent liability model"); // } return new IntervalLatentLiabilityLikelihood(treeModel, tipTraitParameter); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "Provides the likelihood of a latent liability model on multivariate-binary trait data"; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { new ElementRule(AbstractMultivariateTraitLikelihood.class, "The model for the latent random variables"), new ElementRule(TIP_TRAIT, CompoundParameter.class, "The parameter of tip locations from the tree"), // new ElementRule(PatternList.class, "The binary tip data"), new ElementRule(TreeModel.class, "The tree model"), }; public Class getReturnType() { return IntervalLatentLiabilityLikelihood.class; } }; @Override public Citation.Category getCategory() { return Citation.Category.TRAIT_MODELS; } @Override public String getDescription() { return "Intervaled latent liability model"; } @Override public List<Citation> getCitations() { List<Citation> citations = new ArrayList<Citation>(); citations.add(CommonCitations.CYBIS_2015_ASSESSING); return citations; } public double getNormalizationConstant(Distribution working) { return normalizationDelegate.getNormalizationConstant(working); // delegate to abstract Delegate } private final LatentTruncation.Delegate normalizationDelegate = new Delegate() { protected double computeNormalizationConstant(Distribution working) { double constant = 0.0; for (long datum : tipData) { constant += Math.log(working.cdf(datum + 0.5) - working.cdf(datum - 0.5)); } return -constant; // Note minus sign // return 16.30411; } }; public void setPathParameter(double beta){ pathParameter=beta; } @Override public double getLikelihoodCorrection() { return 0; } private TreeModel treeModel; private PatternList patternList; private CompoundParameter tipTraitParameter; private long[] tipData; private boolean likelihoodKnown = false; private double logLikelihood; private double storedLogLikelihood; private static final boolean DEBUG = false; private double pathParameter=1; }