/* * ARGLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ /* * TreeLikelihood.java * * (c) 2002-2005 BEAST Development Core Team * * This package may be distributed under the * Lesser Gnu Public Licence (LGPL) */ package dr.evomodel.arg.likelihood; import dr.evolution.alignment.PatternList; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.TaxonList; import dr.evomodel.arg.ARGModel; import dr.evomodel.arg.ARGTree; import dr.evomodel.arg.operators.ARGPartitioningOperator; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.branchratemodel.DefaultBranchRateModel; import dr.oldevomodel.sitemodel.SiteModel; import dr.oldevomodel.substmodel.FrequencyModel; import dr.evomodel.treelikelihood.*; import dr.inference.model.Likelihood; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.oldevomodel.treelikelihood.*; import dr.xml.*; import java.util.logging.Logger; import java.util.*; /** * ARGLikelihood - implements a Likelihood Function for sequences on an ancestral recombination graph. * * @author Marc Suchard * @version $Id: ARGLikelihood.java,v 1.3 2006/10/23 04:13:41 msuchard Exp $ */ public class ARGLikelihood extends AbstractARGLikelihood { public static final String ARG_LIKELIHOOD = "argTreeLikelihood"; public static final String USE_AMBIGUITIES = "useAmbiguities"; public static final String STORE_PARTIALS = "storePartials"; public static final String USE_SCALING = "useScaling"; /** * Constructor. */ public ARGLikelihood(PatternList patternList, ARGModel treeModel, SiteModel siteModel, BranchRateModel branchRateModel, boolean useAmbiguities, boolean storePartials, boolean useScaling) { super(ARG_LIKELIHOOD, patternList, treeModel); partition = treeModel.addLikelihoodCalculator(this); this.storePartials = storePartials; this.useAmbiguities = useAmbiguities; try { this.siteModel = siteModel; addModel(siteModel); this.frequencyModel = siteModel.getFrequencyModel(); addModel(frequencyModel); integrateAcrossCategories = siteModel.integrateAcrossCategories(); this.categoryCount = siteModel.getCategoryCount(); if (integrateAcrossCategories) { if (patternList.getDataType() instanceof dr.evolution.datatype.Nucleotides) { if (NativeNucleotideLikelihoodCore.isAvailable()) { Logger.getLogger("dr.evomodel").info("TreeLikelihood using native nucleotide likelihood core"); likelihoodCore = new NativeNucleotideLikelihoodCore(); } else { Logger.getLogger("dr.evomodel").info("TreeLikelihood using Java nucleotide likelihood core"); likelihoodCore = new NucleotideLikelihoodCore(); } } else if (patternList.getDataType() instanceof dr.evolution.datatype.AminoAcids) { if (NativeAminoAcidLikelihoodCore.isAvailable()) { Logger.getLogger("dr.evomodel").info("TreeLikelihood using native amino acid likelihood core"); likelihoodCore = new NativeAminoAcidLikelihoodCore(); } else { Logger.getLogger("dr.evomodel").info("TreeLikelihood using java likelihood core"); likelihoodCore = new AminoAcidLikelihoodCore(); } } else if (patternList.getDataType() instanceof dr.evolution.datatype.Codons) { Logger.getLogger("dr.evomodel").info("TreeLikelihood using Java codon likelihood core"); // likelihoodCore = new CodonLikelihoodCore(patternList.getStateCount()); this.useAmbiguities = true; throw new RuntimeException("Still need to merge codon likelihood core"); } else { if (patternList.getDataType() instanceof dr.evolution.datatype.OldHiddenNucleotides && NativeCovarionLikelihoodCore.isAvailable()) { Logger.getLogger("dr.evomodel").info("TreeLikelihood using native covarion likelihood core"); likelihoodCore = new NativeCovarionLikelihoodCore(); } else { Logger.getLogger("dr.evomodel").info("TreeLikelihood using Java general likelihood core"); likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount()); } } } else { Logger.getLogger("dr.evomodel").info("TreeLikelihood using Java general likelihood core"); likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount()); } Logger.getLogger("dr.evomodel").info(" " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood."); Logger.getLogger("dr.evomodel").info(" Partial likelihood scaling " + (useScaling ? "on." : "off.")); if (branchRateModel != null) { this.branchRateModel = branchRateModel; Logger.getLogger("dr.evomodel").info("Branch rate model used: " + branchRateModel.getModelName()); } else { this.branchRateModel = new DefaultBranchRateModel(); } addModel(this.branchRateModel); probabilities = new double[stateCount * stateCount]; // likelihoodCore.initialize(nodeCount, patternCount, categoryCount, integrateAcrossCategories, useScaling); likelihoodCore.initialize(nodeCount, patternCount, categoryCount, integrateAcrossCategories); int extNodeCount = treeModel.getExternalNodeCount(); int intNodeCount = treeModel.getInternalNodeCount(); for (int i = 0; i < extNodeCount; i++) { // Find the id of tip i in the patternList String id = treeModel.getTaxonId(i); int index = patternList.getTaxonIndex(id); // System.err.println("id = "+id+" index = "+index); if (index == -1) { throw new TaxonList.MissingTaxonException("Taxon, " + id + ", in tree, " + treeModel.getId() + ", is not found in patternList, " + patternList.getId()); } if (useAmbiguities) { setPartials(likelihoodCore, patternList, categoryCount, index, i); } else { setStates(likelihoodCore, patternList, index, i); } } // System.exit(-1); for (int i = 0; i < intNodeCount; i++) { likelihoodCore.createNodePartials(extNodeCount + i); } } catch (TaxonList.MissingTaxonException mte) { throw new RuntimeException(mte.toString()); } } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** private static final boolean NO_CACHING = false; /** * Handles model changed events from the submodels. */ protected void handleModelChangedEvent(Model model, Object object, int index) { if (NO_CACHING) { reconstructTree = true; updateAllNodes(); } if (model == treeModel) { if (object instanceof ARGModel.TreeChangedEvent) { ARGModel.TreeChangedEvent event = (ARGModel.TreeChangedEvent) object; if (event.isSizeChanged() ) { updateAllNodes(); // TODO Update only affected portion of tree reconstructTree = true; } else if (event.isNodeChanged()) { // If a node event occurs the node and its two child nodes // are flagged for updating (this will result in everything // above being updated as well. Node events occur when a node // is added to a branch, removed from a branch or its height or // rate changes. NodeRef treeNode = mapARGNodesToTreeNodes.get(event.getNode()); if ( treeNode != null ) { if (event.isHeightChanged() || event.isRateChanged()) { updateNodeAndChildren(treeNode); } else { reconstructTree = true; // updateNodeAndChildren(treeNode); // TODO This doesn't work with sizeChange; why??? updateAllNodes(); } } } else if (event.isTreeChanged()) { // Full tree events result in a complete updating of the tree likelihood // These include adding and removing nodes // TODO ARG rearrangements still call this; they should not reconstructTree = true; updateAllNodes(); } else { // Other event types are ignored (probably trait changes). throw new RuntimeException("Another tree event has occured (possibly a trait change)."); } } else if (object instanceof ARGPartitioningOperator.PartitionChangedEvent) { final boolean[] updatePartition = ((ARGPartitioningOperator.PartitionChangedEvent) object).getUpdatedPartitions(); if (updatePartition[partition]) { reconstructTree = true; updateAllNodes(); // TODO Probably does not affect entire tree; fix } } else if (object instanceof Parameter) { // ignore, most of these are handled in isNodeChanged() } else throw new RuntimeException("Unexpected ARGModel update "+object.getClass()); } else if (model == branchRateModel) { // TODO Only update affected branches updateAllNodes(); } else if (model == frequencyModel) { updateAllNodes(); } else if (model instanceof SiteModel) { updateAllNodes(); } else { throw new RuntimeException("Unknown componentChangedEvent"); } super.handleModelChangedEvent(model, object, index); } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** /** * Stores the additional state other than model components */ protected void storeState() { if (storePartials) { likelihoodCore.storeState(); } super.storeState(); } /** * Restore the additional stored state */ protected void restoreState() { if (storePartials) { likelihoodCore.restoreState(); } else { updateAllNodes(); } reconstructTree = true; // currently the tree is not cached, because the ARG that generates it is cached super.restoreState(); } private int getUnusedInt(Map<NodeRef,Integer> inMap) { Collection<Integer> intSet = inMap.values(); int i = tree.getExternalNodeCount(); while( intSet.contains(i) ) i++; return i; } private Set<NodeRef> unsetNodes = null; private void reconstructTree() { oldTree = tree; oldMapARGNodesToInts = mapARGNodesToInts; tree = new ARGTree(treeModel, partition); reconstructTree = false; mapARGNodesToInts = new HashMap<NodeRef,Integer>(tree.getInternalNodeCount()); mapARGNodesToTreeNodes = tree.getMapping(); if (oldTree == null) { // First initialization for(int i=0; i<tree.getInternalNodeCount(); i++) { NodeRef node = tree.getInternalNode(i); mapARGNodesToInts.put(treeModel.getMirrorNode(node),node.getNumber()); } } else { // Need to renumber if (unsetNodes == null) unsetNodes = new HashSet<NodeRef>(); else unsetNodes.clear(); // Copy over numbers for nodes that still exist in tree for (int i = 0; i < tree.getInternalNodeCount(); i++) { NodeRef newNode = tree.getInternalNode(i); NodeRef argNode = treeModel.getMirrorNode(newNode); if (oldMapARGNodesToInts.containsKey(argNode)) { // was in old tree int oldNumber = oldMapARGNodesToInts.get(argNode); treeModel.setNodeNumber(newNode,oldNumber); mapARGNodesToInts.put(argNode,oldNumber); } else // was not in old tree unsetNodes.add(newNode); } // Set unused numbers for nodes that are new and mark for update for (NodeRef node : unsetNodes) { int newNumber = getUnusedInt(mapARGNodesToInts); treeModel.setNodeNumber(node,newNumber); mapARGNodesToInts.put(node,newNumber); updateNode[newNumber] = true; } } } // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** /** * Calculate the log likelihood of the current state. * * @return the log likelihood. */ protected double calculateLogLikelihood() { if (reconstructTree) { reconstructTree(); } NodeRef root = tree.getRoot(); if (rootPartials == null) { rootPartials = new double[patternCount * stateCount]; } if (patternLogLikelihoods == null) { patternLogLikelihoods = new double[patternCount]; } if (!integrateAcrossCategories) { if (siteCategories == null) { siteCategories = new int[patternCount]; } for (int i = 0; i < patternCount; i++) { siteCategories[i] = siteModel.getCategoryOfSite(i); } } try { traverse(tree, root); } catch (NegativeBranchLengthException e) { System.err.println("Negative branch length found, trying to return 0 likelihood"); return Double.NEGATIVE_INFINITY; } //******************************************************************** // after traverse all nodes and patterns have been updated -- //so change flags to reflect this. for (int i = 0; i < nodeCount; i++) { updateNode[i] = false; } //******************************************************************** double logL = 0.0; for (int i = 0; i < patternCount; i++) { // System.err.printf("Pattern %2d: %5.4f %5.4f\n",i,patternLogLikelihoods[i],patternWeights[i]); logL += patternLogLikelihoods[i] * patternWeights[i]; } return logL; } class NegativeBranchLengthException extends Exception { } /** * Traverse the tree calculating partial likelihoods. * * @return whether the partials for this node were recalculated. */ private boolean traverse(Tree tree, NodeRef node) throws NegativeBranchLengthException { boolean update = false; int nodeNum = node.getNumber(); // System.err.println(nodeNum); NodeRef parent = tree.getParent(node); // First update the transition probability matrix(ices) for this branch if (parent != null && updateNode[nodeNum]) { double branchRate = branchRateModel.getBranchRate(tree, node); // Get the operational time of the branch double branchTime = branchRate * (tree.getNodeHeight(parent) - tree.getNodeHeight(node)); if (branchTime < 0.0) { if (!DEBUG) { throw new RuntimeException("Negative branch length: " + branchTime); } else{ throw new NegativeBranchLengthException(); } } for (int i = 0; i < categoryCount; i++) { double branchLength = siteModel.getRateForCategory(i) * branchTime; siteModel.getSubstitutionModel().getTransitionProbabilities(branchLength, probabilities); likelihoodCore.setNodeMatrix(nodeNum, i, probabilities); } update = true; } // If the node is internal, update the partial likelihoods. if (!tree.isExternal(node)) { int nodeCount = tree.getChildCount(node); if (nodeCount != 2) throw new RuntimeException("binary trees only!"); // Traverse down the two child nodes NodeRef child1 = tree.getChild(node, 0); boolean update1 = traverse(tree, child1); NodeRef child2 = tree.getChild(node, 1); boolean update2 = traverse(tree, child2); // If either child node was updated then update this node too if (update1 || update2) { int childNum1 = child1.getNumber(); int childNum2 = child2.getNumber(); if (integrateAcrossCategories) { likelihoodCore.calculatePartials(childNum1, childNum2, nodeNum); } else { likelihoodCore.calculatePartials(childNum1, childNum2, nodeNum, siteCategories); } if (parent == null) { // No parent this is the root of the tree - // calculate the pattern likelihoods double[] frequencies = frequencyModel.getFrequencies(); if (integrateAcrossCategories) { // moved this call to here, because non-integrating siteModels don't need to support it - AD double[] proportions = siteModel.getCategoryProportions(); likelihoodCore.integratePartials(nodeNum, proportions, rootPartials); } else { likelihoodCore.getPartials(nodeNum, rootPartials); } likelihoodCore.calculateLogLikelihoods(rootPartials, frequencies, patternLogLikelihoods); } update = true; } } return update; } /** * The XML parser */ public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { public String getParserName() { return ARG_LIKELIHOOD; } public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean useAmbiguities = false; boolean storePartials = true; boolean useScaling = false; if (xo.hasAttribute(USE_AMBIGUITIES)) { useAmbiguities = xo.getBooleanAttribute(USE_AMBIGUITIES); } if (xo.hasAttribute(STORE_PARTIALS)) { storePartials = xo.getBooleanAttribute(STORE_PARTIALS); } if (xo.hasAttribute(USE_SCALING)) { useScaling = xo.getBooleanAttribute(USE_SCALING); } PatternList patternList = (PatternList) xo.getChild(PatternList.class); ARGModel treeModel = (ARGModel) xo.getChild(ARGModel.class); SiteModel siteModel = (SiteModel) xo.getChild(SiteModel.class); BranchRateModel branchRateModel = (BranchRateModel) xo.getChild(BranchRateModel.class); return new ARGLikelihood(patternList, treeModel, siteModel, branchRateModel, useAmbiguities, storePartials, useScaling); } //************************************************************************ // AbstractXMLObjectParser implementation //************************************************************************ public String getParserDescription() { return "This element represents the likelihood of a patternlist on a tree given the site model."; } public Class getReturnType() { return Likelihood.class; } public XMLSyntaxRule[] getSyntaxRules() { return rules; } private final XMLSyntaxRule[] rules = { AttributeRule.newBooleanRule(USE_AMBIGUITIES, true), AttributeRule.newBooleanRule(STORE_PARTIALS, true), AttributeRule.newBooleanRule(USE_SCALING, true), new ElementRule(PatternList.class), new ElementRule(ARGModel.class), new ElementRule(SiteModel.class), new ElementRule(BranchRateModel.class, true) }; }; /** * XML Serializer for parallelization * */ // public Element toXML() { // Element likelihoodElement // } // ************************************************************** // INSTANCE VARIABLES // ************************************************************** /** * the frequency model for these sites */ protected FrequencyModel frequencyModel = null; /** * the site model for these sites */ protected SiteModel siteModel = null; /** * the branch rate model */ protected BranchRateModel branchRateModel = null; private boolean storePartials = false; private boolean integrateAcrossCategories = false; /** * the categories for each site */ protected int[] siteCategories = null; /** * the root partial likelihoods */ protected double[] rootPartials = null; /** * the pattern likelihoods */ protected double[] patternLogLikelihoods = null; /** * the number of rate categories */ protected int categoryCount; /** * an array used to store transition probabilities */ protected double[] probabilities; /** * the LikelihoodCore */ protected LikelihoodCore likelihoodCore; private boolean useAmbiguities; private boolean reconstructTree = true; private ARGTree tree = null; private ARGTree oldTree; private Map<NodeRef,Integer> mapARGNodesToInts = null; private Map<NodeRef,Integer> oldMapARGNodesToInts; private Map<NodeRef,NodeRef> mapARGNodesToTreeNodes = null; private static final boolean DEBUG = true; }