/* * AbstractObservationProcess.java * * Copyright (c) 2002-2016 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.MSSD; import dr.evomodel.siteratemodel.SiteRateModel; import dr.evolution.alignment.AscertainedSitePatterns; import dr.evolution.alignment.PatternList; import dr.evolution.datatype.MutationDeathType; import dr.evolution.tree.NodeRef; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.branchratemodel.DefaultBranchRateModel; import dr.evomodel.tree.TreeModel; import dr.inference.model.AbstractModel; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.math.GammaFunction; /** * Package: AbstractObservationProcess * Description: * <p/> * <p/> * Created by * Alexander V. Alekseyenko (alexander.alekseyenko@gmail.com) * Date: Feb 19, 2008 * Time: 12:41:01 PM */ abstract public class AbstractObservationProcess extends AbstractModel { protected boolean[] nodePatternInclusion; protected boolean[] storedNodePatternInclusion; protected double[] cumLike; protected double[] nodePartials; protected double[] nodeLikelihoods; protected int nodeCount; protected int patternCount; protected int stateCount; protected TreeModel treeModel; protected PatternList patterns; protected double[] patternWeights; protected Parameter mu; protected Parameter lam; // update control variables protected boolean weightKnown; protected double logTreeWeight; protected double storedLogTreeWeight; private double gammaNorm; private double totalPatterns; protected MutationDeathType dataType; protected int deathState; protected SiteRateModel siteModel; private double logN; protected boolean nodePatternInclusionKnown = false; BranchRateModel branchRateModel; public AbstractObservationProcess(String Name, TreeModel treeModel, PatternList patterns, SiteRateModel siteModel, BranchRateModel branchRateModel, Parameter mu, Parameter lam) { super(Name); this.treeModel = treeModel; this.patterns = patterns; this.mu = mu; this.lam = lam; this.siteModel = siteModel; if (branchRateModel != null) { this.branchRateModel = branchRateModel; } else { this.branchRateModel = new DefaultBranchRateModel(); } addModel(treeModel); addModel(siteModel); addModel(this.branchRateModel); addVariable(mu); addVariable(lam); nodeCount = treeModel.getNodeCount(); stateCount = patterns.getDataType().getStateCount(); this.patterns = patterns; patternCount = patterns.getPatternCount(); patternWeights = patterns.getPatternWeights(); totalPatterns = 0; for (int i = 0; i < patternCount; ++i) { totalPatterns += patternWeights[i]; } logN = Math.log(totalPatterns); gammaNorm = -GammaFunction.lnGamma(totalPatterns + 1); dataType = (MutationDeathType) patterns.getDataType(); this.deathState = dataType.DEATHSTATE; setNodePatternInclusion(); cumLike = new double[patternCount]; nodeLikelihoods = new double[patternCount]; weightKnown = false; } // public Parameter getMuParameter() { // return mu; // } // // public Parameter getLamParameter() { // return lam; // } private double calculateSiteLogLikelihood(int site, double[] partials, double[] frequencies) { int v = site * stateCount; double sum = 0.0; for (int i = 0; i < stateCount; i++) { sum += frequencies[i] * partials[v + i]; } return Math.log(sum); } // @todo needs updating to use BEAGLE // private void calculateNodePatternLikelihood(int nodeIndex, // double[] freqs, // LikelihoodCore likelihoodCore, // double averageRate, // double[] cumLike) { // // get partials for node nodeIndex // likelihoodCore.getPartials(nodeIndex, nodePartials); // MAS // /* // multiply the partials by equilibrium probs // this part could be optimized by first summing // and then multiplying by equilibrium probs // */ // double prob = Math.log(getNodeSurvivalProbability(nodeIndex, averageRate)); // // for (int j = 0; j < patternCount; ++j) { // if (nodePatternInclusion[nodeIndex * patternCount + j]) { // cumLike[j] += Math.exp(calculateSiteLogLikelihood(j, nodePartials, freqs) + prob); // } // } // } // // private double accumulateCorrectedLikelihoods(double[] cumLike, double ascertainmentCorrection, // double[] patterWeights) { // double logL = 0; // for (int j = 0; j < patternCount; ++j) { // logL += Math.log(cumLike[j] / ascertainmentCorrection) * patternWeights[j]; // } // return logL; // } // // public final double nodePatternLikelihood(double[] freqs, LikelihoodPartialsProvider likelihoodCore, // ScaleFactorsHelper scaleFactorsHelper) { // int i, j; // double logL = gammaNorm; // // double birthRate = lam.getParameterValue(0); // double logProb; // if (!nodePatternInclusionKnown) // setNodePatternInclusion(); // if (nodePartials == null) { // nodePartials = new double[patternCount * stateCount]; // } // // double averageRate = getAverageRate(); // // for (j = 0; j < patternCount; ++j) cumLike[j] = 0; // // for (i = 0; i < nodeCount; ++i) { // // get partials for node i // likelihoodCore.getPartials(i, nodePartials); // scaleFactorsHelper.rescalePartials(i, nodePartials); // /* // multiply the partials by equilibrium probs // this part could be optimized by first summing // and then multiplying by equilibrium probs // */ //// likelihoodCore.calculateLogLikelihoods(nodePartials, freqs, nodeLikelihoods); // MAS Removed // logProb = Math.log(getNodeSurvivalProbability(i, averageRate)); // // for (j = 0; j < patternCount; ++j) { // if (nodePatternInclusion[i * patternCount + j]) { //// cumLike[j] += Math.exp(nodeLikelihoods[j] + logProb); // MAS Replaced with line below // cumLike[j] += Math.exp(calculateSiteLogLikelihood(j, nodePartials, freqs) // + logProb); // } // } // } // // double ascertainmentCorrection = getAscertainmentCorrection(cumLike); //// System.err.println("AscertainmentCorrection: "+ascertainmentCorrection); // // for (j = 0; j < patternCount; ++j) { // logL += Math.log(cumLike[j] / ascertainmentCorrection) * patternWeights[j]; // } // // double deathRate = mu.getParameterValue(0); // // double logTreeWeight = getLogTreeWeight(); // // if (integrateGainRate) { // logL -= gammaNorm + logN + Math.log(-logTreeWeight * deathRate / birthRate) * totalPatterns; // } else { // logL += logTreeWeight + Math.log(birthRate / deathRate) * totalPatterns; // } // return logL; // } protected double getAscertainmentCorrection(double[] patternProbs) { // This function probably belongs better to the AscertainedSitePatterns double excludeProb = 0, includeProb = 0, returnProb = 1.0; if (this.patterns instanceof AscertainedSitePatterns) { int[] includeIndices = ((AscertainedSitePatterns) patterns).getIncludePatternIndices(); int[] excludeIndices = ((AscertainedSitePatterns) patterns).getExcludePatternIndices(); for (int i = 0; i < ((AscertainedSitePatterns) patterns).getIncludePatternCount(); i++) { int index = includeIndices[i]; includeProb += patternProbs[index]; } for (int j = 0; j < ((AscertainedSitePatterns) patterns).getExcludePatternCount(); j++) { int index = excludeIndices[j]; excludeProb += patternProbs[index]; } if (includeProb == 0.0) { returnProb -= excludeProb; } else if (excludeProb == 0.0) { returnProb = includeProb; } else { returnProb = includeProb - excludeProb; } } return returnProb; } final public double getLogTreeWeight() { if (!weightKnown) { logTreeWeight = calculateLogTreeWeight(); weightKnown = true; } return logTreeWeight; } abstract public double calculateLogTreeWeight(); abstract void setNodePatternInclusion(); final public double getAverageRate() { if (!averageRateKnown) { double avgRate = 0.0; double proportions[] = siteModel.getCategoryProportions(); for (int i = 0; i < siteModel.getCategoryCount(); ++i) { avgRate += proportions[i] * siteModel.getRateForCategory(i); } averageRate = avgRate; averageRateKnown = true; } return averageRate; } public double getNodeSurvivalProbability(int index, double averageRate) { NodeRef node = treeModel.getNode(index); NodeRef parent = treeModel.getParent(node); if (parent == null) return 1.0; final double deathRate = mu.getParameterValue(0) * averageRate; //getAverageRate(); final double branchRate = branchRateModel.getBranchRate(treeModel, node); // Get the operational time of the branch final double branchTime = branchRate * treeModel.getBranchLength(node); return 1.0 - Math.exp(-deathRate * branchTime); } protected void handleModelChangedEvent(Model model, Object object, int index) { if (model == siteModel) { averageRateKnown = false; } if (model == treeModel || model == siteModel || model == branchRateModel) { weightKnown = false; } if (model == treeModel) { if (object instanceof TreeModel.TreeChangedEvent) { if (((TreeModel.TreeChangedEvent) object).isTreeChanged()) { nodePatternInclusionKnown = false; } } } } protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { if (variable == mu || variable == lam) { weightKnown = false; } else { System.err.println("AbstractObservationProcess: Got unexpected parameter changed event. (Parameter = " + variable + ")"); } } protected void storeState() { // storedAverageRate = averageRate; storedLogTreeWeight = logTreeWeight; System.arraycopy(nodePatternInclusion, 0, storedNodePatternInclusion, 0, storedNodePatternInclusion.length); } protected void restoreState() { // averageRate = storedAverageRate; averageRateKnown = false; logTreeWeight = storedLogTreeWeight; boolean[] tmp = storedNodePatternInclusion; storedNodePatternInclusion = nodePatternInclusion; nodePatternInclusion = tmp; } protected void acceptState() { } public void setIntegrateGainRate(boolean integrateGainRate) { this.integrateGainRate = integrateGainRate; } private boolean integrateGainRate = false; private double storedAverageRate; private double averageRate; private boolean averageRateKnown = false; }