/* * CodonPartitionedRobustCounting.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.substmodel; import dr.evomodel.siteratemodel.SiteRateModel; import dr.evomodel.treelikelihood.AncestralStateBeagleTreeLikelihood; import dr.evomodel.treelikelihood.utilities.TreeTraitLogger; import dr.evolution.datatype.Codons; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.tree.TreeTrait; import dr.evolution.tree.TreeTraitProvider; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.tree.TreeModel; import dr.inference.loggers.LogColumn; import dr.inference.loggers.Loggable; import dr.inference.markovjumps.StateHistory; import dr.inference.model.AbstractModel; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.math.MathUtils; import dr.util.Citable; import dr.util.Citation; import dr.util.CommonCitations; import java.util.ArrayList; import java.util.Collections; import java.util.List; /** * @author Marc A. Suchard * @author Vladimir Minin * <p/> * A class for implementing robust counting for synonymous and nonsynonymous changes in BEAST using BEAGLE * This work is supported by NSF grant 0856099 * <p/> * O'Brien JD, Minin VN and Suchard MA (2009) Learning to count: robust estimates for labeled distances between * molecular sequences. Molecular Biology and Evolution, 26, 801-814 */ public class CodonPartitionedRobustCounting extends AbstractModel implements TreeTraitProvider, Loggable, Citable { private static final boolean DEBUG = false; public static final String UNCONDITIONED_PREFIX = "u_"; public static final String SITE_SPECIFIC_PREFIX = "c_"; public static final String TOTAL_PREFIX = "total_"; public static final String UNCONDITIONED_TOTAL_PREFIX = "utotal_"; public static final String BASE_TRAIT_PREFIX = "base_"; public static final String COMPLETE_HISTORY_PREFIX = "all_"; public static final String UNCONDITIONED_PER_BRANCH_PREFIX = "b_u_"; // public CodonPartitionedRobustCounting(String name, TreeModel tree, // AncestralStateBeagleTreeLikelihood[] partition, // Codons codons, // CodonLabeling codonLabeling, // boolean useUniformization) { // this(name, tree, partition, codons, codonLabeling, useUniformization, // StratifiedTraitOutputFormat.SUM_OVER_SITES, StratifiedTraitOutputFormat.SUM_OVER_SITES); // // } public CodonPartitionedRobustCounting(String name, TreeModel tree, AncestralStateBeagleTreeLikelihood[] partition, Codons codons, CodonLabeling codonLabeling, boolean useUniformization, boolean includeExternalBranches, boolean includeInternalBranches, boolean doUnconditionalPerBranch, boolean saveCompleteHistory, boolean tryNewNeutralModel, StratifiedTraitOutputFormat branchFormat, StratifiedTraitOutputFormat logFormat, String prefix) { this(name, tree, partition, codons, codonLabeling, useUniformization, includeExternalBranches, includeInternalBranches, doUnconditionalPerBranch, saveCompleteHistory, false, tryNewNeutralModel, branchFormat, logFormat, prefix); } public CodonPartitionedRobustCounting(String name, TreeModel tree, AncestralStateBeagleTreeLikelihood[] partition, Codons codons, CodonLabeling codonLabeling, boolean useUniformization, boolean includeExternalBranches, boolean includeInternalBranches, boolean doUnconditionalPerBranch, boolean saveCompleteHistory, boolean forceUnconditionalAverageRate, boolean tryNewNeutralModel, StratifiedTraitOutputFormat branchFormat, StratifiedTraitOutputFormat logFormat, String prefix) { super(name); this.tree = tree; addModel(tree); if (partition.length != 3) { throw new RuntimeException("CodonPartition models require 3 partitions"); } this.partition = partition; this.codonLabeling = codonLabeling; branchRateModel = partition[0].getBranchRateModel(); addModel(branchRateModel); List<SubstitutionModel> substModelsList = new ArrayList<SubstitutionModel>(3); List<SiteRateModel> siteRateModelsList = new ArrayList<SiteRateModel>(3); numCodons = partition[0].getPatternWeights().length; for (int i = 0; i < 3; i++) { substModelsList.add(partition[i].getBranchModel().getRootSubstitutionModel()); siteRateModelsList.add(partition[i].getSiteRateModel()); if (partition[i].getPatternWeights().length != numCodons) { throw new RuntimeException("All sequence lengths must be equal in CodonPartitionedRobustCounting"); } } this.saveCompleteHistory = saveCompleteHistory; productChainModel = new ProductChainSubstitutionModel("codonLabeling", substModelsList, siteRateModelsList, false); addModel(productChainModel); this.forceUnconditionalAverageRate = forceUnconditionalAverageRate; if (forceUnconditionalAverageRate) { averagedProductChainModel = new ProductChainSubstitutionModel("codonLabeling", substModelsList, siteRateModelsList, true); addModel(averagedProductChainModel); } this.useUniformization = useUniformization; if (useUniformization) { markovJumps = new UniformizedSubstitutionModel(productChainModel); ((UniformizedSubstitutionModel) markovJumps).setSaveCompleteHistory(saveCompleteHistory); if (forceUnconditionalAverageRate) { averagedMarkovJumps = new UniformizedSubstitutionModel(averagedProductChainModel); ((UniformizedSubstitutionModel) averagedMarkovJumps).setSaveCompleteHistory(saveCompleteHistory); } } else { markovJumps = new MarkovJumpsSubstitutionModel(productChainModel); if (forceUnconditionalAverageRate) { averagedMarkovJumps = new MarkovJumpsSubstitutionModel(averagedProductChainModel); } } double[] synRegMatrix = CodonLabeling.getRegisterMatrix(codonLabeling, codons, true); markovJumps.setRegistration(synRegMatrix); condMeanMatrix = new double[64 * 64]; this.branchFormat = branchFormat; this.logFormat = logFormat; computedCounts = new double[tree.getNodeCount()][]; // TODO Temporary until there exists a helper class this.includeExternalBranches = includeExternalBranches; this.includeInternalBranches = includeInternalBranches; this.doUnconditionedPerBranch = doUnconditionalPerBranch; this.tryNewNeutralModel = tryNewNeutralModel; //this.neutralSubstitutionModel = null; // new ComplexSubstitutionModel(); this.prefix = prefix; setupTraits(); } public double[] getUnconditionalCountsForBranch(NodeRef child) { if (!unconditionsPerBranchKnown) { computeAllUnconditionalCountsPerBranch(); unconditionsPerBranchKnown = true; } return unconditionedCountsPerBranch[child.getNumber()]; } public double[] getExpectedCountsForBranch(NodeRef child) { // TODO This function will implement TraitProvider if (!countsKnown) { computeAllExpectedCounts(); } return computedCounts[child.getNumber()]; } private void computeAllExpectedCounts() { for (int i = 0; i < tree.getNodeCount(); i++) { NodeRef child = tree.getNode(i); if (!tree.isRoot(child)) { computedCounts[child.getNumber()] = computeExpectedCountsForBranch(child); } } countsKnown = true; } private double[] computeExpectedCountsForBranch(NodeRef child) { // Get child node reconstructed sequence final int[] childSeq0 = partition[0].getStatesForNode(tree, child); final int[] childSeq1 = partition[1].getStatesForNode(tree, child); final int[] childSeq2 = partition[2].getStatesForNode(tree, child); // Get parent node reconstructed sequence final NodeRef parent = tree.getParent(child); final int[] parentSeq0 = partition[0].getStatesForNode(tree, parent); final int[] parentSeq1 = partition[1].getStatesForNode(tree, parent); final int[] parentSeq2 = partition[2].getStatesForNode(tree, parent); double branchRateTime = branchRateModel.getBranchRate(tree, child) * tree.getBranchLength(child); double[] count = new double[numCodons]; if (!useUniformization) { markovJumps.computeCondStatMarkovJumps(branchRateTime, condMeanMatrix); } else { // Fill condMeanMatrix with transition probabilities markovJumps.getSubstitutionModel().getTransitionProbabilities(branchRateTime, condMeanMatrix); } for (int i = 0; i < numCodons; i++) { // Construct this child and parent codon final int childState = getCanonicalState(childSeq0[i], childSeq1[i], childSeq2[i]); final int parentState = getCanonicalState(parentSeq0[i], parentSeq1[i], parentSeq2[i]); // final int vChildState = getVladimirState(childSeq0[i], childSeq1[i], childSeq2[i]); // final int vParentState = getVladimirState(parentSeq0[i], parentSeq1[i], parentSeq2[i]); final double codonCount; if (!useUniformization) { codonCount = condMeanMatrix[parentState * 64 + childState]; } else { codonCount = ((UniformizedSubstitutionModel) markovJumps).computeCondStatMarkovJumps( parentState, childState, branchRateTime, condMeanMatrix[parentState * 64 + childState] ); } if (useUniformization && saveCompleteHistory) { UniformizedSubstitutionModel usModel = (UniformizedSubstitutionModel) markovJumps; if (completeHistoryPerNode == null) { completeHistoryPerNode = new String[tree.getNodeCount()][numCodons]; } StateHistory history = usModel.getStateHistory(); // Only report syn or nonsyn changes double[] register = usModel.getRegistration(); history = history.filterChanges(register); int historyCount = history.getNumberOfJumps(); if (historyCount > 0) { double parentTime = tree.getNodeHeight(tree.getParent(child)); double childTime = tree.getNodeHeight(child); history.rescaleTimesOfEvents(parentTime, childTime); int n = history.getNumberOfJumps(); // MAS may have broken the next line String hstring = history.toStringChanges(i + 1, usModel.dataType, false); if (DEBUG) { System.err.println("site " + (i + 1) + " : " + history.getNumberOfJumps() + " : " + history.toStringChanges(i + 1, usModel.dataType) + " " + codonLabeling.getText()); } completeHistoryPerNode[child.getNumber()][i] = hstring; } else { completeHistoryPerNode[child.getNumber()][i] = null; } } count[i] = codonCount; } return count; } private void setupTraits() { TreeTrait baseTrait = new TreeTrait.DA() { public String getTraitName() { return BASE_TRAIT_PREFIX + codonLabeling.getText(); } public Intent getIntent() { return Intent.BRANCH; } public double[] getTrait(Tree tree, NodeRef node) { return getExpectedCountsForBranch(node); } public boolean getLoggable() { return false; } }; if (saveCompleteHistory) { TreeTrait stringTrait = new TreeTrait.SA() { public String getTraitName() { return COMPLETE_HISTORY_PREFIX + codonLabeling.getText(); } public Intent getIntent() { return Intent.BRANCH; } public boolean getFormatAsArray() { return true; } public String[] getTrait(Tree tree, NodeRef node) { double[] count = getExpectedCountsForBranch(node); // Lazy simulation of complete histories List<String> events = new ArrayList<String>(); for (int i = 0; i < numCodons; i++) { String eventString = completeHistoryPerNode[node.getNumber()][i]; if (eventString != null) { if (eventString.contains("},{")) { // There are multiple events String[] elements = eventString.split("(?<=\\}),(?=\\{)"); for (String e : elements) { events.add(e); } } else { events.add(eventString); } } } if (DEBUG) { double sum = 0.0; for (double d : count) { if (d > 0.0) { sum += 1; } } System.err.println(events.size() + " " + sum); if (Math.abs(events.size() - sum) > 0.5) { System.err.println("Error"); for (int i = 0; i < count.length; ++i) { if (count[i] != 0.0) { System.err.println(i + ": " + count[i] + completeHistoryPerNode[node.getNumber()][i]); } } System.err.println("Error"); int c = 0; for (String s : events) { c++; System.err.println(c + ":" + s); } System.exit(-1); } } String[] array = new String[events.size()]; events.toArray(array); return array; } public boolean getLoggable() { return true; } }; treeTraits.addTrait(stringTrait); } TreeTrait unconditionedSum; if (!TRIAL) { unconditionedSum = new TreeTrait.D() { public String getTraitName() { return UNCONDITIONED_PREFIX + codonLabeling.getText(); } public Intent getIntent() { return Intent.WHOLE_TREE; } public Double getTrait(Tree tree, NodeRef node) { return getUnconditionedTraitValue(); } public boolean getLoggable() { return false; } }; } else { unconditionedSum = new TreeTrait.DA() { public String getTraitName() { return UNCONDITIONED_PREFIX + codonLabeling.getText(); } public Intent getIntent() { return Intent.WHOLE_TREE; } public double[] getTrait(Tree tree, NodeRef node) { return getUnconditionedTraitValues(); } public boolean getLoggable() { return false; } }; } TreeTrait sumOverTreeTrait = new TreeTrait.SumOverTreeDA( SITE_SPECIFIC_PREFIX + codonLabeling.getText(), baseTrait, includeExternalBranches, includeInternalBranches) { @Override public boolean getLoggable() { return false; } }; // This should be the default output in tree logs TreeTrait sumOverSitesTrait = new TreeTrait.SumAcrossArrayD( codonLabeling.getText(), baseTrait) { @Override public boolean getLoggable() { return true; } }; // This should be the default output in columns logs String name = prefix != null ? prefix + TOTAL_PREFIX + codonLabeling.getText() : TOTAL_PREFIX + codonLabeling.getText(); TreeTrait sumOverSitesAndTreeTrait = new TreeTrait.SumOverTreeD( name, sumOverSitesTrait, includeExternalBranches, includeInternalBranches) { @Override public boolean getLoggable() { return true; } }; treeTraitLogger = new TreeTraitLogger( tree, new TreeTrait[]{sumOverSitesAndTreeTrait} ); treeTraits.addTrait(baseTrait); treeTraits.addTrait(unconditionedSum); treeTraits.addTrait(sumOverSitesTrait); treeTraits.addTrait(sumOverTreeTrait); treeTraits.addTrait(sumOverSitesAndTreeTrait); if (doUnconditionedPerBranch) { TreeTrait unconditionedBase = new TreeTrait.DA() { public String getTraitName() { return UNCONDITIONED_PER_BRANCH_PREFIX + codonLabeling.getText(); } public Intent getIntent() { return Intent.BRANCH; } public double[] getTrait(Tree tree, NodeRef node) { return getUnconditionalCountsForBranch(node); } public boolean getLoggable() { return false; // TODO Should be switched to true to log unconditioned values per branch } }; TreeTrait sumUnconditionedOverSitesTrait = new TreeTrait.SumAcrossArrayD( UNCONDITIONED_PER_BRANCH_PREFIX + codonLabeling.getText(), unconditionedBase) { @Override public boolean getLoggable() { return true; } }; String nameU = prefix != null ? prefix + UNCONDITIONED_TOTAL_PREFIX + codonLabeling.getText() : UNCONDITIONED_TOTAL_PREFIX + codonLabeling.getText(); TreeTrait sumUnconditionedOverSitesAndTreeTrait = new TreeTrait.SumOverTreeD( nameU, sumUnconditionedOverSitesTrait, includeExternalBranches, includeInternalBranches) { public boolean getLoggable() { return true; } }; treeTraitLogger = new TreeTraitLogger(tree, new TreeTrait[]{sumOverSitesAndTreeTrait, sumUnconditionedOverSitesAndTreeTrait}); treeTraits.addTrait(unconditionedBase); treeTraits.addTrait(sumUnconditionedOverSitesTrait); } } public TreeTrait[] getTreeTraits() { return treeTraits.getTreeTraits(); } public TreeTrait getTreeTrait(String key) { return treeTraits.getTreeTrait(key); } private int getCanonicalState(int i, int j, int k) { return i * 16 + j * 4 + k; } // private int getVladimirState(int i, int j, int k) { // if (i == 1) i = 2; // else if (i == 2) i = 1; // // if (j == 1) j = 2; // else if (j == 2) j = 1; // // if (k == 1) k = 2; // else if (k == 2) k = 1; // // return i * 16 + j * 4 + k + 1; // } public LogColumn[] getColumns() { return treeTraitLogger.getColumns(); } public int getDimension() { return numCodons; } private void computeAllUnconditionalCountsPerBranch() { if (unconditionedCountsPerBranch == null) { unconditionedCountsPerBranch = new double[tree.getNodeCount()][numCodons]; } double[] rootDistribution = getUnconditionalRootDistribution(); for (int i = 0; i < tree.getNodeCount(); i++) { NodeRef node = tree.getNode(i); if (!tree.isRoot(node)) { final double expectedLength = getExpectedBranchLength(node); fillInUnconditionalTraitValues(expectedLength, rootDistribution, unconditionedCountsPerBranch[node.getNumber()]); } } } private void computeUnconditionedTraitValues() { if (unconditionedCounts == null) { unconditionedCounts = new double[numCodons]; } final double treeLength = getExpectedTreeLength(); double[] rootDistribution = getUnconditionalRootDistribution(); // final int stateCount = 64; // double[] lambda = new double[stateCount * stateCount]; // productChainModel.getInfinitesimalMatrix(lambda); // for (int i = 0; i < numCodons; i++) { // final int startingState = MathUtils.randomChoicePDF(rootDistribution); // StateHistory history = StateHistory.simulateUnconditionalOnEndingState( // 0.0, // startingState, // treeLength, // lambda, // stateCount // ); // unconditionedCounts[i] = markovJumps.getProcessForSimulant(history); // } fillInUnconditionalTraitValues(treeLength, rootDistribution, unconditionedCounts); } private double[] getUnconditionalRootDistribution() { if (forceUnconditionalAverageRate) { return averagedProductChainModel.getFrequencyModel().getFrequencies(); } else { return productChainModel.getFrequencyModel().getFrequencies(); } } private void fillInUnconditionalQMatrix(double[] lambda) { if (forceUnconditionalAverageRate) { averagedProductChainModel.getInfinitesimalMatrix(lambda); } else { productChainModel.getInfinitesimalMatrix(lambda); } } private void fillInUnconditionalTraitValues(double expectedLength, double[] freq, double[] out) { final int stateCount = 64; double[] lambda = new double[stateCount * stateCount]; fillInUnconditionalQMatrix(lambda); for (int i = 0; i < numCodons; i++) { final int startingState = MathUtils.randomChoicePDF(freq); StateHistory history = StateHistory.simulateUnconditionalOnEndingState( 0.0, startingState, expectedLength, lambda, stateCount ); out[i] = markovJumps.getProcessForSimulant(history); } } private double[] getUnconditionedTraitValues() { if (!unconditionsKnown) { computeUnconditionedTraitValues(); unconditionsKnown = true; } return unconditionedCounts; } public Double getUnconditionedTraitValue() { if (!TRIAL) { throw new RuntimeException("Believed broken for neutral models"); // return markovJumps.getMarginalRate() * getExpectedTreeLength(); } else { final double treeLength = getExpectedTreeLength(); double[] rootDistribution = getUnconditionalRootDistribution(); final int startingState = MathUtils.randomChoicePDF(rootDistribution); final int stateCount = 64; double[] lambda = new double[stateCount * stateCount]; fillInUnconditionalQMatrix(lambda); StateHistory history = StateHistory.simulateUnconditionalOnEndingState( 0.0, startingState, treeLength, lambda, stateCount ); return markovJumps.getProcessForSimulant(history); } } private double getExpectedBranchLength(NodeRef node) { return branchRateModel.getBranchRate(tree, node) * tree.getBranchLength(node); } private double getExpectedTreeLength() { double expectedTreeLength = 0; if (includeExternalBranches) { for (int i = 0; i < tree.getExternalNodeCount(); i++) { NodeRef node = tree.getExternalNode(i); expectedTreeLength += getExpectedBranchLength(node); } } if (includeInternalBranches) { for (int i = 0; i < tree.getInternalNodeCount(); i++) { NodeRef node = tree.getInternalNode(i); if (!tree.isRoot(node)) { expectedTreeLength += getExpectedBranchLength(node); } } } return expectedTreeLength; } protected void handleModelChangedEvent(Model model, Object object, int index) { countsKnown = false; unconditionsKnown = false; unconditionsPerBranchKnown = false; } protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { countsKnown = false; unconditionsKnown = false; } protected void storeState() { // Do nothing } protected void restoreState() { countsKnown = false; unconditionsKnown = false; unconditionsPerBranchKnown = false; } protected void acceptState() { // Do nothing } private final AncestralStateBeagleTreeLikelihood[] partition; private final MarkovJumpsSubstitutionModel markovJumps; private MarkovJumpsSubstitutionModel averagedMarkovJumps = null; private final boolean forceUnconditionalAverageRate; private final boolean useUniformization; private final BranchRateModel branchRateModel; private final ProductChainSubstitutionModel productChainModel; private ProductChainSubstitutionModel averagedProductChainModel = null; private final CodonLabeling codonLabeling; private final Tree tree; private final String prefix; private final StratifiedTraitOutputFormat branchFormat; private final StratifiedTraitOutputFormat logFormat; private final double[] condMeanMatrix; private int numCodons; private boolean countsKnown = false; private boolean unconditionsKnown = false; private boolean unconditionsPerBranchKnown = false; private double[] unconditionedCounts; private double[][] unconditionedCountsPerBranch; private double[][] computedCounts; // TODO Temporary storage until generic TreeTraitProvider/Helpers are finished private String[][] completeHistoryPerNode; protected Helper treeTraits = new Helper(); protected TreeTraitLogger treeTraitLogger; private final boolean includeExternalBranches; private final boolean includeInternalBranches; private final boolean doUnconditionedPerBranch; private static final boolean TRIAL = true; private boolean saveCompleteHistory = false; private boolean tryNewNeutralModel = false; @Override public Citation.Category getCategory() { return Citation.Category.COUNTING_PROCESSES; } @Override public String getDescription() { StringBuilder sb = new StringBuilder("Using robust counting (first citation) for labeled distances between sequences" + " to efficiently estimate site-specific dN/dS rate ratios (second citation)"); if (saveCompleteHistory) { sb.append(" and inferring the complete transition history (third citation)"); } return sb.toString(); } /** * @return a list of citations associated with this object */ @Override public List<Citation> getCitations() { List<Citation> list = new ArrayList<Citation>(); list.add(CommonCitations.OBRIEN_2009_LEARNING); list.add(CommonCitations.LEMEY_2012_RENAISSANCE); if (saveCompleteHistory) { list.add(CommonCitations.BLOOM_2013_STABILITY); } return list; } }