/* * AbstractTreeLikelihood.java * * Copyright (C) 2002-2012 Alexei Drummond, Andrew Rambaut & Marc A. Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.app.beagle.evomodel.treelikelihood; import dr.evolution.alignment.PatternList; import dr.evolution.datatype.DataType; import dr.evolution.tree.NodeRef; import dr.evomodel.tree.TreeModel; import dr.inference.model.AbstractModelLikelihood; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.xml.Reportable; /** * AbstractTreeLikelihood - a base class for likelihood calculators of sites on a tree. * * @author Andrew Rambaut * @author Marc Suchard * @version $Id: AbstractTreeLikelihood.java,v 1.16 2005/06/07 16:27:39 alexei Exp $ */ public abstract class AbstractTreeLikelihood extends AbstractModelLikelihood implements Reportable { protected static final boolean COUNT_TOTAL_OPERATIONS = false; public AbstractTreeLikelihood(String name, PatternList patternList, TreeModel treeModel) { super(name); this.patternList = patternList; this.dataType = patternList.getDataType(); patternCount = patternList.getPatternCount(); stateCount = dataType.getStateCount(); patternWeights = patternList.getPatternWeights(); this.treeModel = treeModel; addModel(treeModel); nodeCount = treeModel.getNodeCount(); updateNode = new boolean[nodeCount]; for (int i = 0; i < nodeCount; i++) { updateNode[i] = true; } likelihoodKnown = false; } /** * Set update flag for a node and its children */ protected void updateNode(NodeRef node) { updateNode[node.getNumber()] = true; likelihoodKnown = false; } /** * Set update flag for a node and its direct children */ protected void updateNodeAndChildren(NodeRef node) { updateNode[node.getNumber()] = true; for (int i = 0; i < treeModel.getChildCount(node); i++) { NodeRef child = treeModel.getChild(node, i); updateNode[child.getNumber()] = true; } likelihoodKnown = false; } /** * Set update flag for a node and all its descendents */ protected void updateNodeAndDescendents(NodeRef node) { updateNode[node.getNumber()] = true; for (int i = 0; i < treeModel.getChildCount(node); i++) { NodeRef child = treeModel.getChild(node, i); updateNodeAndDescendents(child); } likelihoodKnown = false; } /** * Set update flag for all nodes */ protected void updateAllNodes() { for (int i = 0; i < nodeCount; i++) { updateNode[i] = true; } likelihoodKnown = false; } /** * Set update flag for a pattern */ protected void updatePattern(int i) { if (updatePattern != null) { updatePattern[i] = true; } likelihoodKnown = false; } /** * Set update flag for all patterns */ protected void updateAllPatterns() { if (updatePattern != null) { for (int i = 0; i < patternCount; i++) { updatePattern[i] = true; } } likelihoodKnown = false; } public final double[] getPatternWeights() { return patternWeights; } // ************************************************************** // VariableListener IMPLEMENTATION // ************************************************************** protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { // do nothing } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** protected void handleModelChangedEvent(Model model, Object object, int index) { if (COUNT_TOTAL_OPERATIONS) totalModelChangedCount++; likelihoodKnown = false; } /** * Stores the additional state other than model components */ protected void storeState() { storedLikelihoodKnown = likelihoodKnown; storedLogLikelihood = logLikelihood; } /** * Restore the additional stored state */ protected void restoreState() { likelihoodKnown = storedLikelihoodKnown; logLikelihood = storedLogLikelihood; } protected void acceptState() { } // nothing to do // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** public final Model getModel() { return this; } public final double getLogLikelihood() { if (COUNT_TOTAL_OPERATIONS) totalGetLogLikelihoodCount++; if (!likelihoodKnown) { if (COUNT_TOTAL_OPERATIONS) totalcalculateLikelihoodCount++; logLikelihood = calculateLogLikelihood(); likelihoodKnown = true; } return logLikelihood; } /** * Forces a complete recalculation of the likelihood next time getLikelihood is called */ public void makeDirty() { if (COUNT_TOTAL_OPERATIONS) totalMakeDirtyCount++; likelihoodKnown = false; updateAllNodes(); updateAllPatterns(); } protected abstract double calculateLogLikelihood(); public String getReport() { if (hasInitialized) { String rtnValue = getClass().getName() + "(" + getLogLikelihood() + ")"; if (COUNT_TOTAL_OPERATIONS) rtnValue += " total operations = " + totalOperationCount + " matrix updates = " + totalMatrixUpdateCount + " model changes = " + totalModelChangedCount + " make dirties = " + totalMakeDirtyCount + " calculate likelihoods = " + totalcalculateLikelihoodCount + " get likelihoods = " + totalGetLogLikelihoodCount + " all rate updates = " + totalRateUpdateAllCount + " partial rate updates = " + totalRateUpdateSingleCount; return rtnValue; } else { return getClass().getName() + "(uninitialized)"; } } // ************************************************************** // INSTANCE VARIABLES // ************************************************************** /** * the tree */ protected TreeModel treeModel = null; /** * the patternList */ protected PatternList patternList = null; protected DataType dataType = null; /** * the pattern weights */ protected double[] patternWeights; /** * the number of patterns */ protected int patternCount; /** * the number of states in the data */ protected int stateCount; /** * the number of nodes in the tree */ protected int nodeCount; /** * Flags to specify which patterns are to be updated */ protected boolean[] updatePattern = null; /** * Flags to specify which nodes are to be updated */ protected boolean[] updateNode; private double logLikelihood; private double storedLogLikelihood; protected boolean likelihoodKnown = false; private boolean storedLikelihoodKnown = false; protected boolean hasInitialized = false; protected int totalOperationCount = 0; protected int totalMatrixUpdateCount = 0; protected int totalGetLogLikelihoodCount = 0; protected int totalModelChangedCount = 0; protected int totalMakeDirtyCount = 0; protected int totalcalculateLikelihoodCount = 0; protected int totalRateUpdateAllCount = 0; protected int totalRateUpdateSingleCount = 0; }