/* * CompoundLikelihood.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.inference.model; import dr.util.NumberFormatter; import dr.xml.Reportable; import java.util.*; import java.util.concurrent.*; /** * A likelihood function which is simply the product of a set of likelihood functions. * * @author Alexei Drummond * @author Andrew Rambaut * @version $Id: CompoundLikelihood.java,v 1.19 2005/05/25 09:14:36 rambaut Exp $ */ public class CompoundLikelihood implements Likelihood, Reportable { public final static boolean UNROLL_COMPOUND = true; public final static boolean EVALUATION_TIMERS = true; public final long[] evaluationTimes; public final int[] evaluationCounts; public CompoundLikelihood(int threads, Collection<Likelihood> likelihoods) { int i = 0; for (Likelihood l : likelihoods) { addLikelihood(l, i, true); i++; } if (threads < 0 && this.likelihoods.size() > 1) { // asking for an automatic threadpool size and there is more than one likelihood to compute threadCount = this.likelihoods.size(); // create a threadpool the size of the number of likelihoods // threadCount = -1; // use cached thread pool } else if (threads > 0) { threadCount = threads; // use a thread pool of a specified size } else { // no thread pool requested or only one likelihood threadCount = 0; } if (threadCount > 0) { pool = Executors.newFixedThreadPool(threadCount); } else if (threadCount < 0) { // create a cached thread pool which should create one thread per likelihood... pool = Executors.newCachedThreadPool(); } else { // don't use a threadpool (i.e., compute serially) pool = null; } if (EVALUATION_TIMERS) { evaluationTimes = new long[this.likelihoods.size()]; evaluationCounts = new int[this.likelihoods.size()]; } else { evaluationTimes = null; evaluationCounts = null; } } public CompoundLikelihood(Collection<Likelihood> likelihoods) { pool = null; threadCount = 0; int i = 0; for (Likelihood l : likelihoods) { addLikelihood(l, i, false); i++; } if (EVALUATION_TIMERS) { evaluationTimes = new long[this.likelihoods.size()]; evaluationCounts = new int[this.likelihoods.size()]; } else { evaluationTimes = null; evaluationCounts = null; } } // public CompoundLikelihood(BeagleBranchLikelihoods bbl) { // // pool = null; // threadCount = 0; // evaluationTimes = null; // evaluationCounts = null; // // } protected void addLikelihood(Likelihood likelihood, int index, boolean addToPool) { // unroll any compound likelihoods if (UNROLL_COMPOUND && addToPool && likelihood instanceof CompoundLikelihood) { for (Likelihood l : ((CompoundLikelihood)likelihood).getLikelihoods()) { addLikelihood(l, index, addToPool); } } else { if (!likelihoods.contains(likelihood)) { likelihoods.add(likelihood); if (likelihood.getModel() != null) { compoundModel.addModel(likelihood.getModel()); } if (likelihood.evaluateEarly()) { earlyLikelihoods.add(likelihood); } else { // late likelihood list is used to evaluate them if the thread pool is not being used... lateLikelihoods.add(likelihood); if (addToPool) { likelihoodCallers.add(new LikelihoodCaller(likelihood, index)); } } } else { throw new IllegalArgumentException("Attempted to add the same likelihood multiple times to CompoundLikelihood."); } // END: contains check }//END: if unroll check }//END: addLikelihood public Set<Likelihood> getLikelihoodSet() { Set<Likelihood> set = new HashSet<Likelihood>(); for (Likelihood l : likelihoods) { set.add(l); set.addAll(l.getLikelihoodSet()); } return set; } public int getLikelihoodCount() { return likelihoods.size(); } public final Likelihood getLikelihood(int i) { return likelihoods.get(i); } public List<Likelihood> getLikelihoods() { return likelihoods; } public List<Callable<Double>> getLikelihoodCallers() { return likelihoodCallers; } // ************************************************************** // Likelihood IMPLEMENTATION // ************************************************************** public Model getModel() { return compoundModel; } // // todo: remove in release // static int DEBUG = 0; public double getLogLikelihood() { double logLikelihood = evaluateLikelihoods(earlyLikelihoods); if( logLikelihood == Double.NEGATIVE_INFINITY ) { return Double.NEGATIVE_INFINITY; } if (pool == null) { // Single threaded logLikelihood += evaluateLikelihoods(lateLikelihoods); } else { try { List<Future<Double>> results = pool.invokeAll(likelihoodCallers); for (Future<Double> result : results) { double logL = result.get(); logLikelihood += logL; } } catch (InterruptedException e) { e.printStackTrace(); } catch (ExecutionException e) { e.printStackTrace(); } } // if( DEBUG > 0 ) { // int t = DEBUG; DEBUG = 0; // System.err.println(getId() + ": " + getDiagnosis(0) + " = " + logLikelihood); // DEBUG = t; // } if (DEBUG_PARALLEL_EVALUATION) { System.err.println(""); } return logLikelihood; } private double evaluateLikelihoods(ArrayList<Likelihood> likelihoods) { double logLikelihood = 0.0; int i = 0; for (Likelihood likelihood : likelihoods) { if (EVALUATION_TIMERS) { // this code is only compiled if EVALUATION_TIMERS is true long time = System.nanoTime(); double l = likelihood.getLogLikelihood(); evaluationTimes[i] += System.nanoTime() - time; evaluationCounts[i] ++; if( l == Double.NEGATIVE_INFINITY ) return Double.NEGATIVE_INFINITY; logLikelihood += l; i++; } else { final double l = likelihood.getLogLikelihood(); // if the likelihood is zero then short cut the rest of the likelihoods // This means that expensive likelihoods such as TreeLikelihoods should // be put after cheap ones such as BooleanLikelihoods if( l == Double.NEGATIVE_INFINITY ) return Double.NEGATIVE_INFINITY; logLikelihood += l; } } return logLikelihood; } public void makeDirty() { for( Likelihood likelihood : likelihoods ) { likelihood.makeDirty(); } } public boolean evaluateEarly() { return false; } public String getDiagnosis() { return getDiagnosis(0); } public String getDiagnosis(int indent) { String message = ""; boolean first = true; final NumberFormatter nf = new NumberFormatter(6); for( Likelihood lik : likelihoods ) { if( !first ) { message += ", "; } else { first = false; } if (indent >= 0) { message += "\n"; for (int i = 0; i < indent; i++) { message += " "; } } message += lik.prettyName() + "="; if( lik instanceof CompoundLikelihood ) { final String d = ((CompoundLikelihood) lik).getDiagnosis(indent < 0 ? -1 : indent + 2); if( d != null && d.length() > 0 ) { message += "(" + d; if (indent >= 0) { message += "\n"; for (int i = 0; i < indent; i++) { message += " "; } } message += ")"; } } else { final double logLikelihood = lik.getLogLikelihood(); if( logLikelihood == Double.NEGATIVE_INFINITY ) { message += "-Inf"; } else if( Double.isNaN(logLikelihood) ) { message += "NaN"; } else if( logLikelihood == Double.POSITIVE_INFINITY ) { message += "+Inf"; } else { message += nf.formatDecimal(logLikelihood, 4); } } } message += "\n"; for (int i = 0; i < indent; i++) { message += " "; } message += "Total = " + this.getLogLikelihood(); return message; } public String toString() { return getId(); // really bad for debugging //return Double.toString(getLogLikelihood()); } public String prettyName() { return Abstract.getPrettyName(this); } public boolean isUsed() { return used; } public void setUsed() { used = true; for (Likelihood l : likelihoods) { l.setUsed(); } } public int getThreadCount() { return threadCount; } public long[] getEvaluationTimes() { return evaluationTimes; } public int[] getEvaluationCounts() { return evaluationCounts; } public void resetEvaluationTimes() { for (int i = 0; i < evaluationTimes.length; i++) { evaluationTimes[i] = 0; evaluationCounts[i] = 0; } } // ************************************************************** // Loggable IMPLEMENTATION // ************************************************************** /** * @return the log columns. */ public dr.inference.loggers.LogColumn[] getColumns() { return new dr.inference.loggers.LogColumn[]{ new LikelihoodColumn(getId() == null ? "likelihood" : getId()) }; } private class LikelihoodColumn extends dr.inference.loggers.NumberColumn { public LikelihoodColumn(String label) { super(label); } public double getDoubleValue() { return getLogLikelihood(); } } // ************************************************************** // Reportable IMPLEMENTATION // ************************************************************** public String getReport() { return getReport(0); } public String getReport(int indent) { if (EVALUATION_TIMERS) { String message = "\n"; boolean first = true; final NumberFormatter nf = new NumberFormatter(6); int index = 0; for( Likelihood lik : likelihoods ) { if( !first ) { message += ", "; } else { first = false; } if (indent >= 0) { message += "\n"; for (int i = 0; i < indent; i++) { message += " "; } } message += lik.prettyName() + "="; if( lik instanceof CompoundLikelihood ) { final String d = ((CompoundLikelihood) lik).getReport(indent < 0 ? -1 : indent + 2); if( d != null && d.length() > 0 ) { message += "(" + d; if (indent >= 0) { message += "\n"; for (int i = 0; i < indent; i++) { message += " "; } } message += ")"; } } else { double secs = (double)evaluationTimes[index] / 1.0E9; message += evaluationCounts[index] + " evaluations in " + nf.format(secs) + " secs (" + nf.format(secs / evaluationCounts[index]) + " secs/eval)"; } index++; } return message; } else { return "No evaluation timer report available"; } } // ************************************************************** // Identifiable IMPLEMENTATION // ************************************************************** private String id = null; public void setId(String id) { this.id = id; } public String getId() { return id; } private boolean used = false; private final int threadCount; private final ExecutorService pool; private final ArrayList<Likelihood> likelihoods = new ArrayList<Likelihood>(); private final CompoundModel compoundModel = new CompoundModel("compoundModel"); private final ArrayList<Likelihood> earlyLikelihoods = new ArrayList<Likelihood>(); private final ArrayList<Likelihood> lateLikelihoods = new ArrayList<Likelihood>(); private final List<Callable<Double>> likelihoodCallers = new ArrayList<Callable<Double>>(); class LikelihoodCaller implements Callable<Double> { public LikelihoodCaller(Likelihood likelihood, int index) { this.likelihood = likelihood; this.index = index; } public Double call() throws Exception { if (DEBUG_PARALLEL_EVALUATION) { System.err.print("Invoking thread #" + index + " for " + likelihood.getId() + ": "); } if (EVALUATION_TIMERS) { long time = System.nanoTime(); double logL = likelihood.getLogLikelihood(); evaluationTimes[index] += System.nanoTime() - time; evaluationCounts[index] ++; return logL; } return likelihood.getLogLikelihood(); } private final Likelihood likelihood; private final int index; } public static final boolean DEBUG_PARALLEL_EVALUATION = false; }