/* * MultiPartitionTreeLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.treelikelihood; import beagle.*; import dr.evolution.tree.TreeUtils; import dr.evomodel.branchmodel.BranchModel; import dr.evomodel.treedatalikelihood.BufferIndexHelper; import dr.evomodelxml.treelikelihood.BeagleTreeLikelihoodParser; import dr.evomodel.siteratemodel.SiteRateModel; import dr.evolution.alignment.AscertainedSitePatterns; import dr.evolution.alignment.PatternList; import dr.evolution.datatype.DataType; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.Taxon; import dr.evolution.util.TaxonList; import dr.evomodel.branchratemodel.BranchRateModel; import dr.evomodel.branchratemodel.DefaultBranchRateModel; import dr.evomodel.tree.TreeModel; import dr.evomodel.tipstatesmodel.TipStatesModel; import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.ThreadAwareLikelihood; import java.util.*; import java.util.logging.Logger; /** * BeagleTreeLikelihoodModel - implements a Likelihood Function for sequences on a tree. * * @author Andrew Rambaut * @author Alexei Drummond * @author Marc Suchard * @version $Id$ */ @Deprecated @SuppressWarnings("serial") public class MultiPartitionTreeLikelihood extends AbstractTreeLikelihood implements ThreadAwareLikelihood { // This property is a comma-delimited list of resource numbers (0 == CPU) to // allocate each BEAGLE instance to. If less than the number of instances then // will wrap around. private static final String RESOURCE_ORDER_PROPERTY = "beagle.resource.order"; private static final String PREFERRED_FLAGS_PROPERTY = "beagle.preferred.flags"; private static final String REQUIRED_FLAGS_PROPERTY = "beagle.required.flags"; private static final String SCALING_PROPERTY = "beagle.scaling"; private static final String RESCALE_FREQUENCY_PROPERTY = "beagle.rescale"; private static final String EXTRA_BUFFER_COUNT_PROPERTY = "beagle.extra.buffer.count"; private static final String FORCE_VECTORIZATION = "beagle.force.vectorization"; // Which scheme to use if choice not specified (or 'default' is selected): private static final PartialsRescalingScheme DEFAULT_RESCALING_SCHEME = PartialsRescalingScheme.DYNAMIC; private static int instanceCount = 0; private static List<Integer> resourceOrder = null; private static List<Integer> preferredOrder = null; private static List<Integer> requiredOrder = null; private static List<String> scalingOrder = null; private static List<Integer> extraBufferOrder = null; // Default frequency for complete recomputation of scaling factors under the 'dynamic' scheme private static final int RESCALE_FREQUENCY = 100; private static final int RESCALE_TIMES = 1; public MultiPartitionTreeLikelihood(List<PatternList> patternLists, List<SiteRateModel> siteRateModels, TreeModel treeModel, BranchModel branchModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean useAmbiguities, PartialsRescalingScheme rescalingScheme) { this(patternLists, siteRateModels, treeModel, branchModel, branchRateModel, tipStatesModel, useAmbiguities, rescalingScheme, null); } public MultiPartitionTreeLikelihood(List<PatternList> patternLists, List<SiteRateModel> siteRateModels, TreeModel treeModel, BranchModel branchModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean useAmbiguities, PartialsRescalingScheme rescalingScheme, Map<Set<String>, Parameter> partialsRestrictions) { super(BeagleTreeLikelihoodParser.TREE_LIKELIHOOD, treeModel); try { final Logger logger = Logger.getLogger("dr.evomodel"); logger.info("Using BEAGLE TreeLikelihood"); // should be a 1 to 1 correspondence of patternLists to siteModels. assert(patternLists.size() == siteRateModels.size()); this.dataType = patternLists.get(0).getDataType(); this.stateCount = dataType.getStateCount(); partitionCount = patternLists.size(); this.patternLists.addAll(patternLists); for (PatternList patternList : patternLists) { // check all patternLists). assert(patternList.getDataType() == dataType); } this.siteRateModels.addAll(siteRateModels); this.categoryCount = this.siteRateModels.get(0).getCategoryCount(); for (SiteRateModel siteRateModel : siteRateModels) { // check all siteRateModels use the same number of categories // (could be relaxed but this will make for easier bookkeeping). assert(siteRateModel.getCategoryCount() == categoryCount); addModel(siteRateModel); } this.branchModel = branchModel; addModel(this.branchModel); if (branchRateModel != null) { this.branchRateModel = branchRateModel; logger.info(" Branch rate model used: " + branchRateModel.getModelName()); } else { this.branchRateModel = new DefaultBranchRateModel(); } addModel(this.branchRateModel); this.tipStatesModel = tipStatesModel; this.tipCount = treeModel.getExternalNodeCount(); internalNodeCount = nodeCount - tipCount; int compactPartialsCount = tipCount; if (useAmbiguities) { // if we are using ambiguities then we don't use tip partials compactPartialsCount = 0; } // one partials buffer for each tip and two for each internal node (for store restore) partialBufferHelper = new BufferIndexHelper(nodeCount, tipCount); // one scaling buffer for each internal node plus an extra for the accumulation, then doubled for store/restore scaleBufferHelper = new BufferIndexHelper(getScaleBufferCount(), 0); // Attempt to get the resource order from the System Property if (resourceOrder == null) { resourceOrder = parseSystemPropertyIntegerArray(RESOURCE_ORDER_PROPERTY); } if (preferredOrder == null) { preferredOrder = parseSystemPropertyIntegerArray(PREFERRED_FLAGS_PROPERTY); } if (requiredOrder == null) { requiredOrder = parseSystemPropertyIntegerArray(REQUIRED_FLAGS_PROPERTY); } if (scalingOrder == null) { scalingOrder = parseSystemPropertyStringArray(SCALING_PROPERTY); } if (extraBufferOrder == null) { extraBufferOrder = parseSystemPropertyIntegerArray(EXTRA_BUFFER_COUNT_PROPERTY); } int extraBufferCount = -1; // default if (extraBufferOrder.size() > 0) { extraBufferCount = extraBufferOrder.get(instanceCount % extraBufferOrder.size()); } substitutionModelDelegates = new SubstitutionModelDelegate[partitionCount]; for (int i = 0; i < partitionCount; i++) { substitutionModelDelegates[i] = new SubstitutionModelDelegate(treeModel, branchModel, extraBufferCount); } // first set the rescaling scheme to use from the parser this.rescalingScheme = rescalingScheme; int[] resourceList = null; long preferenceFlags = 0; long requirementFlags = 0; if (scalingOrder.size() > 0) { this.rescalingScheme = PartialsRescalingScheme.parseFromString( scalingOrder.get(instanceCount % scalingOrder.size())); } if (resourceOrder.size() > 0) { // added the zero on the end so that a CPU is selected if requested resource fails resourceList = new int[]{resourceOrder.get(instanceCount % resourceOrder.size()), 0}; if (resourceList[0] > 0) { preferenceFlags |= BeagleFlag.PROCESSOR_GPU.getMask(); // Add preference weight against CPU } } if (preferredOrder.size() > 0) { preferenceFlags = preferredOrder.get(instanceCount % preferredOrder.size()); } if (requiredOrder.size() > 0) { requirementFlags = requiredOrder.get(instanceCount % requiredOrder.size()); } // Define default behaviour here if (this.rescalingScheme == PartialsRescalingScheme.DEFAULT) { //if GPU: the default is dynamic scaling in BEAST if (resourceList != null && resourceList[0] > 1) { this.rescalingScheme = DEFAULT_RESCALING_SCHEME; } else { // if CPU: just run as fast as possible // this.rescalingScheme = PartialsRescalingScheme.NONE; // Dynamic should run as fast as none until first underflow this.rescalingScheme = DEFAULT_RESCALING_SCHEME; } } if (this.rescalingScheme == PartialsRescalingScheme.AUTO) { preferenceFlags |= BeagleFlag.SCALING_AUTO.getMask(); useAutoScaling = true; } else { // preferenceFlags |= BeagleFlag.SCALING_MANUAL.getMask(); } String r = System.getProperty(RESCALE_FREQUENCY_PROPERTY); if (r != null) { rescalingFrequency = Integer.parseInt(r); if (rescalingFrequency < 1) { rescalingFrequency = RESCALE_FREQUENCY; } } patternCounts = new int[partitionCount]; int total = 0; for (int i = 0; i < patternLists.size(); i++) { patternCounts[i] = patternLists.get(i).getPatternCount(); total += patternCounts[i]; } patternCount = total; patternWeights = new double[patternCount]; int k = 0; for (PatternList patternList : patternLists) { for (int j = 0; j < patternList.getPatternCount(); j++) { patternWeights[k] = patternList.getPatternWeight(j); k++; } } if (preferenceFlags == 0 && resourceList == null) { // else determine dataset characteristics if (stateCount == 4 && patternCount < 10000) // TODO determine good cut-off preferenceFlags |= BeagleFlag.PROCESSOR_CPU.getMask(); } boolean forceVectorization = false; String vectorizationString = System.getProperty(FORCE_VECTORIZATION); if (vectorizationString != null) { forceVectorization = true; } if (BeagleFlag.VECTOR_SSE.isSet(preferenceFlags) && (stateCount != 4) && !forceVectorization ) { // @todo SSE doesn't seem to work for larger state spaces so for now we override the // SSE option. preferenceFlags &= ~BeagleFlag.VECTOR_SSE.getMask(); preferenceFlags |= BeagleFlag.VECTOR_NONE.getMask(); if (stateCount > 4 && this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) { this.rescalingScheme = PartialsRescalingScheme.DELAYED; } } if (!BeagleFlag.PRECISION_SINGLE.isSet(preferenceFlags)) { // if single precision not explicitly set then prefer double preferenceFlags |= BeagleFlag.PRECISION_DOUBLE.getMask(); } if (substitutionModelDelegates[0].canReturnComplexDiagonalization()) { requirementFlags |= BeagleFlag.EIGEN_COMPLEX.getMask(); } int eigenModelCount = 0; int matrixBufferCount = 0; for (SubstitutionModelDelegate substitutionModelDelegate: substitutionModelDelegates) { eigenModelCount += substitutionModelDelegate.getEigenBufferCount(); matrixBufferCount += substitutionModelDelegate.getMatrixBufferCount(); } instanceCount++; beagle = BeagleFactory.loadBeagleInstance( tipCount, partialBufferHelper.getBufferCount(), compactPartialsCount, stateCount, patternCount, eigenModelCount, matrixBufferCount, categoryCount, scaleBufferHelper.getBufferCount(), // Always allocate; they may become necessary resourceList, preferenceFlags, requirementFlags ); InstanceDetails instanceDetails = beagle.getDetails(); ResourceDetails resourceDetails = null; if (instanceDetails != null) { resourceDetails = BeagleFactory.getResourceDetails(instanceDetails.getResourceNumber()); if (resourceDetails != null) { StringBuilder sb = new StringBuilder(" Using BEAGLE resource "); sb.append(resourceDetails.getNumber()).append(": "); sb.append(resourceDetails.getName()).append("\n"); if (resourceDetails.getDescription() != null) { String[] description = resourceDetails.getDescription().split("\\|"); for (String desc : description) { if (desc.trim().length() > 0) { sb.append(" ").append(desc.trim()).append("\n"); } } } sb.append(" with instance flags: ").append(instanceDetails.toString()); logger.info(sb.toString()); } else { logger.info(" Error retrieving BEAGLE resource for instance: " + instanceDetails.toString()); } } else { logger.info(" No external BEAGLE resources available, or resource list/requirements not met, using Java implementation"); } logger.info(" " + (useAmbiguities ? "Using" : "Ignoring") + " ambiguities in tree likelihood."); if (patternLists.size() > 1) { logger.info(" With " + patternCount + " unique site patterns in " + patternLists.size() + " partitions."); } else { logger.info(" With " + patternCount + " unique site patterns."); } if (tipStatesModel != null) { throw new UnsupportedOperationException("Tip error models not supported by MultiPartitionTreeLikelihood yet"); // tipStatesModel.setTree(treeModel); // // if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) { // tipPartials = new double[patternCount * stateCount]; // } else { // tipStates = new int[patternCount]; // } // // addModel(tipStatesModel); } for (int i = 0; i < tipCount; i++) { // Find the id of tip i in the patternList String id = treeModel.getTaxonId(i); for (PatternList patternList : patternLists) { int index = patternList.getTaxonIndex(id); if (index == -1) { throw new TaxonList.MissingTaxonException("Taxon, " + id + ", in tree, " + treeModel.getId() + ", is not found in patternList, " + patternList.getId()); } else { if (tipStatesModel != null) { throw new UnsupportedOperationException("Tip error models not supported by MultiPartitionTreeLikelihood yet"); // // using a tipPartials model. // // First set the observed states: // tipStatesModel.setStates(patternList, index, i, id); // // if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) { // // Then set the tip partials as determined by the model: // setPartials(beagle, tipStatesModel, i); // } else { // // or the tip states: // tipStatesModel.getTipStates(i, tipStates); // beagle.setTipStates(i, tipStates); // } } else { if (useAmbiguities) { setPartials(beagle, index, i); } else { setStates(beagle, index, i); } } } } } this.partialsRestrictions = partialsRestrictions; // hasRestrictedPartials = (partialsRestrictions != null); if (hasRestrictedPartials) { numRestrictedPartials = partialsRestrictions.size(); updateRestrictedNodePartials = true; partialsMap = new Parameter[treeModel.getNodeCount()]; partials = new double[stateCount * patternCount * categoryCount]; } else { numRestrictedPartials = 0; updateRestrictedNodePartials = false; } beagle.setPatternWeights(patternWeights); String rescaleMessage = " Using rescaling scheme : " + this.rescalingScheme.getText(); if (this.rescalingScheme == PartialsRescalingScheme.AUTO && resourceDetails != null && (resourceDetails.getFlags() & BeagleFlag.SCALING_AUTO.getMask()) == 0) { // If auto scaling in BEAGLE is not supported then do it here this.rescalingScheme = PartialsRescalingScheme.DYNAMIC; rescaleMessage = " Auto rescaling not supported in BEAGLE, using : " + this.rescalingScheme.getText(); } if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) { rescaleMessage += " (rescaling every " + rescalingFrequency + " evaluations)"; } logger.info(rescaleMessage); if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC) { everUnderflowed = false; // If false, BEAST does not rescale until first under-/over-flow. } updateSubstitutionModel = new boolean[partitionCount]; updateSiteModel = new boolean[partitionCount]; for (int i = 0; i < partitionCount; i++) { updateSubstitutionModel[i] = true; updateSiteModel[i] = true; } patternLogLikelihoods = new double[patternCount]; } catch (TaxonList.MissingTaxonException mte) { throw new RuntimeException(mte.toString()); } this.useAmbiguities = useAmbiguities; hasInitialized = true; } private static List<Integer> parseSystemPropertyIntegerArray(String propertyName) { List<Integer> order = new ArrayList<Integer>(); String r = System.getProperty(propertyName); if (r != null) { String[] parts = r.split(","); for (String part : parts) { try { int n = Integer.parseInt(part.trim()); order.add(n); } catch (NumberFormatException nfe) { System.err.println("Invalid entry '" + part + "' in " + propertyName); } } } return order; } private static List<String> parseSystemPropertyStringArray(String propertyName) { List<String> order = new ArrayList<String>(); String r = System.getProperty(propertyName); if (r != null) { String[] parts = r.split(","); for (String part : parts) { try { String s = part.trim(); order.add(s); } catch (NumberFormatException nfe) { System.err.println("Invalid entry '" + part + "' in " + propertyName); } } } return order; } public TipStatesModel getTipStatesModel() { return tipStatesModel; } public TreeModel getTreeModel() { return treeModel; } public BranchModel getBranchModel() { return branchModel; } public BranchRateModel getBranchRateModel() { return branchRateModel; } public PartialsRescalingScheme getRescalingScheme() { return rescalingScheme; } public Map<Set<String>, Parameter> getPartialsRestrictions() { return partialsRestrictions; } public boolean useAmbiguities() { return useAmbiguities; } protected int getScaleBufferCount() { return internalNodeCount + 1; } /** * Sets the partials from a sequence in an alignment. * * @param beagle beagle * @param sequenceIndex sequenceIndex * @param nodeIndex nodeIndex */ protected final void setPartials(Beagle beagle, int sequenceIndex, int nodeIndex) { double[] partials = new double[patternCount * stateCount * categoryCount]; boolean[] stateSet; int v = 0; for (PatternList patternList : patternLists) { for (int i = 0; i < patternList.getPatternCount(); i++) { int state = patternList.getPatternState(sequenceIndex, i); stateSet = dataType.getStateSet(state); for (int j = 0; j < stateCount; j++) { if (stateSet[j]) { partials[v] = 1.0; } else { partials[v] = 0.0; } v++; } } } // if there is more than one category then replicate the partials for each int n = patternCount * stateCount; int k = n; for (int i = 1; i < categoryCount; i++) { System.arraycopy(partials, 0, partials, k, n); k += n; } beagle.setPartials(nodeIndex, partials); } /** * Sets the partials from a sequence in an alignment. */ // protected final void setPartials(Beagle beagle, // TipStatesModel tipStatesModel, // int nodeIndex) { // double[] partials = new double[patternCount * stateCount * categoryCount]; // // tipStatesModel.getTipPartials(nodeIndex, partials); // // // if there is more than one category then replicate the partials for each // int n = patternCount * stateCount; // int k = n; // for (int i = 1; i < categoryCount; i++) { // System.arraycopy(partials, 0, partials, k, n); // k += n; // } // // beagle.setPartials(nodeIndex, partials); // } public int getPatternCount() { return patternCount; } /** * Sets the partials from a sequence in an alignment. * * @param beagle beagle * @param sequenceIndex sequenceIndex * @param nodeIndex nodeIndex */ private final void setStates(Beagle beagle, int sequenceIndex, int nodeIndex) { int[] states = new int[patternCount]; int k = 0; for (PatternList patternList : patternLists) { for (int j = 0; j < patternList.getPatternCount(); j++) { states[k] = patternList.getPatternState(sequenceIndex, j); k++; } } beagle.setTipStates(nodeIndex, states); } // public void setStates(int tipIndex, int[] states) { // System.err.println("BTL:setStates"); // beagle.setTipStates(tipIndex, states); // makeDirty(); // } // // public void getStates(int tipIndex, int[] states) { // System.err.println("BTL:getStates"); // beagle.getTipStates(tipIndex, states); // } // public final void setPatternWeights1(double[] patternWeights) { // this.patternWeights = patternWeights; // beagle.setPatternWeights(patternWeights); // } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** /** * Handles model changed events from the submodels. */ protected void handleModelChangedEvent(Model model, Object object, int index) { fireModelChanged(); if (model == treeModel) { if (object instanceof TreeModel.TreeChangedEvent) { if (((TreeModel.TreeChangedEvent) object).isNodeChanged()) { // If a node event occurs the node and its two child nodes // are flagged for updating (this will result in everything // above being updated as well. Node events occur when a node // is added to a branch, removed from a branch or its height or // rate changes. updateNodeAndChildren(((TreeModel.TreeChangedEvent) object).getNode()); updateRestrictedNodePartials = true; } else if (((TreeModel.TreeChangedEvent) object).isTreeChanged()) { // Full tree events result in a complete updating of the tree likelihood // This event type is now used for EmpiricalTreeDistributions. // System.err.println("Full tree update event - these events currently aren't used\n" + // "so either this is in error or a new feature is using them so remove this message."); updateAllNodes(); updateRestrictedNodePartials = true; } else { // Other event types are ignored (probably trait changes). //System.err.println("Another tree event has occured (possibly a trait change)."); } } } else if (model == branchRateModel) { if (index == -1) { if (COUNT_TOTAL_OPERATIONS) totalRateUpdateAllCount++; updateAllNodes(); } else { if (COUNT_TOTAL_OPERATIONS) totalRateUpdateSingleCount++; updateNode(treeModel.getNode(index)); } } else if (model == branchModel) { // if (index == -1) { // updateSubstitutionModel = true; // updateAllNodes(); // } else { // updateNode(treeModel.getNode(index)); // } makeDirty(); } else if (siteRateModels.contains(model)) { updateSiteModel[siteRateModels.indexOf(model)] = true; updateAllNodes(); } else if (model == tipStatesModel) { if (object instanceof Taxon) { for (int i = 0; i < treeModel.getNodeCount(); i++) if (treeModel.getNodeTaxon(treeModel.getNode(i)) != null && treeModel.getNodeTaxon(treeModel.getNode(i)).getId().equalsIgnoreCase(((Taxon) object).getId())) updateNode(treeModel.getNode(i)); } else updateAllNodes(); } else { throw new RuntimeException("Unknown componentChangedEvent"); } super.handleModelChangedEvent(model, object, index); } @Override public void makeDirty() { super.makeDirty(); for (int i = 0; i < partitionCount; i++) { updateSubstitutionModel[i] = true; updateSiteModel[i] = true; } updateRestrictedNodePartials = true; } // ************************************************************** // Model IMPLEMENTATION // ************************************************************** /** * Stores the additional state other than model components */ protected void storeState() { partialBufferHelper.storeState(); for (SubstitutionModelDelegate substitutionModelDelegate : substitutionModelDelegates) { substitutionModelDelegate.storeState(); } if (useScaleFactors || useAutoScaling) { // Only store when actually used scaleBufferHelper.storeState(); System.arraycopy(scaleBufferIndices, 0, storedScaleBufferIndices, 0, scaleBufferIndices.length); // storedRescalingCount = rescalingCount; } super.storeState(); } /** * Restore the additional stored state */ protected void restoreState() { for (int i = 0; i < partitionCount; i++) { // this is required to upload the categoryRates to BEAGLE after the restore updateSiteModel[i] = true; } partialBufferHelper.restoreState(); for (SubstitutionModelDelegate substitutionModelDelegate : substitutionModelDelegates) { substitutionModelDelegate.restoreState(); } if (useScaleFactors || useAutoScaling) { scaleBufferHelper.restoreState(); int[] tmp = storedScaleBufferIndices; storedScaleBufferIndices = scaleBufferIndices; scaleBufferIndices = tmp; // rescalingCount = storedRescalingCount; } updateRestrictedNodePartials = true; super.restoreState(); } // int marcCount = 0; // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** /** * Calculate the log likelihood of the current state. * * @return the log likelihood. */ protected double calculateLogLikelihood() { if (branchUpdateIndices == null) { branchUpdateIndices = new int[nodeCount]; branchLengths = new double[nodeCount]; scaleBufferIndices = new int[internalNodeCount]; storedScaleBufferIndices = new int[internalNodeCount]; } if (operations == null) { operations = new int[numRestrictedPartials + 1][internalNodeCount * Beagle.OPERATION_TUPLE_SIZE]; operationCount = new int[numRestrictedPartials + 1]; } recomputeScaleFactors = false; if (this.rescalingScheme == PartialsRescalingScheme.ALWAYS) { useScaleFactors = true; recomputeScaleFactors = true; } else if (this.rescalingScheme == PartialsRescalingScheme.DYNAMIC && everUnderflowed) { useScaleFactors = true; if (rescalingCountInner < RESCALE_TIMES) { recomputeScaleFactors = true; makeDirty(); // System.err.println("Recomputing scale factors"); } rescalingCountInner++; rescalingCount++; if (rescalingCount > rescalingFrequency) { rescalingCount = 0; rescalingCountInner = 0; } } else if (this.rescalingScheme == PartialsRescalingScheme.DELAYED && everUnderflowed) { useScaleFactors = true; recomputeScaleFactors = true; rescalingCount++; } if (tipStatesModel != null) { throw new UnsupportedOperationException("Tip error models not supported by MultiPartitionTreeLikelihood yet"); // int tipCount = treeModel.getExternalNodeCount(); // for (int index = 0; index < tipCount; index++) { // if (updateNode[index]) { // if (tipStatesModel.getModelType() == TipStatesModel.Type.PARTIALS) { // tipStatesModel.getTipPartials(index, tipPartials); // beagle.setTipPartials(index, tipPartials); // } else { // tipStatesModel.getTipStates(index, tipStates); // beagle.setTipStates(index, tipStates); // } // } // } } branchUpdateCount = 0; operationListCount = 0; if (hasRestrictedPartials) { for (int i = 0; i <= numRestrictedPartials; i++) { operationCount[i] = 0; } } else { operationCount[0] = 0; } final NodeRef root = treeModel.getRoot(); traverse(treeModel, root, null, true); for (int i = 0; i < partitionCount; i++) { if (updateSubstitutionModel[i]) { // TODO More efficient to update only the substitution model that changed, instead of all substitutionModelDelegates[i].updateSubstitutionModels(beagle); // we are currently assuming a no-category model... } if (updateSiteModel[i]) { double[] categoryRates = this.siteRateModels.get(i).getCategoryRates(); beagle.setCategoryRates(categoryRates); // TODO needs category rates for each partition... // beagle.setCategoryRates(i, categoryRates); } } if (branchUpdateCount > 0) { for (SubstitutionModelDelegate substitutionModelDelegate : substitutionModelDelegates) { substitutionModelDelegate.updateTransitionMatrices( beagle, branchUpdateIndices, branchLengths, branchUpdateCount); } } if (COUNT_TOTAL_OPERATIONS) { totalMatrixUpdateCount += branchUpdateCount; for (int i = 0; i <= numRestrictedPartials; i++) { totalOperationCount += operationCount[i]; } } double logL; boolean done; boolean firstRescaleAttempt = true; do { if (hasRestrictedPartials) { for (int i = 0; i <= numRestrictedPartials; i++) { beagle.updatePartials(operations[i], operationCount[i], Beagle.NONE); if (i < numRestrictedPartials) { // restrictNodePartials(restrictedIndices[i]); } } } else { beagle.updatePartials(operations[0], operationCount[0], Beagle.NONE); } int cumulateScaleBufferIndex = Beagle.NONE; if (useScaleFactors) { if (recomputeScaleFactors) { scaleBufferHelper.flipOffset(internalNodeCount); cumulateScaleBufferIndex = scaleBufferHelper.getOffsetIndex(internalNodeCount); beagle.resetScaleFactors(cumulateScaleBufferIndex); beagle.accumulateScaleFactors(scaleBufferIndices, internalNodeCount, cumulateScaleBufferIndex); } else { cumulateScaleBufferIndex = scaleBufferHelper.getOffsetIndex(internalNodeCount); } } else if (useAutoScaling) { beagle.accumulateScaleFactors(scaleBufferIndices, internalNodeCount, Beagle.NONE); } for (int i = 0; i < partitionCount; i++) { double[] categoryWeights = this.siteRateModels.get(i).getCategoryProportions(); // This should probably explicitly be the state frequencies for the root node... double[] frequencies = substitutionModelDelegates[i].getRootStateFrequencies(); // these could be set only when they change but store/restore would need to be considered beagle.setCategoryWeights(i, categoryWeights); beagle.setStateFrequencies(i, frequencies); } double[] sumLogLikelihoods = new double[1]; int rootIndex = partialBufferHelper.getOffsetIndex(root.getNumber()); beagle.calculateRootLogLikelihoods(new int[]{rootIndex}, new int[]{0}, new int[]{0}, new int[]{cumulateScaleBufferIndex}, 1, sumLogLikelihoods); logL = sumLogLikelihoods[0]; // if (ascertainedSitePatterns) { // // Need to correct for ascertainedSitePatterns // beagle.getSiteLogLikelihoods(patternLogLikelihoods); // logL = getAscertainmentCorrectedLogLikelihood((AscertainedSitePatterns) patternList, // patternLogLikelihoods, patternWeights); // } if (Double.isNaN(logL) || Double.isInfinite(logL)) { everUnderflowed = true; logL = Double.NEGATIVE_INFINITY; if (firstRescaleAttempt && (rescalingScheme == PartialsRescalingScheme.DYNAMIC || rescalingScheme == PartialsRescalingScheme.DELAYED)) { // we have had a potential under/over flow so attempt a rescaling if (rescalingScheme == PartialsRescalingScheme.DYNAMIC || (rescalingCount == 0)) { Logger.getLogger("dr.evomodel").info("Underflow calculating likelihood. Attempting a rescaling..."); } useScaleFactors = true; recomputeScaleFactors = true; branchUpdateCount = 0; if (hasRestrictedPartials) { for (int i = 0; i <= numRestrictedPartials; i++) { operationCount[i] = 0; } } else { operationCount[0] = 0; } // traverse again but without flipping partials indices as we // just want to overwrite the last attempt. We will flip the // scale buffer indices though as we are recomputing them. traverse(treeModel, root, null, false); done = false; // Run through do-while loop again firstRescaleAttempt = false; // Only try to rescale once } else { // we have already tried a rescale, not rescaling or always rescaling // so just return the likelihood... done = true; } } else { done = true; // No under-/over-flow, then done } } while (!done); // If these are needed... //beagle.getSiteLogLikelihoods(patternLogLikelihoods); //******************************************************************** // after traverse all nodes and patterns have been updated -- //so change flags to reflect this. for (int i = 0; i < nodeCount; i++) { updateNode[i] = false; } for (int i = 0; i < partitionCount; i++) { updateSubstitutionModel[i] = false; updateSiteModel[i] = false; } //******************************************************************** return logL; } public void getPartials(int number, double[] partials) { int cumulativeBufferIndex = Beagle.NONE; /* No need to rescale partials */ beagle.getPartials(partialBufferHelper.getOffsetIndex(number), cumulativeBufferIndex, partials); } public boolean arePartialsRescaled() { return useScaleFactors; } protected void setPartials(int number, double[] partials) { beagle.setPartials(partialBufferHelper.getOffsetIndex(number), partials); } private void restrictNodePartials(int nodeIndex) { Parameter restrictionParameter = partialsMap[nodeIndex]; if (restrictionParameter == null) { return; } getPartials(nodeIndex, partials); double[] restriction = restrictionParameter.getParameterValues(); final int partialsLengthPerCategory = stateCount * patternCount; if (restriction.length == partialsLengthPerCategory) { for (int i = 0; i < categoryCount; i++) { componentwiseMultiply(partials, partialsLengthPerCategory * i, restriction, 0, partialsLengthPerCategory); } } else { componentwiseMultiply(partials, 0, restriction, 0, partialsLengthPerCategory * categoryCount); } setPartials(nodeIndex, partials); } private void componentwiseMultiply(double[] a, final int offsetA, double[] b, final int offsetB, final int length) { for (int i = 0; i < length; i++) { a[offsetA + i] *= b[offsetB + i]; } } private void computeNodeToRestrictionMap() { Arrays.fill(partialsMap, null); for (Set<String> taxonNames : partialsRestrictions.keySet()) { NodeRef node = TreeUtils.getCommonAncestorNode(treeModel, taxonNames); partialsMap[node.getNumber()] = partialsRestrictions.get(taxonNames); } } private double getAscertainmentCorrectedLogLikelihood(AscertainedSitePatterns patternList, double[] patternLogLikelihoods, double[] patternWeights) { double logL = 0.0; double ascertainmentCorrection = patternList.getAscertainmentCorrection(patternLogLikelihoods); for (int i = 0; i < patternCount; i++) { logL += (patternLogLikelihoods[i] - ascertainmentCorrection) * patternWeights[i]; } return logL; } /** * Traverse the tree calculating partial likelihoods. * * @param tree tree * @param node node * @param operatorNumber operatorNumber * @param flip flip * @return boolean */ private boolean traverse(Tree tree, NodeRef node, int[] operatorNumber, boolean flip) { boolean update = false; int nodeNum = node.getNumber(); NodeRef parent = tree.getParent(node); if (operatorNumber != null) { operatorNumber[0] = -1; } // First update the transition probability matrix(ices) for this branch if (parent != null && updateNode[nodeNum]) { final double branchRate = branchRateModel.getBranchRate(tree, node); final double parentHeight = tree.getNodeHeight(parent); final double nodeHeight = tree.getNodeHeight(node); // Get the operational time of the branch final double branchLength = branchRate * (parentHeight - nodeHeight); if (branchLength < 0.0) { throw new RuntimeException("Negative branch length: " + branchLength); } if (flip) { for (SubstitutionModelDelegate substitutionModelDelegate : substitutionModelDelegates) { substitutionModelDelegate.flipMatrixBuffer(nodeNum); } } branchUpdateIndices[branchUpdateCount] = nodeNum; branchLengths[branchUpdateCount] = branchLength; branchUpdateCount++; update = true; } // If the node is internal, update the partial likelihoods. if (!tree.isExternal(node)) { // Traverse down the two child nodes NodeRef child1 = tree.getChild(node, 0); final int[] op1 = {-1}; final boolean update1 = traverse(tree, child1, op1, flip); NodeRef child2 = tree.getChild(node, 1); final int[] op2 = {-1}; final boolean update2 = traverse(tree, child2, op2, flip); // If either child node was updated then update this node too if (update1 || update2) { int x = operationCount[operationListCount] * Beagle.OPERATION_TUPLE_SIZE; if (flip) { // first flip the partialBufferHelper partialBufferHelper.flipOffset(nodeNum); } final int[] operations = this.operations[operationListCount]; operations[x] = partialBufferHelper.getOffsetIndex(nodeNum); if (useScaleFactors) { // get the index of this scaling buffer int n = nodeNum - tipCount; if (recomputeScaleFactors) { // flip the indicator: can take either n or (internalNodeCount + 1) - n scaleBufferHelper.flipOffset(n); // store the index scaleBufferIndices[n] = scaleBufferHelper.getOffsetIndex(n); operations[x + 1] = scaleBufferIndices[n]; // Write new scaleFactor operations[x + 2] = Beagle.NONE; } else { operations[x + 1] = Beagle.NONE; operations[x + 2] = scaleBufferIndices[n]; // Read existing scaleFactor } } else { if (useAutoScaling) { scaleBufferIndices[nodeNum - tipCount] = partialBufferHelper.getOffsetIndex(nodeNum); } operations[x + 1] = Beagle.NONE; // Not using scaleFactors operations[x + 2] = Beagle.NONE; } // TODO not sure how these will work. Commented out to allow build. operations[x + 3] = partialBufferHelper.getOffsetIndex(child1.getNumber()); // source node 1 // operations[x + 4] = substitutionModelDelegate.getMatrixIndex(child1.getNumber()); // source matrix 1 operations[x + 5] = partialBufferHelper.getOffsetIndex(child2.getNumber()); // source node 2 // operations[x + 6] = substitutionModelDelegate.getMatrixIndex(child2.getNumber()); // source matrix 2 operationCount[operationListCount]++; update = true; if (hasRestrictedPartials) { // Test if this set of partials should be restricted if (updateRestrictedNodePartials) { // Recompute map computeNodeToRestrictionMap(); updateRestrictedNodePartials = false; } if (partialsMap[nodeNum] != null) { } } } } return update; } // ************************************************************** // INSTANCE VARIABLES // ************************************************************** private int[] branchUpdateIndices; private double[] branchLengths; private int branchUpdateCount; private int[] scaleBufferIndices; private int[] storedScaleBufferIndices; private int[][] operations; private int operationListCount; private int[] operationCount; // private final boolean hasRestrictedPartials; private static final boolean hasRestrictedPartials = false; private final int numRestrictedPartials; private final Map<Set<String>, Parameter> partialsRestrictions; private Parameter[] partialsMap; private double[] partials; private boolean updateRestrictedNodePartials; // private int[] restrictedIndices; protected BufferIndexHelper partialBufferHelper; protected BufferIndexHelper scaleBufferHelper; protected final int tipCount; protected final int internalNodeCount; private PartialsRescalingScheme rescalingScheme; private int rescalingFrequency = RESCALE_FREQUENCY; protected boolean useScaleFactors = false; private boolean useAutoScaling = false; private boolean recomputeScaleFactors = false; private boolean everUnderflowed = false; private int rescalingCount = 0; private int rescalingCountInner = 0; // private int storedRescalingCount; /** * the branch-site model for these sites */ private final BranchModel branchModel; /** * A delegate to handle substitution models on branches */ private final SubstitutionModelDelegate[] substitutionModelDelegates; /** * the site model for these sites */ private final List<PatternList> patternLists = new ArrayList<PatternList>(); /** * the site model for these sites */ private final List<SiteRateModel> siteRateModels = new ArrayList<SiteRateModel>(); /** * the branch rate model */ private final BranchRateModel branchRateModel; /** * the tip partials model */ private final TipStatesModel tipStatesModel; /** * the total number of patterns */ private final int patternCount; /** * the total number of partitions */ private final int partitionCount; /** * the number of patterns for each partition */ private final int[] patternCounts; /** * the pattern likelihoods */ private final double[] patternLogLikelihoods; /** * the number of rate categories */ private final int categoryCount; /** * an array used to transfer tip partials */ // private final double[] tipPartials; /** * an array used to transfer tip states */ // private final int[] tipStates; /** * the BEAGLE library instance */ private final Beagle beagle; /** * Flag to specify that the substitution model has changed */ protected final boolean[] updateSubstitutionModel; /** * Flag to specify that the site model has changed */ private final boolean[] updateSiteModel; private final DataType dataType; /** * the pattern weights */ private final double[] patternWeights; /** * the number of states in the data */ private final int stateCount; /** * Flag to specify if ambiguity codes are in use */ protected final boolean useAmbiguities; }//END: class