/* * TreeLikelihood.java * * Copyright (C) 2002-2009 Alexei Drummond and Andrew Rambaut * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.newtreelikelihood; import dr.evolution.alignment.PatternList; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.TaxonList; import dr.app.beagle.evomodel.treelikelihood.AbstractTreeLikelihood; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.branchratemodel.DefaultBranchRateModel; import dr.evomodel.sitemodel.SiteModel; import dr.evomodel.substmodel.FrequencyModel; import dr.evomodel.tree.TreeModel; import dr.inference.model.Likelihood; import dr.inference.model.Model; import dr.xml.*; import java.util.logging.Logger; /** * TreeLikelihoodModel - implements a Likelihood Function for sequences on a tree. * * @author Andrew Rambaut * @author Alexei Drummond * @version $Id: TreeLikelihood.java,v 1.31 2006/08/30 16:02:42 rambaut Exp $ */ public class TreeLikelihood extends AbstractTreeLikelihood { public static final String TREE_LIKELIHOOD = "treeLikelihood"; public static final String USE_AMBIGUITIES = "useAmbiguities"; public static final String DEVICE_NUMBER = "deviceNumber"; /** * Constructor. */ public TreeLikelihood(PatternList patternList, TreeModel treeModel, SiteModel siteModel, BranchRateModel branchRateModel, boolean useAmbiguities, int deviceNumber ) { super(TREE_LIKELIHOOD, patternList, treeModel); try { final Logger logger = Logger.getLogger("dr.evomodel"); logger.info("Using Vector (GPU) TreeLikelihood"); this.siteModel = siteModel; addModel(siteModel); this.frequencyModel = siteModel.getFrequencyModel(); addModel(frequencyModel); if (branchRateModel != null) { this.branchRateModel = branchRateModel; logger.info("Branch rate model used: " + branchRateModel.getModelName()); } else { this.branchRateModel = new DefaultBranchRateModel(); } addModel(this.branchRateModel); this.categoryCount = siteModel.getCategoryCount(); int extNodeCount = treeModel.getExternalNodeCount(); int[] configuration = new int[4]; configuration[0] = stateCount; configuration[1] = patternCount; configuration[2] = siteModel.getCategoryCount(); // matrixCount configuration[3] = deviceNumber; likelihoodCore = LikelihoodCoreFactory.loadLikelihoodCore(configuration, this); // override use preference on useAmbiguities based on actual ability of the likelihood core if (!likelihoodCore.canHandleTipPartials()) { useAmbiguities = false; } if (!likelihoodCore.canHandleTipStates()) { useAmbiguities = true; } likelihoodCore.initialize(nodeCount, (useAmbiguities ? 0 : extNodeCount), patternCount, categoryCount); for (int i = 0; i < extNodeCount; i++) { // Find the id of tip i in the patternList String id = treeModel.getTaxonId(i); int index = patternList.getTaxonIndex(id); if (index == -1) { throw new TaxonList.MissingTaxonException("Taxon, " + id + ", in tree, " + treeModel.getId() + ", is not found in patternList, " + patternList.getId()); } else { if (useAmbiguities) { setPartials(likelihoodCore, patternList, index, i); } else { setStates(likelihoodCore, patternList, index, i); } } } updateSubstitutionModel = true; updateSiteModel = true; } catch (TaxonList.MissingTaxonException mte) { throw new RuntimeException(mte.toString()); } hasInitialized = true; } /** * Sets the partials from a sequence in an alignment. */ protected final void setPartials(LikelihoodCore likelihoodCore, PatternList patternList, int sequenceIndex, int nodeIndex) { double[] partials = new double[patternCount * stateCount]; boolean[] stateSet; int v = 0; for (int i = 0; i < patternCount; i++) { int state = patternList.getPatternState(sequenceIndex, i); stateSet = dataType.getStateSet(state); for (int j = 0; j < stateCount; j++) { if (stateSet[j]) { partials[v] = 1.0; } else { partials[v] = 0.0; } v++; } } likelihoodCore.setTipPartials(nodeIndex, partials); } /** * Sets the partials from a sequence in an alignment. */ protected final void setStates(LikelihoodCore likelihoodCore, PatternList patternList, int sequenceIndex, int nodeIndex) { int i; int[] states = new int[patternCount]; for (i = 0; i < patternCount; i++) { states[i] = patternList.getPatternState(sequenceIndex, i); } likelihoodCore.setTipStates(nodeIndex, states); } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** /** * Handles model changed events from the submodels. */ protected void handleModelChangedEvent(Model model, Object object, int index) { if (model == treeModel) { if (object instanceof TreeModel.TreeChangedEvent) { if (((TreeModel.TreeChangedEvent) object).isNodeChanged()) { // If a node event occurs the node and its two child nodes // are flagged for updating (this will result in everything // above being updated as well. Node events occur when a node // is added to a branch, removed from a branch or its height or // rate changes. updateNodeAndChildren(((TreeModel.TreeChangedEvent) object).getNode()); } else if (((TreeModel.TreeChangedEvent) object).isTreeChanged()) { // Full tree events result in a complete updating of the tree likelihood // Currently this event type is not used. System.err.println("Full tree update event - these events currently aren't used\n" + "so either this is in error or a new feature is using them so remove this message."); updateAllNodes(); } else { // Other event types are ignored (probably trait changes). //System.err.println("Another tree event has occured (possibly a trait change)."); } } else throw new RuntimeException("Assertion failed: Tree model changed event fired without TreeChangedEvent object"); } else if (model == branchRateModel) { if (object instanceof TreeModel.Node) { updateNode((TreeModel.Node) object); } else if (index == -1) { updateAllNodes(); } else { updateNode(treeModel.getNode(index)); } } else if (model == frequencyModel) { updateSubstitutionModel = true; updateAllNodes(); } else if (model instanceof SiteModel) { updateSubstitutionModel = true; updateSiteModel = true; updateAllNodes(); } else { throw new RuntimeException("Unknown componentChangedEvent"); } super.handleModelChangedEvent(model, object, index); } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** /** * Stores the additional state other than model components */ protected void storeState() { likelihoodCore.storeState(); super.storeState(); } /** * Restore the additional stored state */ protected void restoreState() { likelihoodCore.restoreState(); super.restoreState(); } // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** /** * Calculate the log likelihood of the current state. * * @return the log likelihood. */ protected double calculateLogLikelihood() { if (patternLogLikelihoods == null) { patternLogLikelihoods = new double[patternCount]; } if (branchUpdateIndices == null) { branchUpdateIndices = new int[nodeCount]; branchLengths = new double[nodeCount]; } if (operations == null) { operations = new int[nodeCount * 3]; dependencies = new int[nodeCount * 2]; } branchUpdateCount = 0; operationCount = 0; final NodeRef root = treeModel.getRoot(); traverse(treeModel, root, null); if (updateSubstitutionModel) { likelihoodCore.updateSubstitutionModel(siteModel.getSubstitutionModel()); } if (updateSiteModel) { likelihoodCore.updateSiteModel(siteModel); } likelihoodCore.updateMatrices(branchUpdateIndices, branchLengths, branchUpdateCount); likelihoodCore.updatePartials(operations, dependencies, operationCount, false); nodeEvaluationCount += operationCount; likelihoodCore.calculateLogLikelihoods(root.getNumber(), patternLogLikelihoods); double logL = 0.0; for (int i = 0; i < patternCount; i++) { logL += patternLogLikelihoods[i] * patternWeights[i]; } // Attempt dynamic rescaling if over/under-flow if (logL == Double.NaN || logL == Double.POSITIVE_INFINITY) { System.err.println("Potential under/over-flow; going to attempt a partials rescaling."); updateAllNodes(); branchUpdateCount = 0; operationCount = 0; traverse(treeModel, root, null); likelihoodCore.updateMatrices(branchUpdateIndices, branchLengths, branchUpdateCount); likelihoodCore.updatePartials(operations, dependencies, operationCount, true); likelihoodCore.calculateLogLikelihoods(root.getNumber(), patternLogLikelihoods); logL = 0.0; for (int i = 0; i < patternCount; i++) { logL += patternLogLikelihoods[i] * patternWeights[i]; } } //******************************************************************** // after traverse all nodes and patterns have been updated -- //so change flags to reflect this. for (int i = 0; i < nodeCount; i++) { updateNode[i] = false; } updateSubstitutionModel = false; updateSiteModel = false; //******************************************************************** return logL; } private double[] rates; private int[] branchUpdateIndices; private double[] branchLengths; private int branchUpdateCount; private int[] operations; private int[] dependencies; private int operationCount; /** * Traverse the tree calculating partial likelihoods. */ private boolean traverse(Tree tree, NodeRef node, int[] operatorNumber) { boolean update = false; int nodeNum = node.getNumber(); NodeRef parent = tree.getParent(node); // First update the transition probability matrix(ices) for this branch if (parent != null && updateNode[nodeNum]) { final double branchRate = branchRateModel.getBranchRate(tree, node); // Get the operational time of the branch final double branchTime = branchRate * (tree.getNodeHeight(parent) - tree.getNodeHeight(node)); if (branchTime < 0.0) { throw new RuntimeException("Negative branch length: " + branchTime); } branchUpdateIndices[branchUpdateCount] = nodeNum; branchLengths[branchUpdateCount] = branchTime; branchUpdateCount++; update = true; } // If the node is internal, update the partial likelihoods. if (!tree.isExternal(node)) { // Traverse down the two child nodes NodeRef child1 = tree.getChild(node, 0); final int[] op1 = {-1}; final boolean update1 = traverse(tree, child1, op1); NodeRef child2 = tree.getChild(node, 1); final int[] op2 = {-1}; final boolean update2 = traverse(tree, child2, op2); // If either child node was updated then update this node too if (update1 || update2) { int x = operationCount * 3; operations[x] = child1.getNumber(); // source node 1 operations[x + 1] = child2.getNumber(); // source node 2 operations[x + 2] = nodeNum; // destination node int y = operationCount * 3; dependencies[y] = -1; // dependent ancestor dependencies[y + 1] = 0; // isDependent? // if one of the child nodes have an update then set the dependency // element to this operation. if (op1[0] != -1) { dependencies[op1[0] * 3] = operationCount; dependencies[y + 1] = 1; // isDependent? } if (op2[0] != -1) { dependencies[op2[0] * 3] = operationCount; dependencies[y + 1] = 1; // isDependent? } if (operatorNumber != null) { dependencies[y] = operationCount; } operationCount++; update = true; } } return update; } /** * The default XML parser - this one has the same name as dr.evomodel.treelikelihod/TreeLikelihood * so will override that if loaded. */ public static TreeLikelihoodParser PARSER = new TreeLikelihoodParser(TREE_LIKELIHOOD); static class TreeLikelihoodParser extends AbstractXMLObjectParser { private final String parserName; TreeLikelihoodParser(final String parserName) { this.parserName = parserName; } public String getParserName() { return parserName; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean useAmbiguities = xo.getAttribute(USE_AMBIGUITIES, false); int deviceNumber = xo.getAttribute(DEVICE_NUMBER, 1) - 1; PatternList patternList = (PatternList) xo.getChild(PatternList.class); TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class); SiteModel siteModel = (SiteModel) xo.getChild(SiteModel.class); BranchRateModel branchRateModel = (BranchRateModel) xo.getChild(BranchRateModel.class); return new TreeLikelihood( patternList, treeModel, siteModel, branchRateModel, useAmbiguities, deviceNumber ); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element represents the likelihood of a patternlist on a tree given the site model."; } public Class getReturnType() { return Likelihood.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newBooleanRule(USE_AMBIGUITIES, true), AttributeRule.newIntegerRule(DEVICE_NUMBER, true), new ElementRule(PatternList.class), new ElementRule(TreeModel.class), new ElementRule(SiteModel.class), new ElementRule(BranchRateModel.class, true) }; } // ************************************************************** // INSTANCE VARIABLES // ************************************************************** /** * the frequency model for these sites */ protected final FrequencyModel frequencyModel; /** * the site model for these sites */ protected final SiteModel siteModel; /** * the branch rate model */ protected final BranchRateModel branchRateModel; /** * the pattern likelihoods */ protected double[] patternLogLikelihoods = null; /** * the number of rate categories */ protected int categoryCount; /** * an array used to transfer tip partials */ protected double[] tipPartials; /** * the LikelihoodCore */ protected LikelihoodCore likelihoodCore; /** * Flag to specify that the substitution model has changed */ protected boolean updateSubstitutionModel; /** * Flag to specify that the site model has changed */ protected boolean updateSiteModel; private int nodeEvaluationCount = 0; public int getNodeEvaluationCount() { return nodeEvaluationCount; } // /*** // * Flag to specify if LikelihoodCore supports dynamic rescaling // */ // private boolean dynamicRescaling = false; }