/* * File TreeLikelihood.java * * Copyright (C) 2010 Remco Bouckaert remco@cs.auckland.ac.nz * * This file is part of BEAST2. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package beast.evolution.likelihood; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Random; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.RejectedExecutionException; import beast.app.BeastMCMC; import beast.app.beauti.Beauti; import beast.core.BEASTInterface; import beast.core.Description; import beast.core.Input; import beast.core.State; import beast.core.util.Log; import beast.evolution.alignment.Alignment; import beast.evolution.alignment.FilteredAlignment; import beast.evolution.sitemodel.SiteModel; import beast.evolution.substitutionmodel.SubstitutionModel; @Description("Calculates the likelihood of sequence data on a beast.tree given a site and substitution model using " + "a variant of the 'peeling algorithm'. For details, see" + "Felsenstein, Joseph (1981). Evolutionary trees from DNA sequences: a maximum likelihood approach. J Mol Evol 17 (6): 368-376.") public class ThreadedTreeLikelihood extends GenericTreeLikelihood { final public Input<Boolean> useAmbiguitiesInput = new Input<>("useAmbiguities", "flag to indicate leafs that sites containing ambiguous states should be handled instead of ignored (the default)", false); final public Input<Integer> maxNrOfThreadsInput = new Input<>("threads","maximum number of threads to use, if less than 1 the number of threads in BeastMCMC is used (default -1)", -1); final public Input<String> proportionsInput = new Input<>("proportions", "specifies proportions of patterns used per thread as space " + "delimited string. This is useful when using a mixture of BEAGLE devices that run at different speeds, e.g GPU and CPU. " + "The string is duplicated if there are more threads than proportions specified. For example, " + "'1 2' as well as '33 66' with 2 threads specifies that the first thread gets a third of the patterns and the second " + "two thirds. With 3 threads, it is interpreted as '1 2 1' = 25%, 50%, 25% and with 7 threads it is " + "'1 2 1 2 1 2 1' = 10% 20% 10% 20% 10% 20% 10%. If not specified, all threads get the same proportion of patterns."); enum Scaling {none, always, _default}; final public Input<Scaling> scalingInput = new Input<>("scaling", "type of scaling to use, one of " + Arrays.toString(Scaling.values()) + ". If not specified, the -beagle_scaling flag is used.", Scaling._default, Scaling.values()); /** private list of likelihoods, to notify framework of TreeLikelihoods being created in initAndValidate() **/ final private Input<List<TreeLikelihood>> likelihoodsInput = new Input<>("*","",new ArrayList<>()); @Override public List<Input<?>> listInputs() { List<Input<?>> list = super.listInputs(); if (!Beauti.isInBeauti() && System.getProperty("beast.is.junit.testing") == null) { // do not expose internal likelihoods to BEAUti or junit tests list.add(likelihoodsInput); } return list; } /** calculation engine **/ private TreeLikelihood [] treelikelihood; private ExecutorService pool = null; private final List<Callable<Double>> likelihoodCallers = new ArrayList<Callable<Double>>(); /** number of threads to use, changes when threading causes problems **/ private int threadCount; private double [] logPByThread; // specified a set ranges of patterns assigned to each thread // first patternPoints contains 0, then one point for each thread private int [] patternPoints; @Override public void initAndValidate() { threadCount = BeastMCMC.m_nThreads; if (maxNrOfThreadsInput.get() > 0) { threadCount = Math.min(maxNrOfThreadsInput.get(), BeastMCMC.m_nThreads); } String instanceCount = System.getProperty("beast.instance.count"); if (instanceCount != null && instanceCount.length() > 0) { threadCount = Integer.parseInt(instanceCount); } logPByThread = new double[threadCount]; // sanity check: alignment should have same #taxa as tree if (dataInput.get().getTaxonCount() != treeInput.get().getLeafNodeCount()) { throw new IllegalArgumentException("The number of nodes in the tree does not match the number of sequences"); } treelikelihood = new TreeLikelihood[threadCount]; if (dataInput.get().isAscertained) { Log.warning.println("Note, can only use single thread per alignment because the alignment is ascertained"); threadCount = 1; } if (threadCount <= 1) { treelikelihood[0] = new TreeLikelihood(); treelikelihood[0].setID(getID() + "0"); treelikelihood[0].initByName("data", dataInput.get(), "tree", treeInput.get(), "siteModel", siteModelInput.get(), "branchRateModel", branchRateModelInput.get(), "useAmbiguities", useAmbiguitiesInput.get(), "scaling" , scalingInput.get() + "" ); treelikelihood[0].getOutputs().add(this); likelihoodsInput.get().add(treelikelihood[0]); } else { pool = Executors.newFixedThreadPool(threadCount); calcPatternPoints(dataInput.get().getSiteCount()); for (int i = 0; i < threadCount; i++) { Alignment data = dataInput.get(); String filterSpec = (patternPoints[i] +1) + "-" + (patternPoints[i + 1]); if (data.isAscertained) { filterSpec += data.excludefromInput.get() + "-" + data.excludetoInput.get() + "," + filterSpec; } treelikelihood[i] = new TreeLikelihood(); treelikelihood[i].setID(getID() + i); treelikelihood[i].getOutputs().add(this); likelihoodsInput.get().add(treelikelihood[i]); FilteredAlignment filter = new FilteredAlignment(); if (i == 0 && dataInput.get() instanceof FilteredAlignment && ((FilteredAlignment)dataInput.get()).constantSiteWeightsInput.get() != null) { filter.initByName("data", dataInput.get()/*, "userDataType", m_data.get().getDataType()*/, "filter", filterSpec, "constantSiteWeights", ((FilteredAlignment)dataInput.get()).constantSiteWeightsInput.get() ); } else { filter.initByName("data", dataInput.get()/*, "userDataType", m_data.get().getDataType()*/, "filter", filterSpec ); } treelikelihood[i].initByName("data", filter, "tree", treeInput.get(), "siteModel", duplicate((BEASTInterface) siteModelInput.get(), i), "branchRateModel", duplicate(branchRateModelInput.get(), i), "useAmbiguities", useAmbiguitiesInput.get(), "scaling" , scalingInput.get() + "" ); likelihoodCallers.add(new TreeLikelihoodCaller(treelikelihood[i], i)); } } } /** create new instance of src object, connecting all inputs from src object * Note if input is a SubstModel, it is duplicated as well. * @param src object to be copied * @param i index used to extend ID with. * @return copy of src object */ private Object duplicate(BEASTInterface src, int i) { if (src == null) { return null; } BEASTInterface copy; try { copy = src.getClass().newInstance(); copy.setID(src.getID() + "_" + i); } catch (InstantiationException | IllegalAccessException e) { e.printStackTrace(); throw new RuntimeException("Programmer error: every object in the model should have a default constructor that is publicly accessible: " + src.getClass().getName()); } for (Input<?> input : src.listInputs()) { if (input.get() != null) { if (input.get() instanceof List) { // handle lists //((List)copy.getInput(input.getName())).clear(); for (Object o : (List<?>) input.get()) { if (o instanceof BEASTInterface) { // make sure it is not already in the list copy.setInputValue(input.getName(), o); } } } else if (input.get() instanceof SubstitutionModel) { // duplicate subst models BEASTInterface substModel = (BEASTInterface) duplicate((BEASTInterface) input.get(), i); copy.setInputValue(input.getName(), substModel); } else { // it is some other value copy.setInputValue(input.getName(), input.get()); } } } copy.initAndValidate(); return copy; } private void calcPatternPoints(int nPatterns) { patternPoints = new int[threadCount + 1]; if (proportionsInput.get() == null) { int range = nPatterns / threadCount; for (int i = 0; i < threadCount - 1; i++) { patternPoints[i+1] = range * (i+1); } patternPoints[threadCount] = nPatterns; } else { String [] strs = proportionsInput.get().split("\\s+"); double [] proportions = new double[threadCount]; for (int i = 0; i < threadCount; i++) { proportions[i] = Double.parseDouble(strs[i % strs.length]); } // normalise double sum = 0; for (double d : proportions) { sum += d; } for (int i = 0; i < threadCount; i++) { proportions[i] /= sum; } // cummulative for (int i = 1; i < threadCount; i++) { proportions[i] += proportions[i- 1]; } // calc ranges for (int i = 0; i < threadCount; i++) { patternPoints[i+1] = (int) (proportions[i] * nPatterns + 0.5); } } } /** * This method samples the sequences based on the tree and site model. */ @Override public void sample(State state, Random random) { throw new UnsupportedOperationException("Can't sample a fixed alignment!"); } @Override public double calculateLogP() { logP = calculateLogPByBeagle(); return logP; } class TreeLikelihoodCaller implements Callable<Double> { private final TreeLikelihood likelihood; private final int threadNr; public TreeLikelihoodCaller(TreeLikelihood likelihood, int threadNr) { this.likelihood = likelihood; this.threadNr = threadNr; } public Double call() throws Exception { try { logPByThread[threadNr] = likelihood.calculateLogP(); } catch (Exception e) { System.err.println("Something went wrong in thread " + threadNr); e.printStackTrace(); System.exit(0); } return logPByThread[threadNr]; } } private double calculateLogPByBeagle() { try { if (threadCount > 1) { pool.invokeAll(likelihoodCallers); logP = 0; for (double f : logPByThread) { logP += f; } } else { logP = treelikelihood[0].calculateLogP(); } } catch (RejectedExecutionException | InterruptedException e) { e.printStackTrace(); System.exit(0); } return logP; } /* return copy of pattern log likelihoods for each of the patterns in the alignment */ public double [] getPatternLogLikelihoods() { double [] patternLogLikelihoods = new double[dataInput.get().getPatternCount()]; int i = 0; for (TreeLikelihood b : treelikelihood) { double [] d = b.getPatternLogLikelihoods(); System.arraycopy(d, 0, patternLogLikelihoods, i, d.length); i += d.length; } return patternLogLikelihoods; } // getPatternLogLikelihoods /** CalculationNode methods **/ /** * check state for changed variables and update temp results if necessary * */ @Override protected boolean requiresRecalculation() { boolean requiresRecalculation = false; for (TreeLikelihood b : treelikelihood) { requiresRecalculation |= b.requiresRecalculation(); } return requiresRecalculation; } @Override public void store() { // for (TreeLikelihood b : treelikelihood) { // b.store(); // } super.store(); } @Override public void restore() { // for (TreeLikelihood b : treelikelihood) { // b.restore(); // } super.restore(); } /** * @return a list of unique ids for the state nodes that form the argument */ @Override public List<String> getArguments() { return Collections.singletonList(dataInput.get().getID()); } /** * @return a list of unique ids for the state nodes that make up the conditions */ @Override public List<String> getConditions() { return ((SiteModel.Base)siteModelInput.get()).getConditions(); } } // class ThreadedTreeLikelihood