package dr.evomodel.continuous;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evolution.util.Taxon;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.*;
import dr.math.distributions.MultivariateDistribution;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import dr.xml.*;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;
/**
* @author Marc Suchard
*/
public abstract class AbstractMultivariateTraitLikelihood extends AbstractModelLikelihood
implements TreeTraitProvider, Citable {
public static final String TRAIT_LIKELIHOOD = "multivariateTraitLikelihood";
public static final String CONJUGATE_ROOT_PRIOR = "conjugateRootPrior";
public static final String MODEL = "diffusionModel";
public static final String TREE = "tree";
public static final String CACHE_BRANCHES = "cacheBranches";
public static final String REPORT_MULTIVARIATE = "reportAsMultivariate";
public static final String CHECK = "check";
public static final String USE_TREE_LENGTH = "useTreeLength";
public static final String SCALE_BY_TIME = "scaleByTime";
public static final String SUBSTITUTIONS = "substitutions";
public static final String SAMPLING_DENSITY = "samplingDensity";
public static final String INTEGRATE = "integrateInternalTraits";
public static final String RECIPROCAL_RATES = "reciprocalRates";
public static final String PRIOR_SAMPLE_SIZE = "priorSampleSize";
public static final String RANDOM_SAMPLE = "randomSample";
public static final String IGNORE_PHYLOGENY = "ignorePhylogeny";
public static final String ASCERTAINMENT = "ascertainedTaxon";
public static final String EXCHANGEABLE_TIPS = "exchangeableTips";
public AbstractMultivariateTraitLikelihood(String traitName,
TreeModel treeModel,
MultivariateDiffusionModel diffusionModel,
CompoundParameter traitParameter,
List<Integer> missingIndices,
boolean cacheBranches,
boolean scaleByTime,
boolean useTreeLength,
BranchRateModel rateModel,
Model samplingDensity,
boolean reportAsMultivariate,
boolean reciprocalRates) {
this(traitName, treeModel, diffusionModel, traitParameter, null, missingIndices, cacheBranches,
scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate, reciprocalRates);
}
public AbstractMultivariateTraitLikelihood(String traitName,
TreeModel treeModel,
MultivariateDiffusionModel diffusionModel,
CompoundParameter traitParameter,
Parameter deltaParameter,
List<Integer> missingIndices,
boolean cacheBranches,
boolean scaleByTime,
boolean useTreeLength,
BranchRateModel rateModel,
Model samplingDensity,
boolean reportAsMultivariate,
boolean reciprocalRates) {
super(TRAIT_LIKELIHOOD);
this.traitName = traitName;
this.treeModel = treeModel;
this.rateModel = rateModel;
this.diffusionModel = diffusionModel;
this.traitParameter = traitParameter;
this.missingIndices = missingIndices;
addModel(treeModel);
addModel(diffusionModel);
this.deltaParameter = deltaParameter;
if (deltaParameter != null) {
addVariable(deltaParameter);
}
if (rateModel != null) {
hasRateModel = true;
addModel(rateModel);
}
if (samplingDensity != null) {
addModel(samplingDensity);
}
if (traitParameter != null)
addVariable(traitParameter);
this.reportAsMultivariate = reportAsMultivariate;
this.cacheBranches = cacheBranches;
if (cacheBranches) {
cachedLogLikelihoods = new double[treeModel.getNodeCount()];
storedCachedLogLikelihood = new double[treeModel.getNodeCount()];
validLogLikelihoods = new boolean[treeModel.getNodeCount()];
storedValidLogLikelihoods = new boolean[treeModel.getNodeCount()];
}
this.scaleByTime = scaleByTime;
this.useTreeLength = useTreeLength;
this.reciprocalRates = reciprocalRates;
dimTrait = diffusionModel.getPrecisionmatrix().length;
dim = traitParameter != null ? traitParameter.getParameter(0).getDimension() : 0;
numData = dim / dimTrait;
if (dim % dimTrait != 0)
throw new RuntimeException("dim is not divisible by dimTrait");
recalculateTreeLength();
printInformtion();
}
protected void printInformtion() {
StringBuffer sb = new StringBuffer("Creating multivariate diffusion model:\n");
sb.append("\tTrait: ").append(traitName).append("\n");
sb.append("\tDiffusion process: ").append(diffusionModel.getId()).append("\n");
sb.append("\tHeterogenity model: ").append(rateModel != null ? rateModel.getId() : "homogeneous").append("\n");
sb.append("\tTree normalization: ").append(scaleByTime ? (useTreeLength ? "length" : "height") : "off").append("\n");
sb.append("\tUsing reciprocal (precision) rates: ").append(reciprocalRates).append("\n");
if (scaleByTime) {
recalculateTreeLength();
if (useTreeLength) {
sb.append("\tInitial tree length: ").append(treeLength).append("\n");
} else {
sb.append("\tInitial tree height: ").append(treeLength).append("\n");
}
}
sb.append(extraInfo());
sb.append("\tPlease cite:\n");
sb.append(Citable.Utils.getCitationString(this));
sb.append("\n\tDiffusion dimension : ").append(dimTrait).append("\n");
sb.append( "\tNumber of observations: ").append(numData).append("\n");
Logger.getLogger("dr.evomodel").info(sb.toString());
}
private static Citable TraitAscertainmentCitation = new Citable() {//} implements Citable {
public List<Citation> getCitations() {
List<Citation> list = new ArrayList<Citation>();
list.add(
new Citation(
new Author[]{
new Author("MA", "Suchard"),
new Author("J", "Novembre"),
new Author("B", "von Holdt"),
new Author("G", "Cybis"),
},
Citation.Status.IN_PREPARATION
)
);
return list;
}
};
public List<Citation> getCitations() {
List<Citation> citations = new ArrayList<Citation>();
citations.add(
CommonCitations.LEMEY_2010
);
return citations;
}
protected abstract String extraInfo();
public CompoundParameter getTraitParameter() {
return traitParameter;
}
public void setAscertainedTaxon(Taxon taxon) {
ascertainedTaxonIndex = treeModel.getTaxonIndex(taxon);
if (ascertainedTaxonIndex == -1) {
throw new RuntimeException("Taxon " + taxon.getId() + " is not in tree " + treeModel.getId());
}
doAscertainmentCorrect = true;
StringBuilder sb = new StringBuilder("Enabling ascertainment correction for multivariate trait model: ");
sb.append(getId()).append("\n");
sb.append("\tTaxon: ").append(taxon.getId()).append("\n");
sb.append("\tPlease cite:\n");
sb.append(Citable.Utils.getCitationString(TraitAscertainmentCitation));
Logger.getLogger("dr.evomodel").info(sb.toString());
}
public double getRescaledBranchLength(NodeRef node) {
double length = treeModel.getBranchLength(node);
if (hasRateModel) {
if (reciprocalRates) {
length /= rateModel.getBranchRate(treeModel, node); // branch rate scales as precision (inv-time)
} else {
length *= rateModel.getBranchRate(treeModel, node); // branch rate scales as variance (time)
}
}
if (scaleByTime) {
length /= treeLength;
}
if (deltaParameter != null && treeModel.isExternal(node)) {
length += deltaParameter.getParameterValue(0);
}
return length;
}
// **************************************************************
// ModelListener IMPLEMENTATION
// **************************************************************
protected void handleModelChangedEvent(Model model, Object object, int index) {
if (!cacheBranches) {
likelihoodKnown = false;
if (model == treeModel)
recalculateTreeLength();
return;
}
if (model == diffusionModel) {
updateAllNodes();
}
// fireTreeEvents sends two events here when a node trait is changed,
// ignoring object instance Parameter case
else if (model == treeModel) {
if (object instanceof TreeModel.TreeChangedEvent) {
TreeModel.TreeChangedEvent event = (TreeModel.TreeChangedEvent) object;
if (event.isTreeChanged()) {
recalculateTreeLength();
updateAllNodes();
} else if (event.isHeightChanged()) {
recalculateTreeLength();
if (useTreeLength || (scaleByTime && treeModel.isRoot(event.getNode())))
updateAllNodes();
else {
updateNodeAndChildren(event.getNode());
}
} else if (event.isNodeParameterChanged()) {
updateNodeAndChildren(event.getNode());
} else if (event.isNodeChanged()) {
recalculateTreeLength();
if (useTreeLength || (scaleByTime && treeModel.isRoot(event.getNode())))
updateAllNodes();
else {
updateNodeAndChildren(event.getNode());
}
} else {
throw new RuntimeException("Unexpected TreeModel TreeChangedEvent occurring in AbstractMultivariateTraitLikelihood");
}
} else if (object instanceof Parameter) {
// Ignoring
} else {
throw new RuntimeException("Unexpected object throwing events in AbstractMultivariateTraitLikelihood");
}
} else if (model == rateModel) {
if (index == -1) {
updateAllNodes();
} else {
if (object == null || ((Parameter) object).getDimension() == 2 * (treeModel.getNodeCount() - 1))
updateNode(treeModel.getNode(index)); // This is a branch specific update
else
updateAllNodes(); // Probably an epoch model
}
} else {
throw new RuntimeException("Unknown componentChangedEvent");
}
}
private void updateAllNodes() {
for (int i = 0; i < treeModel.getNodeCount(); i++)
validLogLikelihoods[i] = false;
likelihoodKnown = false;
}
private void updateNode(NodeRef node) {
validLogLikelihoods[node.getNumber()] = false;
likelihoodKnown = false;
}
private void updateNodeAndChildren(NodeRef node) {
validLogLikelihoods[node.getNumber()] = false;
for (int i = 0; i < treeModel.getChildCount(node); i++)
validLogLikelihoods[treeModel.getChild(node, i).getNumber()] = false;
likelihoodKnown = false;
}
public void recalculateTreeLength() {
if (!scaleByTime)
return;
if (useTreeLength) {
treeLength = 0;
for (int i = 0; i < treeModel.getNodeCount(); i++) {
NodeRef node = treeModel.getNode(i);
if (!treeModel.isRoot(node))
treeLength += treeModel.getBranchLength(node); // Bug was here
}
} else { // Normalizing by tree height.
treeLength = treeModel.getNodeHeight(treeModel.getRoot());
}
}
// **************************************************************
// VariableListener IMPLEMENTATION
// **************************************************************
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
if (variable == deltaParameter) {
likelihoodKnown = false;
}
if (variable == traitParameter) {
likelihoodKnown = false;
}
// All parameter changes are handled first by the treeModel
if (!cacheBranches)
likelihoodKnown = false;
}
// **************************************************************
// Model IMPLEMENTATION
// **************************************************************
/**
* Stores the precalculated state: in this case the intervals
*/
protected void storeState() {
storedLikelihoodKnown = likelihoodKnown;
storedLogLikelihood = logLikelihood;
storedTreeLength = treeLength;
if (cacheBranches) {
System.arraycopy(cachedLogLikelihoods, 0, storedCachedLogLikelihood, 0, treeModel.getNodeCount());
System.arraycopy(validLogLikelihoods, 0, storedValidLogLikelihoods, 0, treeModel.getNodeCount());
}
}
/**
* Restores the precalculated state: that is the intervals of the tree.
*/
protected void restoreState() {
likelihoodKnown = storedLikelihoodKnown;
logLikelihood = storedLogLikelihood;
treeLength = storedTreeLength;
if (cacheBranches) {
double[] tmp = storedCachedLogLikelihood;
storedCachedLogLikelihood = cachedLogLikelihoods;
cachedLogLikelihoods = tmp;
boolean[] tmp2 = storedValidLogLikelihoods;
storedValidLogLikelihoods = validLogLikelihoods;
validLogLikelihoods = tmp2;
}
}
protected void acceptState() {
} // nothing to do
public TreeModel getTreeModel() {
return treeModel;
}
public String getTraitName() {
return traitName;
}
public MultivariateDiffusionModel getDiffusionModel() {
return diffusionModel;
}
// public boolean getInSubstitutionTime() {
// return inSubstitutionTime;
// }
// **************************************************************
// Likelihood IMPLEMENTATION
// **************************************************************
public Model getModel() {
return this;
}
public String toString() {
return getClass().getName() + "(" + getLogLikelihood() + ")";
}
public final double getLogLikelihood() {
if (!likelihoodKnown) {
logLikelihood = calculateLogLikelihood();
if (doAscertainmentCorrect) {
double correction = calculateAscertainmentCorrection(ascertainedTaxonIndex);
// System.err.println("Correction = " + correction);
logLikelihood -= correction;
}
likelihoodKnown = true;
}
return logLikelihood;
}
protected abstract double calculateAscertainmentCorrection(int taxonIndex);
public abstract double getLogDataLikelihood();
public void makeDirty() {
likelihoodKnown = false;
if (cacheBranches)
updateAllNodes();
}
public LogColumn[] getColumns() {
return new LogColumn[]{
new LikelihoodColumn(getId() + ".joint"),
new NumberColumn(getId() + ".data") {
public double getDoubleValue() {
return getLogDataLikelihood();
}
}
};
}
public abstract double calculateLogLikelihood();
// public double getMaxLogLikelihood() {
// return maxLogLikelihood;
// }
// **************************************************************
// Loggable IMPLEMENTATION
// **************************************************************
private TreeTrait[] treeTraits = null;
public TreeTrait[] getTreeTraits() {
if (treeTraits == null) {
final double[] trait = getRootNodeTrait();
if (trait.length == 1 || reportAsMultivariate) {
treeTraits = new TreeTrait[] {
new TreeTrait.DA() {
public String getTraitName() {
return traitName;
}
public Intent getIntent() {
return Intent.NODE;
}
public Class getTraitClass() {
return Double.class;
}
public double[] getTrait(Tree tree, NodeRef node) {
return getTraitForNode(tree, node, traitName);
}
}
};
} else {
throw new RuntimeException("Reporting of traits is only supported as multivariate");
}
}
return treeTraits;
}
public TreeTrait getTreeTrait(String key) {
TreeTrait[] tts = getTreeTraits();
for (TreeTrait tt : tts) {
if (tt.getTraitName().equals(key)) {
return tt;
}
}
return null;
}
public final int getNumData() {
return numData;
}
public final int getDimTrait() {
return dimTrait;
}
protected double[] getRootNodeTrait() {
return treeModel.getMultivariateNodeTrait(treeModel.getRoot(), traitName);
}
public abstract double[] getTraitForNode(Tree tree, NodeRef node, String traitName);
public void check(Parameter trait) throws XMLParseException {
diffusionModel.check(trait);
}
// **************************************************************
// XMLElement IMPLEMENTATION
// **************************************************************
public Element createElement(Document d) {
throw new RuntimeException("Not implemented yet!");
}
// **************************************************************
// XMLObjectParser
// **************************************************************
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
public String getParserName() {
return TRAIT_LIKELIHOOD;
}
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
MultivariateDiffusionModel diffusionModel = (MultivariateDiffusionModel) xo.getChild(MultivariateDiffusionModel.class);
TreeModel treeModel = (TreeModel) xo.getChild(TreeModel.class);
boolean cacheBranches = xo.getAttribute(CACHE_BRANCHES, false);
boolean integrate = xo.getAttribute(INTEGRATE, false);
boolean useTreeLength = xo.getAttribute(USE_TREE_LENGTH, false);
boolean scaleByTime = xo.getAttribute(SCALE_BY_TIME, false);
boolean reciprocalRates = xo.getAttribute(RECIPROCAL_RATES, false);
boolean reportAsMultivariate = xo.getAttribute(REPORT_MULTIVARIATE, true);
BranchRateModel rateModel = (BranchRateModel) xo.getChild(BranchRateModel.class);
TreeTraitParserUtilities utilities = new TreeTraitParserUtilities();
String traitName = TreeTraitParserUtilities.DEFAULT_TRAIT_NAME;
TreeTraitParserUtilities.TraitsAndMissingIndices returnValue =
utilities.parseTraitsFromTaxonAttributes(xo, traitName, treeModel, integrate);
CompoundParameter traitParameter = returnValue.traitParameter;
List<Integer> missingIndices = returnValue.missingIndices;
traitName = returnValue.traitName;
Model samplingDensity = null;
if (xo.hasChildNamed(SAMPLING_DENSITY)) {
XMLObject cxo = xo.getChild(SAMPLING_DENSITY);
samplingDensity = (Model) cxo.getChild(Model.class);
}
Parameter deltaParameter = null;
if (xo.hasChildNamed("delta")) {
XMLObject cxo = xo.getChild("delta");
deltaParameter = (Parameter) cxo.getChild(Parameter.class);
}
AbstractMultivariateTraitLikelihood like;
if (integrate) {
MultivariateDistributionLikelihood rootPrior =
(MultivariateDistributionLikelihood) xo.getChild(MultivariateDistributionLikelihood.class);
if (rootPrior != null) {
if (!(rootPrior.getDistribution() instanceof MultivariateDistribution))
throw new XMLParseException("Only multivariate normal priors allowed for Gibbs sampling the root trait");
MultivariateNormalDistribution rootDistribution =
(MultivariateNormalDistribution) rootPrior.getDistribution();
like = new SemiConjugateMultivariateTraitLikelihood(traitName, treeModel, diffusionModel,
traitParameter, missingIndices, cacheBranches,
scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate,
rootDistribution, reciprocalRates);
} else {
XMLObject cxo = xo.getChild(CONJUGATE_ROOT_PRIOR);
if (cxo == null) {
throw new XMLParseException("Must specify a conjugate or multivariate normal root prior");
}
boolean ignorePhylogeny = xo.getAttribute(IGNORE_PHYLOGENY, false);
Parameter meanParameter = (Parameter) cxo.getChild(MultivariateDistributionLikelihood.MVN_MEAN)
.getChild(Parameter.class);
if (meanParameter.getDimension() != diffusionModel.getPrecisionmatrix().length) {
throw new XMLParseException("Root prior mean dimension does not match trait diffusion dimension");
}
Parameter sampleSizeParameter = (Parameter) cxo.getChild(PRIOR_SAMPLE_SIZE).getChild(Parameter.class);
double[] mean = meanParameter.getParameterValues();
double pseudoObservations = sampleSizeParameter.getParameterValue(0);
if (ignorePhylogeny) {
boolean exchangeableTips = xo.getAttribute(EXCHANGEABLE_TIPS, true);
like = new NonPhylogeneticMultivariateTraitLikelihood(traitName, treeModel, diffusionModel,
traitParameter, deltaParameter, missingIndices, cacheBranches,
scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate,
mean, pseudoObservations, reciprocalRates, exchangeableTips);
} else {
like = new FullyConjugateMultivariateTraitLikelihood(traitName, treeModel, diffusionModel,
traitParameter, deltaParameter, missingIndices, cacheBranches,
scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate,
mean, pseudoObservations, reciprocalRates);
}
}
} else {
like = new SampledMultivariateTraitLikelihood(traitName, treeModel, diffusionModel,
traitParameter, missingIndices, cacheBranches,
scaleByTime, useTreeLength, rateModel, samplingDensity, reportAsMultivariate,
reciprocalRates);
}
if (!integrate && xo.hasChildNamed(TreeTraitParserUtilities.RANDOMIZE)) {
utilities.randomize(xo);
}
if (xo.hasChildNamed(TreeTraitParserUtilities.JITTER)) {
utilities.jitter(xo, diffusionModel.getPrecisionmatrix().length, missingIndices);
}
if (xo.hasChildNamed(CHECK)) {
XMLObject cxo = xo.getChild(CHECK);
Parameter check = (Parameter) cxo.getChild(Parameter.class);
like.check(check);
}
if (xo.hasChildNamed(ASCERTAINMENT)) {
XMLObject cxo = xo.getChild(ASCERTAINMENT);
Taxon taxon = (Taxon) cxo.getChild(Taxon.class);
if (!integrate) {
throw new XMLParseException("Ascertainment correction is currently only implemented" +
" for integrated multivariate trait likelihood models");
}
like.setAscertainedTaxon(taxon);
}
return like;
}
//************************************************************************
// AbstractXMLObjectParser implementation
//************************************************************************
public String getParserDescription() {
return "Provides the likelihood of a continuous trait evolving on a tree by a " +
"given diffusion model.";
}
public XMLSyntaxRule[] getSyntaxRules() {
return rules;
}
private final XMLSyntaxRule[] rules = {
new StringAttributeRule(TreeTraitParserUtilities.TRAIT_NAME, "The name of the trait for which a likelihood should be calculated"),
new ElementRule(TreeTraitParserUtilities.TRAIT_PARAMETER, new XMLSyntaxRule[]{
new ElementRule(Parameter.class)
}),
new ElementRule("delta", new XMLSyntaxRule[]{
new ElementRule(Parameter.class)
}, true),
AttributeRule.newBooleanRule(INTEGRATE, true),
// new XORRule(
new ElementRule(MultivariateDistributionLikelihood.class, true),
new ElementRule(CONJUGATE_ROOT_PRIOR, new XMLSyntaxRule[]{
new ElementRule(MultivariateDistributionLikelihood.MVN_MEAN,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
new ElementRule(PRIOR_SAMPLE_SIZE,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}),
}, true),
// true),
new ElementRule(ASCERTAINMENT, new XMLSyntaxRule[] {
new ElementRule(Taxon.class)
}, true),
new ElementRule(MultivariateDiffusionModel.class),
new ElementRule(TreeModel.class),
new ElementRule(BranchRateModel.class, true),
AttributeRule.newDoubleArrayRule("cut", true),
AttributeRule.newBooleanRule(REPORT_MULTIVARIATE, true),
AttributeRule.newBooleanRule(USE_TREE_LENGTH, true),
AttributeRule.newBooleanRule(SCALE_BY_TIME, true),
AttributeRule.newBooleanRule(RECIPROCAL_RATES, true),
AttributeRule.newBooleanRule(CACHE_BRANCHES, true),
AttributeRule.newIntegerRule(RANDOM_SAMPLE, true),
AttributeRule.newBooleanRule(IGNORE_PHYLOGENY, true),
AttributeRule.newBooleanRule(EXCHANGEABLE_TIPS, true),
new ElementRule(Parameter.class, true),
TreeTraitParserUtilities.randomizeRules(true),
TreeTraitParserUtilities.jitterRules(true),
new ElementRule(CHECK, new XMLSyntaxRule[]{
new ElementRule(Parameter.class)
}, true)
};
public Class getReturnType() {
return AbstractMultivariateTraitLikelihood.class;
}
};
TreeModel treeModel = null;
MultivariateDiffusionModel diffusionModel = null;
String traitName = null;
CompoundParameter traitParameter;
List<Integer> missingIndices;
protected double logLikelihood;
protected double maxLogLikelihood = Double.NEGATIVE_INFINITY;
private double storedLogLikelihood;
protected boolean likelihoodKnown = false;
private boolean storedLikelihoodKnown = false;
private BranchRateModel rateModel = null;
private boolean hasRateModel = false;
private double treeLength;
private double storedTreeLength;
private final boolean reportAsMultivariate;
private final boolean scaleByTime;
private final boolean useTreeLength;
private final boolean reciprocalRates;
protected boolean cacheBranches;
protected double[] cachedLogLikelihoods;
protected double[] storedCachedLogLikelihood;
protected boolean[] validLogLikelihoods;
protected boolean[] storedValidLogLikelihoods;
private final Parameter deltaParameter;
private boolean doAscertainmentCorrect = false;
private int ascertainedTaxonIndex;
protected int numData;
protected int dimTrait;
protected int dim;
}