/* * PurifyingTreeLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.oldevomodel.treelikelihood; import dr.evolution.alignment.PatternList; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.TaxonList; import dr.oldevomodel.sitemodel.SiteModel; import dr.oldevomodel.substmodel.FrequencyModel; import dr.evomodel.tree.TreeModel; import dr.inference.model.Likelihood; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.xml.*; /** * PurifyingTreeLikelihood - implements a Likelihood Function for sequences on a tree. * * @version $Id: PurifyingTreeLikelihood.java,v 1.6 2006/01/10 16:48:28 rambaut Exp $ * * @author Andrew Rambaut */ @Deprecated // Switching to BEAGLE public class PurifyingTreeLikelihood extends AbstractTreeLikelihood { public static final String PURIFYING_TREE_LIKELIHOOD = "purifyingTreeLikelihood"; public static final String HALF_LIFE = "halfLife"; public static final String PROPORTION = "proportion"; public static final String AVERAGE = "average"; /** * Constructor. */ public PurifyingTreeLikelihood( PatternList patternList, TreeModel treeModel, SiteModel siteModel, Parameter proportionParameter, Parameter lambdaParameter, boolean useAmbiguities, boolean useAveraging ) throws TaxonList.MissingTaxonException { super(PURIFYING_TREE_LIKELIHOOD, patternList, treeModel); this.useAveraging = useAveraging; try { this.siteModel = siteModel; addModel(siteModel); this.frequencyModel = siteModel.getFrequencyModel(); addModel(frequencyModel); this.proportionParameter = proportionParameter; addVariable(proportionParameter); this.lambdaParameter = lambdaParameter; addVariable(lambdaParameter); integrateAcrossCategories = siteModel.integrateAcrossCategories(); this.categoryCount = siteModel.getCategoryCount(); if (integrateAcrossCategories) { if (patternList.getDataType() instanceof dr.evolution.datatype.Nucleotides) { if (NativeNucleotideLikelihoodCore.isAvailable()) { System.out.println("TreeLikelihood using native nucleotide likelihood core."); likelihoodCore = new NativeNucleotideLikelihoodCore(); } else { System.out.println("TreeLikelihood using Java nucleotide likelihood core."); likelihoodCore = new NucleotideLikelihoodCore(); } } else if (patternList.getDataType() instanceof dr.evolution.datatype.AminoAcids) { System.out.println("TreeLikelihood using Java amino acid likelihood core."); likelihoodCore = new AminoAcidLikelihoodCore(); } else { System.out.println("TreeLikelihood using Java general likelihood core."); likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount()); } } else { System.out.println("TreeLikelihood using Java general likelihood core."); likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount()); } probabilities = new double[stateCount * stateCount]; likelihoodCore.initialize(nodeCount, patternCount, categoryCount, integrateAcrossCategories); int extNodeCount = treeModel.getExternalNodeCount(); int intNodeCount = treeModel.getInternalNodeCount(); for (int i = 0; i < extNodeCount; i++) { // Find the id of tip i in the patternList String id = treeModel.getTaxonId(i); int index = patternList.getTaxonIndex(id); if (index == -1) { throw new TaxonList.MissingTaxonException("Taxon, " + id + ", in tree, " + treeModel.getId() + ", is not found in patternList, " + patternList.getId()); } if (useAmbiguities) { setPartials(likelihoodCore, patternList, categoryCount, index, i); } else { setStates(likelihoodCore, patternList, index, i); } } for (int i = 0; i < intNodeCount; i++) { likelihoodCore.createNodePartials(extNodeCount + i); } } catch (TaxonList.MissingTaxonException mte) { throw new RuntimeException(mte.toString()); } } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** /** * Handles model changed events from the submodels. */ protected void handleModelChangedEvent(Model model, Object object, int index) { if (model == treeModel) { if (object instanceof TreeModel.TreeChangedEvent) { if (((TreeModel.TreeChangedEvent)object).isNodeChanged()) { updateNodeAndChildren(((TreeModel.TreeChangedEvent)object).getNode()); } else { updateAllNodes(); } } updateRates = true; } else if (model == frequencyModel) { updateAllNodes(); } else if (model instanceof SiteModel) { updateAllNodes(); } else { throw new RuntimeException("Unknown componentChangedEvent"); } super.handleModelChangedEvent(model, object, index); } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** public void handleParameterChangedEvent(Parameter parameter, int index) { // mu has changed if (parameter == proportionParameter || parameter == lambdaParameter) { updateRates = true; updateAllNodes(); } } /** * Stores the additional state other than model components */ protected void storeState() { likelihoodCore.storeState(); super.storeState(); } /** * Restore the additional stored state */ protected void restoreState() { likelihoodCore.restoreState(); super.restoreState(); } // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** /** * Calculate the log likelihood of the current state. * @return the log likelihood. */ protected double calculateLogLikelihood() { NodeRef root = treeModel.getRoot(); if (rootPartials == null) { rootPartials = new double[patternCount * stateCount]; } if (patternLogLikelihoods == null) { patternLogLikelihoods = new double[patternCount]; } if (!integrateAcrossCategories) { if (siteCategories == null) { siteCategories = new int[patternCount]; } for (int i = 0; i < patternCount; i++) { siteCategories[i] = siteModel.getCategoryOfSite(i); } } double p = proportionParameter.getParameterValue(0); double lambda = Math.log(2)/lambdaParameter.getParameterValue(0); updateRates = true; updateAllNodes(); if (updateRates) { if (nodeTimes == null) { nodeTimes = new double[treeModel.getNodeCount()]; } calculateNodeRates(treeModel, root, 1.0, p, lambda); } traverse(treeModel, root); updateRates = false; //******************************************************************** // after traverse all nodes and patterns have been updated -- //so change flags to reflect this. for (int i = 0; i < nodeCount; i++) { updateNode[i] = false; } //******************************************************************** double logL = 0.0; for (int i = 0; i < patternCount; i++) { logL += patternLogLikelihoods[i] * patternWeights[i]; } if (Double.isNaN(logL)) { throw new RuntimeException("Likelihood NaN"); } return logL; } /** * Traverse the tree calculating partial likelihoods. * @return whether the partials for this node were recalculated. */ private double calculateNodeRates(TreeModel tree, NodeRef node, double mu, double p, double lambda) { NodeRef parent = tree.getParent(node); double time0 = 0.0; // If the node is internal, update the partial likelihoods. if (!tree.isExternal(node)) { // Traverse down the two child nodes NodeRef child1 = tree.getChild(node, 0); double t1 = calculateNodeRates(tree, child1, mu, p, lambda); NodeRef child2 = tree.getChild(node, 1); double t2 = calculateNodeRates(tree, child2, mu, p, lambda); if (useAveraging) { time0 = (t1 + t2) / 2.0; } else { // pick larger of the two if (t1 > t2) { time0 = t1; } else { time0 = t2; } } } // don't bother if you are at the root because rate at root is ignored if (parent == null) return 0; double branchTime = tree.getNodeHeight(parent) - tree.getNodeHeight(node); double time1 = time0 + branchTime; double branchRate = rateIntegral(time1, mu, p, lambda); if (time0 > 0.0) { branchRate -= rateIntegral(time0, mu, p, lambda); } if (branchRate != tree.getNodeRate(node)) { updateNode(node); nodeTimes[node.getNumber()] = branchRate; } return time1; } private double rateIntegral(double time, double mu, double p, double lambda) { return mu * ( (p * time) - (((1.0 - p) / lambda) * (Math.exp(-lambda * time) - 1.0))); } /** * Traverse the tree calculating partial likelihoods. * @return whether the partials for this node were recalculated. */ private boolean traverse(Tree tree, NodeRef node) { boolean update = false; int nodeNum = node.getNumber(); NodeRef parent = tree.getParent(node); // First update the transition probability matrix(ices) for this branch if (parent != null && updateNode[nodeNum]) { // Get the average rate over this branch // *************************************************************** // Rate at nodes model //double branchRate = (tree.getNodeRate(node) + tree.getNodeRate(parent)) / 2; // *************************************************************** // *************************************************************** // Rate at branches model double branchTime = nodeTimes[node.getNumber()]; // *************************************************************** for (int i = 0; i < categoryCount; i++) { double branchLength = siteModel.getRateForCategory(i) * branchTime; siteModel.getSubstitutionModel().getTransitionProbabilities(branchLength, probabilities); likelihoodCore.setNodeMatrix(nodeNum, i, probabilities); } update = true; } // If the node is internal, update the partial likelihoods. if (!tree.isExternal(node)) { // Traverse down the two child nodes NodeRef child1 = tree.getChild(node, 0); boolean update1 = traverse(tree, child1); NodeRef child2 = tree.getChild(node, 1); boolean update2 = traverse(tree, child2); // If either child node was updated then update this node too if (update1 || update2) { int childNum1 = child1.getNumber(); int childNum2 = child2.getNumber(); if (integrateAcrossCategories) { likelihoodCore.calculatePartials(childNum1, childNum2, nodeNum); } else { likelihoodCore.calculatePartials(childNum1, childNum2, nodeNum, siteCategories); } if (parent == null) { // No parent this is the root of the tree - // calculate the pattern likelihoods double[] frequencies = frequencyModel.getFrequencies(); if (integrateAcrossCategories) { // moved this call to here, because non-integrating siteModels don't need to support it - AD double[] proportions = siteModel.getCategoryProportions(); likelihoodCore.integratePartials(nodeNum, proportions, rootPartials); } else { likelihoodCore.getPartials(nodeNum, rootPartials); } likelihoodCore.calculateLogLikelihoods(rootPartials, frequencies, patternLogLikelihoods); } update = true; } } return update; } // ************************************************************** // XMLElement IMPLEMENTATION // ************************************************************** public org.w3c.dom.Element createElement(org.w3c.dom.Document d) { throw new RuntimeException("createElement not implemented"); } /** * The XML parser */ public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return PURIFYING_TREE_LIKELIHOOD; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { XMLObject cxo = (XMLObject)xo.getChild(PROPORTION); Parameter proportionParam = (Parameter)cxo.getChild(Parameter.class); cxo = (XMLObject)xo.getChild(HALF_LIFE); Parameter lambdaParam = (Parameter)cxo.getChild(Parameter.class); PatternList patternList = (PatternList)xo.getChild(PatternList.class); TreeModel treeModel = (TreeModel)xo.getChild(TreeModel.class); SiteModel siteModel = (SiteModel)xo.getChild(SiteModel.class); boolean useAveraging = xo.getBooleanAttribute(AVERAGE); try { return new PurifyingTreeLikelihood(patternList, treeModel, siteModel, proportionParam, lambdaParam, false, useAveraging); } catch (TaxonList.MissingTaxonException e) { throw new XMLParseException(e.toString()); } } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element represents the likelihood of a patternlist on a tree given the site model."; } public Class getReturnType() { return Likelihood.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newBooleanRule(AVERAGE, false), new ElementRule(PROPORTION, new XMLSyntaxRule[] { new ElementRule(Parameter.class) }), new ElementRule(HALF_LIFE, new XMLSyntaxRule[] { new ElementRule(Parameter.class) }), new ElementRule(PatternList.class), new ElementRule(TreeModel.class), new ElementRule(SiteModel.class) }; }; // ************************************************************** // INSTANCE VARIABLES // ************************************************************** /** the frequency model for these sites */ protected FrequencyModel frequencyModel = null; /** the site model for these sites */ protected SiteModel siteModel = null; protected Parameter proportionParameter = null; protected Parameter lambdaParameter = null; private boolean updateRates = false; private double[] nodeTimes = null; // If true then the average of the two incoming branch lengths is used for rate function in internal branches (as opposed to longest) private boolean useAveraging = true; private boolean integrateAcrossCategories = false; /** the root partial likelihoods */ protected double[] branchRates = null; /** the categories for each site */ protected int[] siteCategories = null; /** the root partial likelihoods */ protected double[] rootPartials = null; /** the pattern likelihoods */ protected double[] patternLogLikelihoods = null; /** the number of rate categories */ protected int categoryCount; /** an array used to store transition probabilities */ protected double[] probabilities; /** the LikelihoodCore */ protected LikelihoodCore likelihoodCore; }