/* * TreeLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.oldevomodel.treelikelihood; import dr.evolution.alignment.AscertainedSitePatterns; import dr.evolution.alignment.PatternList; import dr.evolution.alignment.SitePatterns; import dr.evolution.datatype.DataType; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.Taxon; import dr.evolution.util.TaxonList; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.branchratemodel.DefaultBranchRateModel; import dr.oldevomodel.sitemodel.SiteModel; import dr.oldevomodel.substmodel.FrequencyModel; import dr.evomodel.tipstatesmodel.TipStatesModel; import dr.evomodel.tree.TreeModel; import dr.oldevomodelxml.treelikelihood.TreeLikelihoodParser; import dr.inference.model.Model; import dr.inference.model.Statistic; import java.util.logging.Logger; /** * TreeLikelihoodModel - implements a Likelihood Function for sequences on a tree. * * @author Andrew Rambaut * @author Alexei Drummond * @version $Id: TreeLikelihood.java,v 1.31 2006/08/30 16:02:42 rambaut Exp $ */ @Deprecated // Switching to BEAGLE public class TreeLikelihood extends AbstractTreeLikelihood { private static final boolean DEBUG = false; /** * Constructor. */ public TreeLikelihood(PatternList patternList, TreeModel treeModel, SiteModel siteModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean useAmbiguities, boolean allowMissingTaxa, boolean storePartials, boolean forceJavaCore, boolean forceRescaling) { super(TreeLikelihoodParser.TREE_LIKELIHOOD, patternList, treeModel); this.storePartials = storePartials; try { this.siteModel = siteModel; addModel(siteModel); this.frequencyModel = siteModel.getFrequencyModel(); addModel(frequencyModel); this.tipStatesModel = tipStatesModel; integrateAcrossCategories = siteModel.integrateAcrossCategories(); this.categoryCount = siteModel.getCategoryCount(); final Logger logger = Logger.getLogger("dr.evomodel"); String coreName = "Java general"; if (integrateAcrossCategories) { final DataType dataType = patternList.getDataType(); if (dataType instanceof dr.evolution.datatype.Nucleotides) { if (!forceJavaCore && NativeNucleotideLikelihoodCore.isAvailable()) { coreName = "native nucleotide"; likelihoodCore = new NativeNucleotideLikelihoodCore(); } else { coreName = "Java nucleotide"; likelihoodCore = new NucleotideLikelihoodCore(); } } else if (dataType instanceof dr.evolution.datatype.AminoAcids) { if (!forceJavaCore && NativeAminoAcidLikelihoodCore.isAvailable()) { coreName = "native amino acid"; likelihoodCore = new NativeAminoAcidLikelihoodCore(); } else { coreName = "Java amino acid"; likelihoodCore = new AminoAcidLikelihoodCore(); } // The codon core was out of date and did nothing more than the general core... } else if (dataType instanceof dr.evolution.datatype.Codons) { if (!forceJavaCore && NativeGeneralLikelihoodCore.isAvailable()) { coreName = "native general"; likelihoodCore = new NativeGeneralLikelihoodCore(patternList.getStateCount()); } else { coreName = "Java general"; likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount()); } useAmbiguities = true; } else { if (!forceJavaCore && NativeGeneralLikelihoodCore.isAvailable()) { coreName = "native general"; likelihoodCore = new NativeGeneralLikelihoodCore(patternList.getStateCount()); } else { coreName = "Java general"; likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount()); } } } else { likelihoodCore = new GeneralLikelihoodCore(patternList.getStateCount()); } { final String id = getId(); logger.info("TreeLikelihood(" + ((id != null) ? id : treeModel.getId()) + ") using " + coreName + " likelihood core"); logger.info(" " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood."); logger.info(" With " + patternList.getPatternCount() + " unique site patterns."); } if (branchRateModel != null) { this.branchRateModel = branchRateModel; logger.info("Branch rate model used: " + branchRateModel.getModelName()); } else { this.branchRateModel = new DefaultBranchRateModel(); } addModel(this.branchRateModel); probabilities = new double[stateCount * stateCount]; likelihoodCore.initialize(nodeCount, patternCount, categoryCount, integrateAcrossCategories); int extNodeCount = treeModel.getExternalNodeCount(); int intNodeCount = treeModel.getInternalNodeCount(); if (tipStatesModel != null) { tipStatesModel.setTree(treeModel); tipPartials = new double[patternCount * stateCount]; for (int i = 0; i < extNodeCount; i++) { // Find the id of tip i in the patternList String id = treeModel.getTaxonId(i); int index = patternList.getTaxonIndex(id); if (index == -1) { throw new TaxonList.MissingTaxonException("Taxon, " + id + ", in tree, " + treeModel.getId() + ", is not found in patternList, " + patternList.getId()); } tipStatesModel.setStates(patternList, index, i, id); likelihoodCore.createNodePartials(i); } addModel(tipStatesModel); } else { for (int i = 0; i < extNodeCount; i++) { // Find the id of tip i in the patternList String id = treeModel.getTaxonId(i); int index = patternList.getTaxonIndex(id); if (index == -1) { if (!allowMissingTaxa) { throw new TaxonList.MissingTaxonException("Taxon, " + id + ", in tree, " + treeModel.getId() + ", is not found in patternList, " + patternList.getId()); } if (useAmbiguities) { setMissingPartials(likelihoodCore, i); } else { setMissingStates(likelihoodCore, i); } } else { if (useAmbiguities) { setPartials(likelihoodCore, patternList, categoryCount, index, i); } else { setStates(likelihoodCore, patternList, index, i); } } } } for (int i = 0; i < intNodeCount; i++) { likelihoodCore.createNodePartials(extNodeCount + i); } if (forceRescaling) { likelihoodCore.setUseScaling(true); logger.info(" Forcing use of partials rescaling."); } } catch (TaxonList.MissingTaxonException mte) { throw new RuntimeException(mte.toString()); } addStatistic(new SiteLikelihoodsStatistic()); } public final LikelihoodCore getLikelihoodCore() { return likelihoodCore; } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** /** * Handles model changed events from the submodels. */ protected void handleModelChangedEvent(Model model, Object object, int index) { if (model == treeModel) { if (object instanceof TreeModel.TreeChangedEvent) { if (((TreeModel.TreeChangedEvent) object).isNodeChanged()) { // If a node event occurs the node and its two child nodes // are flagged for updating (this will result in everything // above being updated as well. Node events occur when a node // is added to a branch, removed from a branch or its height or // rate changes. updateNodeAndChildren(((TreeModel.TreeChangedEvent) object).getNode()); } else if (((TreeModel.TreeChangedEvent) object).isTreeChanged()) { // Full tree events result in a complete updating of the tree likelihood updateAllNodes(); } else { // Other event types are ignored (probably trait changes). //System.err.println("Another tree event has occured (possibly a trait change)."); } } } else if (model == branchRateModel) { if (index == -1) { updateAllNodes(); } else { if (DEBUG) { if (index >= treeModel.getNodeCount()) { throw new IllegalArgumentException("Node index out of bounds"); } } updateNode(treeModel.getNode(index)); } } else if (model == frequencyModel) { updateAllNodes(); } else if (model == tipStatesModel) { if(object instanceof Taxon) { for(int i=0; i<treeModel.getNodeCount(); i++) if(treeModel.getNodeTaxon(treeModel.getNode(i))!=null && treeModel.getNodeTaxon(treeModel.getNode(i)).getId().equalsIgnoreCase(((Taxon)object).getId())) updateNode(treeModel.getNode(i)); }else updateAllNodes(); } else if (model instanceof SiteModel) { updateAllNodes(); } else { throw new RuntimeException("Unknown componentChangedEvent"); } super.handleModelChangedEvent(model, object, index); } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** /** * Stores the additional state other than model components */ protected void storeState() { if (storePartials) { likelihoodCore.storeState(); } super.storeState(); } /** * Restore the additional stored state */ protected void restoreState() { if (storePartials) { likelihoodCore.restoreState(); } else { updateAllNodes(); } super.restoreState(); } // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** /** * Calculate the log likelihood of the current state. * * @return the log likelihood. */ protected double calculateLogLikelihood() { if (patternLogLikelihoods == null) { patternLogLikelihoods = new double[patternCount]; } if (!integrateAcrossCategories) { if (siteCategories == null) { siteCategories = new int[patternCount]; } for (int i = 0; i < patternCount; i++) { siteCategories[i] = siteModel.getCategoryOfSite(i); } } if (tipStatesModel != null) { int extNodeCount = treeModel.getExternalNodeCount(); for (int index = 0; index < extNodeCount; index++) { if (updateNode[index]) { likelihoodCore.setNodePartialsForUpdate(index); tipStatesModel.getTipPartials(index, tipPartials); likelihoodCore.setCurrentNodePartials(index, tipPartials); } } } final NodeRef root = treeModel.getRoot(); traverse(treeModel, root); double logL = 0.0; double ascertainmentCorrection = getAscertainmentCorrection(patternLogLikelihoods); for (int i = 0; i < patternCount; i++) { logL += (patternLogLikelihoods[i] - ascertainmentCorrection) * patternWeights[i]; } if (logL == Double.NEGATIVE_INFINITY) { Logger.getLogger("dr.evomodel").info("TreeLikelihood, " + this.getId() + ", turning on partial likelihood scaling to avoid precision loss"); // We probably had an underflow... turn on scaling likelihoodCore.setUseScaling(true); // and try again... updateAllNodes(); updateAllPatterns(); traverse(treeModel, root); logL = 0.0; ascertainmentCorrection = getAscertainmentCorrection(patternLogLikelihoods); for (int i = 0; i < patternCount; i++) { logL += (patternLogLikelihoods[i] - ascertainmentCorrection) * patternWeights[i]; } } //******************************************************************** // after traverse all nodes and patterns have been updated -- //so change flags to reflect this. for (int i = 0; i < nodeCount; i++) { updateNode[i] = false; } //******************************************************************** return logL; } public double[] getPatternLogLikelihoods() { getLogLikelihood(); // Ensure likelihood is up-to-date double ascertainmentCorrection = getAscertainmentCorrection(patternLogLikelihoods); double[] out = new double[patternCount]; for (int i = 0; i < patternCount; i++) { if (patternWeights[i] > 0) { out[i] = (patternLogLikelihoods[i] - ascertainmentCorrection) * patternWeights[i]; } else { out[i] = Double.NEGATIVE_INFINITY; } } return out; } /* Calculate ascertainment correction if working off of AscertainedSitePatterns @param patternLogProbs log pattern probabilities @return the log total probability for a pattern. */ protected double getAscertainmentCorrection(double[] patternLogProbs) { if (patternList instanceof AscertainedSitePatterns) { return ((AscertainedSitePatterns) patternList).getAscertainmentCorrection(patternLogProbs); } else { return 0.0; } } /** * Check whether the scaling is still required. If the sum of all the logScalingFactors * is zero then we simply turn off the useScaling flag. This will speed up the likelihood * calculations when scaling is not required. */ public void checkScaling() { // if (useScaling) { // if (scalingCheckCount % 1000 == 0) { // double totalScalingFactor = 0.0; // for (int i = 0; i < nodeCount; i++) { // for (int j = 0; j < patternCount; j++) { // totalScalingFactor += scalingFactors[currentPartialsIndices[i]][i][j]; // } // } // useScaling = totalScalingFactor < 0.0; // Logger.getLogger("dr.evomodel").info("LikelihoodCore total log scaling factor: " + totalScalingFactor); // if (!useScaling) { // Logger.getLogger("dr.evomodel").info("LikelihoodCore scaling turned off."); // } // } // scalingCheckCount++; // } } /** * Traverse the tree calculating partial likelihoods. * * @return whether the partials for this node were recalculated. */ protected boolean traverse(Tree tree, NodeRef node) { boolean update = false; int nodeNum = node.getNumber(); NodeRef parent = tree.getParent(node); // First update the transition probability matrix(ices) for this branch if (parent != null && updateNode[nodeNum]) { final double branchRate = branchRateModel.getBranchRate(tree, node); // Get the operational time of the branch final double branchTime = branchRate * (tree.getNodeHeight(parent) - tree.getNodeHeight(node)); if (branchTime < 0.0) { throw new RuntimeException("Negative branch length: " + branchTime); } likelihoodCore.setNodeMatrixForUpdate(nodeNum); for (int i = 0; i < categoryCount; i++) { double branchLength = siteModel.getRateForCategory(i) * branchTime; siteModel.getSubstitutionModel().getTransitionProbabilities(branchLength, probabilities); likelihoodCore.setNodeMatrix(nodeNum, i, probabilities); } update = true; } // If the node is internal, update the partial likelihoods. if (!tree.isExternal(node)) { // Traverse down the two child nodes NodeRef child1 = tree.getChild(node, 0); final boolean update1 = traverse(tree, child1); NodeRef child2 = tree.getChild(node, 1); final boolean update2 = traverse(tree, child2); // If either child node was updated then update this node too if (update1 || update2) { final int childNum1 = child1.getNumber(); final int childNum2 = child2.getNumber(); likelihoodCore.setNodePartialsForUpdate(nodeNum); if (integrateAcrossCategories) { likelihoodCore.calculatePartials(childNum1, childNum2, nodeNum); } else { likelihoodCore.calculatePartials(childNum1, childNum2, nodeNum, siteCategories); } if (COUNT_TOTAL_OPERATIONS) { totalOperationCount ++; } if (parent == null) { // No parent this is the root of the tree - // calculate the pattern likelihoods double[] frequencies = frequencyModel.getFrequencies(); double[] partials = getRootPartials(); likelihoodCore.calculateLogLikelihoods(partials, frequencies, patternLogLikelihoods); } update = true; } } return update; } public final double[] getRootPartials() { if (rootPartials == null) { rootPartials = new double[patternCount * stateCount]; } int nodeNum = treeModel.getRoot().getNumber(); if (integrateAcrossCategories) { // moved this call to here, because non-integrating siteModels don't need to support it - AD double[] proportions = siteModel.getCategoryProportions(); likelihoodCore.integratePartials(nodeNum, proportions, rootPartials); } else { likelihoodCore.getPartials(nodeNum, rootPartials); } return rootPartials; } /** * the root partial likelihoods (a temporary array that is used * to fetch the partials - it should not be examined directly - * use getRootPartials() instead). */ private double[] rootPartials = null; public class SiteLikelihoodsStatistic extends Statistic.Abstract { public SiteLikelihoodsStatistic() { super("siteLikelihoods"); } public int getDimension() { if (patternList instanceof SitePatterns) { return ((SitePatterns)patternList).getSiteCount(); } else { return patternList.getPatternCount(); } } public String getDimensionName(int dim) { return getTreeModel().getId() + "site-" + dim; } public double getStatisticValue(int i) { if (patternList instanceof SitePatterns) { int index = ((SitePatterns)patternList).getPatternIndex(i); if( index >= 0 ) { return patternLogLikelihoods[index] / patternWeights[index]; } else { return 0.0; } } else { return patternList.getPatternCount(); } } } // ************************************************************** // INSTANCE VARIABLES // ************************************************************** /** * the frequency model for these sites */ protected final FrequencyModel frequencyModel; /** * the site model for these sites */ protected final SiteModel siteModel; /** * the branch rate model */ protected final BranchRateModel branchRateModel; /** * the tip partials model */ private final TipStatesModel tipStatesModel; private final boolean storePartials; protected final boolean integrateAcrossCategories; /** * the categories for each site */ protected int[] siteCategories = null; /** * the pattern likelihoods */ protected double[] patternLogLikelihoods = null; /** * the number of rate categories */ protected int categoryCount; /** * an array used to transfer transition probabilities */ protected double[] probabilities; /** * an array used to transfer tip partials */ protected double[] tipPartials; /** * the LikelihoodCore */ protected LikelihoodCore likelihoodCore; }