/* * BeagleTreeLikelihood.java * * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.treedatalikelihood; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.tree.TreeTrait; import dr.evolution.tree.TreeTraitProvider; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.branchratemodel.DefaultBranchRateModel; import dr.evomodel.tree.TreeModel; import dr.inference.model.AbstractModelLikelihood; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.xml.Reportable; import java.util.List; import java.util.logging.Logger; /** * TreeDataLikelihood - uses plugin delegates to compute the likelihood of some data given a tree. * * @author Andrew Rambaut * @author Marc Suchard * @version $Id$ */ public final class TreeDataLikelihood extends AbstractModelLikelihood implements TreeTraitProvider, Reportable { protected static final boolean COUNT_TOTAL_OPERATIONS = true; private static final long MAX_UNDERFLOWS_BEFORE_ERROR = 100; public TreeDataLikelihood(DataLikelihoodDelegate likelihoodDelegate, TreeModel treeModel, BranchRateModel branchRateModel) { super("TreeDataLikelihood"); // change this to use a const once the parser exists assert likelihoodDelegate != null; assert treeModel != null; assert branchRateModel != null; final Logger logger = Logger.getLogger("dr.evomodel"); logger.info("\nUsing TreeDataLikelihood"); this.likelihoodDelegate = likelihoodDelegate; addModel(likelihoodDelegate); likelihoodDelegate.setCallback(this); this.treeModel = treeModel; isTreeRandom = treeModel.isTreeRandom(); if (isTreeRandom) { addModel(treeModel); } likelihoodKnown = false; this.branchRateModel = branchRateModel; if (!(branchRateModel instanceof DefaultBranchRateModel)) { logger.info(" Branch rate model used: " + branchRateModel.getModelName()); } addModel(this.branchRateModel); treeTraversalDelegate = new LikelihoodTreeTraversal(treeModel, branchRateModel, likelihoodDelegate.getOptimalTraversalType()); hasInitialized = true; } public final Tree getTree() { return treeModel; } public final BranchRateModel getBranchRateModel() { return branchRateModel; } public final DataLikelihoodDelegate getDataLikelihoodDelegate() { return likelihoodDelegate; } // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** @Override public final Model getModel() { return this; } @Override public final double getLogLikelihood() { if (COUNT_TOTAL_OPERATIONS) totalGetLogLikelihoodCount++; if (!likelihoodKnown) { if (COUNT_TOTAL_OPERATIONS) totalCalculateLikelihoodCount++; logLikelihood = calculateLogLikelihood(); likelihoodKnown = true; } return logLikelihood; } @Override public final void makeDirty() { if (COUNT_TOTAL_OPERATIONS) totalMakeDirtyCount++; likelihoodKnown = false; likelihoodDelegate.makeDirty(); updateAllNodes(); } public final boolean isLikelihoodKnown() { return likelihoodKnown; } // ************************************************************** // VariableListener IMPLEMENTATION // ************************************************************** protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { // do nothing } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** @Override protected final void handleModelChangedEvent(Model model, Object object, int index) { if (model == treeModel) { if (object instanceof TreeModel.TreeChangedEvent) { if (!isTreeRandom) throw new IllegalStateException("Attempting to change a fixed tree"); if (((TreeModel.TreeChangedEvent) object).isNodeChanged()) { // If a node event occurs the node and its two child nodes // are flagged for updating (this will result in everything // above being updated as well. Node events occur when a node // is added to a branch, removed from a branch or its height or // rate changes. updateNodeAndChildren(((TreeModel.TreeChangedEvent) object).getNode()); } else if (((TreeModel.TreeChangedEvent) object).isTreeChanged()) { // Full tree events result in a complete updating of the tree likelihood // This event type is now used for EmpiricalTreeDistributions. updateAllNodes(); } else { // Other event types are ignored (probably trait changes). } } } else if (model == likelihoodDelegate) { if (index == -1) { updateAllNodes(); } else { updateNode(treeModel.getNode(index)); } } else if (model == branchRateModel) { if (index == -1) { updateAllNodes(); } else { updateNode(treeModel.getNode(index)); } } else { assert false : "Unknown componentChangedEvent"; } if (COUNT_TOTAL_OPERATIONS) totalModelChangedCount++; likelihoodKnown = false; fireModelChanged(); } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** @Override protected final void storeState() { assert (likelihoodKnown) : "the likelihood should always be known at this point in the cycle"; storedLogLikelihood = logLikelihood; } @Override protected final void restoreState() { // restore the likelihood and flag it as known logLikelihood = storedLogLikelihood; likelihoodKnown = true; } @Override protected void acceptState() { } // nothing to do /** * Calculate the log likelihood of the data for the current tree. * * @return the log likelihood. */ private final double calculateLogLikelihood() { double logL = Double.NEGATIVE_INFINITY; boolean done = false; long underflowCount = 0; do { treeTraversalDelegate.dispatchTreeTraversalCollectBranchAndNodeOperations(); final List<DataLikelihoodDelegate.BranchOperation> branchOperations = treeTraversalDelegate.getBranchOperations(); final List<DataLikelihoodDelegate.NodeOperation> nodeOperations = treeTraversalDelegate.getNodeOperations(); if (COUNT_TOTAL_OPERATIONS) { totalMatrixUpdateCount += branchOperations.size(); totalOperationCount += nodeOperations.size(); } final NodeRef root = treeModel.getRoot(); try { logL = likelihoodDelegate.calculateLikelihood(branchOperations, nodeOperations, root.getNumber()); done = true; } catch (DataLikelihoodDelegate.LikelihoodException e) { // if there is an underflow, assume delegate will attempt to rescale // so flag all nodes to update and return to try again. updateAllNodes(); underflowCount++; } } while (!done && underflowCount < MAX_UNDERFLOWS_BEFORE_ERROR); // after traverse all nodes and patterns have been updated -- //so change flags to reflect this. setAllNodesUpdated(); return logL; } private void setAllNodesUpdated() { treeTraversalDelegate.setAllNodesUpdated(); } /** * Set update flag for a node only */ protected void updateNode(NodeRef node) { if (COUNT_TOTAL_OPERATIONS) totalRateUpdateSingleCount++; treeTraversalDelegate.updateNode(node); likelihoodKnown = false; } /** * Set update flag for a node and its direct children */ protected void updateNodeAndChildren(NodeRef node) { if (COUNT_TOTAL_OPERATIONS) totalRateUpdateSingleCount += 1 + treeModel.getChildCount(node); treeTraversalDelegate.updateNodeAndChildren(node); likelihoodKnown = false; } /** * Set update flag for a node and all its descendents */ protected void updateNodeAndDescendents(NodeRef node) { if (COUNT_TOTAL_OPERATIONS) totalRateUpdateSingleCount++; treeTraversalDelegate.updateNodeAndDescendents(node); likelihoodKnown = false; } /** * Set update flag for all nodes */ protected void updateAllNodes() { if (COUNT_TOTAL_OPERATIONS) totalRateUpdateAllCount++; treeTraversalDelegate.updateAllNodes(); likelihoodKnown = false; } // ************************************************************** // Reportable IMPLEMENTATION // ************************************************************** @Override public String getReport() { if (hasInitialized) { StringBuilder sb = new StringBuilder(); String delegateString = likelihoodDelegate.getReport(); if (delegateString != null) { sb.append(delegateString); System.err.println(delegateString); } sb.append(getClass().getName() + "(" + getLogLikelihood() + ")"); if (COUNT_TOTAL_OPERATIONS) sb.append("\n total operations = " + totalOperationCount + "\n matrix updates = " + totalMatrixUpdateCount + "\n model changes = " + totalModelChangedCount + "\n make dirties = " + totalMakeDirtyCount + "\n calculate likelihoods = " + totalCalculateLikelihoodCount + "\n get likelihoods = " + totalGetLogLikelihoodCount + "\n all rate updates = " + totalRateUpdateAllCount + "\n partial rate updates = " + totalRateUpdateSingleCount); return sb.toString(); } else { return getClass().getName() + "(uninitialized)"; } } // ************************************************************** // TreeTrait IMPLEMENTATION // ************************************************************** /** * Returns an array of all the available traits * * @return the array */ @Override public TreeTrait[] getTreeTraits() { return treeTraits.getTreeTraits(); } /** * Returns a trait that is stored using a specific key. This will often be the same * as the 'name' of the trait but may not be depending on the application. * * @param key a unique key * @return the trait */ @Override public TreeTrait getTreeTrait(String key) { return treeTraits.getTreeTrait(key); } // ************************************************************** // Decorate with TreeTraitProviders // ************************************************************** public void addTrait(TreeTrait trait) { treeTraits.addTrait(trait); } public void addTraits(TreeTrait[] traits) { treeTraits.addTraits(traits); } // ************************************************************** // INSTANCE VARIABLES // ************************************************************** /** * The data likelihood delegate */ private final DataLikelihoodDelegate likelihoodDelegate; /** * the tree model */ private final TreeModel treeModel; /** * the branch rate model */ private final BranchRateModel branchRateModel; /** * TreeTrait helper */ private final Helper treeTraits = new Helper(); private final LikelihoodTreeTraversal treeTraversalDelegate; private double logLikelihood; private double storedLogLikelihood; protected boolean likelihoodKnown = false; private boolean hasInitialized = false; private final boolean isTreeRandom; private int totalOperationCount = 0; private int totalMatrixUpdateCount = 0; private int totalGetLogLikelihoodCount = 0; private int totalModelChangedCount = 0; private int totalMakeDirtyCount = 0; private int totalCalculateLikelihoodCount = 0; private int totalRateUpdateAllCount = 0; private int totalRateUpdateSingleCount = 0; }