/*
* File TreeLikelihood.java
*
* Copyright (C) 2010 Remco Bouckaert remco@cs.auckland.ac.nz
*
* This file is part of BEAST2.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package beast.evolution.likelihood;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import beast.core.Description;
import beast.core.Input;
import beast.core.State;
import beast.core.util.Log;
import beast.evolution.alignment.Alignment;
import beast.evolution.branchratemodel.BranchRateModel;
import beast.evolution.branchratemodel.StrictClockModel;
import beast.evolution.sitemodel.SiteModel;
import beast.evolution.substitutionmodel.SubstitutionModel;
import beast.evolution.tree.Node;
import beast.evolution.tree.Tree;
import beast.evolution.tree.TreeInterface;
@Description("Calculates the probability of sequence data on a beast.tree given a site and substitution model using " +
"a variant of the 'peeling algorithm'. For details, see" +
"Felsenstein, Joseph (1981). Evolutionary trees from DNA sequences: a maximum likelihood approach. J Mol Evol 17 (6): 368-376.")
public class TreeLikelihood extends GenericTreeLikelihood {
final public Input<Boolean> m_useAmbiguities = new Input<>("useAmbiguities", "flag to indicate that sites containing ambiguous states should be handled instead of ignored (the default)", false);
final public Input<Boolean> m_useTipLikelihoods = new Input<>("useTipLikelihoods", "flag to indicate that partial likelihoods are provided at the tips", false);
final public Input<String> implementationInput = new Input<>("implementation", "name of class that implements this treelikelihood potentially more efficiently. "
+ "This class will be tried first, with the TreeLikelihood as fallback implementation. "
+ "When multi-threading, multiple objects can be created.", "beast.evolution.likelihood.BeagleTreeLikelihood");
public static enum Scaling {none, always, _default};
final public Input<Scaling> scaling = new Input<>("scaling", "type of scaling to use, one of " + Arrays.toString(Scaling.values()) + ". If not specified, the -beagle_scaling flag is used.", Scaling._default, Scaling.values());
/**
* calculation engine *
*/
protected LikelihoodCore likelihoodCore;
BeagleTreeLikelihood beagle;
/**
* BEASTObject associated with inputs. Since none of the inputs are StateNodes, it
* is safe to link to them only once, during initAndValidate.
*/
SubstitutionModel substitutionModel;
protected SiteModel.Base m_siteModel;
protected BranchRateModel.Base branchRateModel;
/**
* flag to indicate the
* // when CLEAN=0, nothing needs to be recalculated for the node
* // when DIRTY=1 indicates a node partial needs to be recalculated
* // when FILTHY=2 indicates the indices for the node need to be recalculated
* // (often not necessary while node partial recalculation is required)
*/
protected int hasDirt;
/**
* Lengths of the branches in the tree associated with each of the nodes
* in the tree through their node numbers. By comparing whether the
* current branch length differs from stored branch lengths, it is tested
* whether a node is dirty and needs to be recomputed (there may be other
* reasons as well...).
* These lengths take branch rate models in account.
*/
protected double[] m_branchLengths;
protected double[] storedBranchLengths;
/**
* memory allocation for likelihoods for each of the patterns *
*/
protected double[] patternLogLikelihoods;
/**
* memory allocation for the root partials *
*/
protected double[] m_fRootPartials;
/**
* memory allocation for probability tables obtained from the SiteModel *
*/
double[] probabilities;
int matrixSize;
/**
* flag to indicate ascertainment correction should be applied *
*/
boolean useAscertainedSitePatterns = false;
/**
* dealing with proportion of site being invariant *
*/
double proportionInvariant = 0;
List<Integer> constantPattern = null;
@Override
public void initAndValidate() {
// sanity check: alignment should have same #taxa as tree
if (dataInput.get().getTaxonCount() != treeInput.get().getLeafNodeCount()) {
throw new IllegalArgumentException("The number of nodes in the tree does not match the number of sequences");
}
beagle = null;
beagle = new BeagleTreeLikelihood();
try {
beagle.initByName(
"data", dataInput.get(), "tree", treeInput.get(), "siteModel", siteModelInput.get(),
"branchRateModel", branchRateModelInput.get(), "useAmbiguities", m_useAmbiguities.get(),
"useTipLikelihoods", m_useTipLikelihoods.get(),"scaling", scaling.get().toString());
if (beagle.beagle != null) {
//a Beagle instance was found, so we use it
return;
}
} catch (Exception e) {
// ignore
}
// No Beagle instance was found, so we use the good old java likelihood core
beagle = null;
int nodeCount = treeInput.get().getNodeCount();
if (!(siteModelInput.get() instanceof SiteModel.Base)) {
throw new IllegalArgumentException("siteModel input should be of type SiteModel.Base");
}
m_siteModel = (SiteModel.Base) siteModelInput.get();
m_siteModel.setDataType(dataInput.get().getDataType());
substitutionModel = m_siteModel.substModelInput.get();
if (branchRateModelInput.get() != null) {
branchRateModel = branchRateModelInput.get();
} else {
branchRateModel = new StrictClockModel();
}
m_branchLengths = new double[nodeCount];
storedBranchLengths = new double[nodeCount];
int stateCount = dataInput.get().getMaxStateCount();
int patterns = dataInput.get().getPatternCount();
if (stateCount == 4) {
likelihoodCore = new BeerLikelihoodCore4();
} else {
likelihoodCore = new BeerLikelihoodCore(stateCount);
}
String className = getClass().getSimpleName();
Alignment alignment = dataInput.get();
Log.info.println(className + "(" + getID() + ") uses " + likelihoodCore.getClass().getSimpleName());
Log.info.println(" " + alignment.toString(true));
// print startup messages via Log.print*
proportionInvariant = m_siteModel.getProportionInvariant();
m_siteModel.setPropInvariantIsCategory(false);
if (proportionInvariant > 0) {
calcConstantPatternIndices(patterns, stateCount);
}
initCore();
patternLogLikelihoods = new double[patterns];
m_fRootPartials = new double[patterns * stateCount];
matrixSize = (stateCount + 1) * (stateCount + 1);
probabilities = new double[(stateCount + 1) * (stateCount + 1)];
Arrays.fill(probabilities, 1.0);
if (dataInput.get().isAscertained) {
useAscertainedSitePatterns = true;
}
}
/**
* Determine indices of m_fRootProbabilities that need to be updates
* // due to sites being invariant. If none of the sites are invariant,
* // the 'site invariant' category does not contribute anything to the
* // root probability. If the site IS invariant for a certain character,
* // taking ambiguities in account, there is a contribution of 1 from
* // the 'site invariant' category.
*/
void calcConstantPatternIndices(final int patterns, final int stateCount) {
constantPattern = new ArrayList<>();
for (int i = 0; i < patterns; i++) {
final int[] pattern = dataInput.get().getPattern(i);
final boolean[] isInvariant = new boolean[stateCount];
Arrays.fill(isInvariant, true);
for (final int state : pattern) {
final boolean[] isStateSet = dataInput.get().getStateSet(state);
if (m_useAmbiguities.get() || !dataInput.get().getDataType().isAmbiguousState(state)) {
for (int k = 0; k < stateCount; k++) {
isInvariant[k] &= isStateSet[k];
}
}
}
for (int k = 0; k < stateCount; k++) {
if (isInvariant[k]) {
constantPattern.add(i * stateCount + k);
}
}
}
}
protected void initCore() {
final int nodeCount = treeInput.get().getNodeCount();
likelihoodCore.initialize(
nodeCount,
dataInput.get().getPatternCount(),
m_siteModel.getCategoryCount(),
true, m_useAmbiguities.get()
);
final int extNodeCount = nodeCount / 2 + 1;
final int intNodeCount = nodeCount / 2;
if (m_useAmbiguities.get() || m_useTipLikelihoods.get()) {
setPartials(treeInput.get().getRoot(), dataInput.get().getPatternCount());
} else {
setStates(treeInput.get().getRoot(), dataInput.get().getPatternCount());
}
hasDirt = Tree.IS_FILTHY;
for (int i = 0; i < intNodeCount; i++) {
likelihoodCore.createNodePartials(extNodeCount + i);
}
}
/**
* This method samples the sequences based on the tree and site model.
*/
@Override
public void sample(State state, Random random) {
throw new UnsupportedOperationException("Can't sample a fixed alignment!");
}
/**
* set leaf states in likelihood core *
*/
protected void setStates(Node node, int patternCount) {
if (node.isLeaf()) {
Alignment data = dataInput.get();
int i;
int[] states = new int[patternCount];
int taxonIndex = getTaxonIndex(node.getID(), data);
for (i = 0; i < patternCount; i++) {
int code = data.getPattern(taxonIndex, i);
int[] statesForCode = data.getDataType().getStatesForCode(code);
if (statesForCode.length==1)
states[i] = statesForCode[0];
else
states[i] = code; // Causes ambiguous states to be ignored.
}
likelihoodCore.setNodeStates(node.getNr(), states);
} else {
setStates(node.getLeft(), patternCount);
setStates(node.getRight(), patternCount);
}
}
/**
*
* @param taxon the taxon name as a string
* @param data the alignment
* @return the taxon index of the given taxon name for accessing its sequence data in the given alignment,
* or -1 if the taxon is not in the alignment.
*/
private int getTaxonIndex(String taxon, Alignment data) {
int taxonIndex = data.getTaxonIndex(taxon);
if (taxonIndex == -1) {
if (taxon.startsWith("'") || taxon.startsWith("\"")) {
taxonIndex = data.getTaxonIndex(taxon.substring(1, taxon.length() - 1));
}
if (taxonIndex == -1) {
throw new RuntimeException("Could not find sequence " + taxon + " in the alignment");
}
}
return taxonIndex;
}
/**
* set leaf partials in likelihood core *
*/
protected void setPartials(Node node, int patternCount) {
if (node.isLeaf()) {
Alignment data = dataInput.get();
int states = data.getDataType().getStateCount();
double[] partials = new double[patternCount * states];
int k = 0;
int taxonIndex = getTaxonIndex(node.getID(), data);
for (int patternIndex_ = 0; patternIndex_ < patternCount; patternIndex_++) {
double[] tipLikelihoods = data.getTipLikelihoods(taxonIndex,patternIndex_);
if (tipLikelihoods != null) {
for (int state = 0; state < states; state++) {
partials[k++] = tipLikelihoods[state];
}
}
else {
int stateCount = data.getPattern(taxonIndex, patternIndex_);
boolean[] stateSet = data.getStateSet(stateCount);
for (int state = 0; state < states; state++) {
partials[k++] = (stateSet[state] ? 1.0 : 0.0);
}
}
}
likelihoodCore.setNodePartials(node.getNr(), partials);
} else {
setPartials(node.getLeft(), patternCount);
setPartials(node.getRight(), patternCount);
}
}
/**
* Calculate the log likelihood of the current state.
*
* @return the log likelihood.
*/
double m_fScale = 1.01;
int m_nScale = 0;
int X = 100;
@Override
public double calculateLogP() {
if (beagle != null) {
logP = beagle.calculateLogP();
return logP;
}
final TreeInterface tree = treeInput.get();
try {
if (traverse(tree.getRoot()) != Tree.IS_CLEAN)
calcLogP();
}
catch (ArithmeticException e) {
return Double.NEGATIVE_INFINITY;
}
m_nScale++;
if (logP > 0 || (likelihoodCore.getUseScaling() && m_nScale > X)) {
// System.err.println("Switch off scaling");
// m_likelihoodCore.setUseScaling(1.0);
// m_likelihoodCore.unstore();
// m_nHasDirt = Tree.IS_FILTHY;
// X *= 2;
// traverse(tree.getRoot());
// calcLogP();
// return logP;
} else if (logP == Double.NEGATIVE_INFINITY && m_fScale < 10 && !scaling.get().equals(Scaling.none)) { // && !m_likelihoodCore.getUseScaling()) {
m_nScale = 0;
m_fScale *= 1.01;
Log.warning.println("Turning on scaling to prevent numeric instability " + m_fScale);
likelihoodCore.setUseScaling(m_fScale);
likelihoodCore.unstore();
hasDirt = Tree.IS_FILTHY;
traverse(tree.getRoot());
calcLogP();
return logP;
}
return logP;
}
void calcLogP() {
logP = 0.0;
if (useAscertainedSitePatterns) {
final double ascertainmentCorrection = dataInput.get().getAscertainmentCorrection(patternLogLikelihoods);
for (int i = 0; i < dataInput.get().getPatternCount(); i++) {
logP += (patternLogLikelihoods[i] - ascertainmentCorrection) * dataInput.get().getPatternWeight(i);
}
} else {
for (int i = 0; i < dataInput.get().getPatternCount(); i++) {
logP += patternLogLikelihoods[i] * dataInput.get().getPatternWeight(i);
}
}
}
/* Assumes there IS a branch rate model as opposed to traverse() */
int traverse(final Node node) {
int update = (node.isDirty() | hasDirt);
final int nodeIndex = node.getNr();
final double branchRate = branchRateModel.getRateForBranch(node);
final double branchTime = node.getLength() * branchRate;
// First update the transition probability matrix(ices) for this branch
//if (!node.isRoot() && (update != Tree.IS_CLEAN || branchTime != m_StoredBranchLengths[nodeIndex])) {
if (!node.isRoot() && (update != Tree.IS_CLEAN || branchTime != m_branchLengths[nodeIndex])) {
m_branchLengths[nodeIndex] = branchTime;
final Node parent = node.getParent();
likelihoodCore.setNodeMatrixForUpdate(nodeIndex);
for (int i = 0; i < m_siteModel.getCategoryCount(); i++) {
final double jointBranchRate = m_siteModel.getRateForCategory(i, node) * branchRate;
substitutionModel.getTransitionProbabilities(node, parent.getHeight(), node.getHeight(), jointBranchRate, probabilities);
//System.out.println(node.getNr() + " " + Arrays.toString(m_fProbabilities));
likelihoodCore.setNodeMatrix(nodeIndex, i, probabilities);
}
update |= Tree.IS_DIRTY;
}
// If the node is internal, update the partial likelihoods.
if (!node.isLeaf()) {
// Traverse down the two child nodes
final Node child1 = node.getLeft(); //Two children
final int update1 = traverse(child1);
final Node child2 = node.getRight();
final int update2 = traverse(child2);
// If either child node was updated then update this node too
if (update1 != Tree.IS_CLEAN || update2 != Tree.IS_CLEAN) {
final int childNum1 = child1.getNr();
final int childNum2 = child2.getNr();
likelihoodCore.setNodePartialsForUpdate(nodeIndex);
update |= (update1 | update2);
if (update >= Tree.IS_FILTHY) {
likelihoodCore.setNodeStatesForUpdate(nodeIndex);
}
if (m_siteModel.integrateAcrossCategories()) {
likelihoodCore.calculatePartials(childNum1, childNum2, nodeIndex);
} else {
throw new RuntimeException("Error TreeLikelihood 201: Site categories not supported");
//m_pLikelihoodCore->calculatePartials(childNum1, childNum2, nodeNum, siteCategories);
}
if (node.isRoot()) {
// No parent this is the root of the beast.tree -
// calculate the pattern likelihoods
final double[] frequencies = //m_pFreqs.get().
substitutionModel.getFrequencies();
final double[] proportions = m_siteModel.getCategoryProportions(node);
likelihoodCore.integratePartials(node.getNr(), proportions, m_fRootPartials);
if (constantPattern != null) { // && !SiteModel.g_bUseOriginal) {
proportionInvariant = m_siteModel.getProportionInvariant();
// some portion of sites is invariant, so adjust root partials for this
for (final int i : constantPattern) {
m_fRootPartials[i] += proportionInvariant;
}
}
likelihoodCore.calculateLogLikelihoods(m_fRootPartials, frequencies, patternLogLikelihoods);
}
}
}
return update;
} // traverseWithBRM
/* return copy of pattern log likelihoods for each of the patterns in the alignment */
public double [] getPatternLogLikelihoods() {
if (beagle != null) {
return beagle.getPatternLogLikelihoods();
}
return patternLogLikelihoods.clone();
} // getPatternLogLikelihoods
/** CalculationNode methods **/
/**
* check state for changed variables and update temp results if necessary *
*/
@Override
protected boolean requiresRecalculation() {
if (beagle != null) {
return beagle.requiresRecalculation();
}
hasDirt = Tree.IS_CLEAN;
if (dataInput.get().isDirtyCalculation()) {
hasDirt = Tree.IS_FILTHY;
return true;
}
if (m_siteModel.isDirtyCalculation()) {
hasDirt = Tree.IS_DIRTY;
return true;
}
if (branchRateModel != null && branchRateModel.isDirtyCalculation()) {
//m_nHasDirt = Tree.IS_DIRTY;
return true;
}
return treeInput.get().somethingIsDirty();
}
@Override
public void store() {
if (beagle != null) {
beagle.store();
super.store();
return;
}
if (likelihoodCore != null) {
likelihoodCore.store();
}
super.store();
System.arraycopy(m_branchLengths, 0, storedBranchLengths, 0, m_branchLengths.length);
}
@Override
public void restore() {
if (beagle != null) {
beagle.restore();
super.restore();
return;
}
if (likelihoodCore != null) {
likelihoodCore.restore();
}
super.restore();
double[] tmp = m_branchLengths;
m_branchLengths = storedBranchLengths;
storedBranchLengths = tmp;
}
/**
* @return a list of unique ids for the state nodes that form the argument
*/
@Override
public List<String> getArguments() {
return Collections.singletonList(dataInput.get().getID());
}
/**
* @return a list of unique ids for the state nodes that make up the conditions
*/
@Override
public List<String> getConditions() {
return m_siteModel.getConditions();
}
} // class TreeLikelihood