/*
* WishartStatisticsWrapper.java
*
* Copyright (c) 2002-2017 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.evomodel.treedatalikelihood.continuous;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.treedatalikelihood.*;
import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.*;
import dr.math.distributions.WishartSufficientStatistics;
import dr.math.interfaces.ConjugateWishartStatisticsProvider;
import dr.math.matrixAlgebra.Vector;
import dr.xml.*;
import java.util.List;
import static dr.evomodel.treedatalikelihood.ProcessSimulationDelegate.AbstractRealizedContinuousTraitDelegate.getTipTraitName;
import static dr.evomodelxml.treelikelihood.TreeTraitParserUtilities.DEFAULT_TRAIT_NAME;
/**
* @author Marc A. Suchard
*/
public class WishartStatisticsWrapper extends AbstractModel implements ConjugateWishartStatisticsProvider, Loggable {
public static final String PARSER_NAME = "wishartStatistics";
public static final String TRAIT_NAME = TreeTraitParserUtilities.TRAIT_NAME;
public WishartStatisticsWrapper(final String name,
final String traitName,
final TreeDataLikelihood dataLikelihood,
final ContinuousDataLikelihoodDelegate likelihoodDelegate) {
super(name);
this.dataLikelihood = dataLikelihood;
this.likelihoodDelegate = likelihoodDelegate;
this.rateTransformation = likelihoodDelegate.getRateTransformation();
this.dimTrait = likelihoodDelegate.getTraitDim();
this.numTrait = likelihoodDelegate.getTraitCount();
this.tipCount = dataLikelihood.getTree().getExternalNodeCount();
this.dimPartial = dimTrait + 1;
addModel(dataLikelihood);
String partialTraitName = getTipTraitName(traitName);
tipSampleTrait = dataLikelihood.getTreeTrait(partialTraitName);
// tipFullConditionalTrait = dataLikelihood.getTreeTrait("fcd." + traitName);
//
// for (TreeTrait t : dataLikelihood.getTreeTraits()) {
// System.err.println(t.getTraitName());
// }
//
// System.err.println("Found? " + (tipFullConditionalTrait == null ? "no" : "yes"));
// System.exit(-1);
treeTraversalDelegate = new LikelihoodTreeTraversal(
dataLikelihood.getTree(),
dataLikelihood.getBranchRateModel(),
TreeTraversal.TraversalType.POST_ORDER);
if (likelihoodDelegate.getIntegrator() instanceof ContinuousDiffusionIntegrator.Multivariate) {
outerProductDelegate = likelihoodDelegate.createObservedDataOnly(likelihoodDelegate);
} else {
outerProductDelegate = likelihoodDelegate;
}
traitDataKnown = false;
outerProductsKnown = false;
}
@Override
public WishartSufficientStatistics getWishartStatistics() {
if (!outerProductsKnown) {
computeOuterProducts();
outerProductsKnown = true;
}
return wishartStatistics;
}
private void simulateMissingTraits() {
likelihoodDelegate.fireModelChanged(); // Force new sample!
// ProcessSimulationDelegate.MeanAndVariance mv =
// (ProcessSimulationDelegate.MeanAndVariance) tipFullConditionalTrait.getTrait(
// dataLikelihood.getTree(), dataLikelihood.getTree().getExternalNode(1));
//
// System.err.println("DONE");
// System.exit(-1);
double[] sample = (double[]) tipSampleTrait.getTrait(dataLikelihood.getTree(), null);
if (DEBUG) {
System.err.println("Attempting to simulate missing traits");
System.err.println(new Vector(sample));
}
final ContinuousDiffusionIntegrator cdi = outerProductDelegate.getIntegrator();
assert (cdi instanceof ContinuousDiffusionIntegrator.Basic);
double[] buffer = new double[dimPartial * numTrait];
for (int trait = 0; trait < numTrait; ++trait) {
buffer[trait * dimPartial + dimTrait] = Double.POSITIVE_INFINITY;
}
for (int tip = 0; tip < tipCount; ++tip) {
int sampleOffset = tip * dimTrait * numTrait;
int bufferOffset = 0;
for (int trait = 0; trait < numTrait; ++trait) {
System.arraycopy(sample, sampleOffset, buffer, bufferOffset, dimTrait);
sampleOffset += dimTrait;
bufferOffset += dimPartial;
}
outerProductDelegate.setTipDataDirectly(tip, buffer);
}
if (DEBUG) {
System.err.println("Finished draw");
}
}
private void computeOuterProducts() {
// Make sure partials on tree are ready
dataLikelihood.getLogLikelihood();
if (likelihoodDelegate != outerProductDelegate) {
simulateMissingTraits();
}
treeTraversalDelegate.updateAllNodes();
treeTraversalDelegate.dispatchTreeTraversalCollectBranchAndNodeOperations();
final List<DataLikelihoodDelegate.BranchOperation> branchOperations = treeTraversalDelegate.getBranchOperations();
final List<DataLikelihoodDelegate.NodeOperation> nodeOperations = treeTraversalDelegate.getNodeOperations();
final NodeRef root = dataLikelihood.getTree().getRoot();
try {
outerProductDelegate.setComputeWishartStatistics(true);
outerProductDelegate.calculateLikelihood(branchOperations, nodeOperations, root.getNumber());
outerProductDelegate.setComputeWishartStatistics(false);
} catch (DataLikelihoodDelegate.LikelihoodException e) {
throw new RuntimeException("Unhandled exception");
}
wishartStatistics = outerProductDelegate.getWishartStatistics();
if (DEBUG) {
System.err.println("WS: " + wishartStatistics);
}
}
@Override
public MatrixParameterInterface getPrecisionParamter() {
return likelihoodDelegate.getDiffusionModel().getPrecisionParameter();
}
@Override
protected void storeState() {
savedTraitDataKnown = traitDataKnown;
savedOuterProductsKnown = outerProductsKnown;
if (outerProductsKnown) {
if (savedWishartStatistics == null) {
savedWishartStatistics = wishartStatistics.clone();
} else {
wishartStatistics.copyTo(savedWishartStatistics);
}
}
}
@Override
protected void restoreState() {
traitDataKnown = savedTraitDataKnown;
outerProductsKnown = savedOuterProductsKnown;
if (outerProductsKnown) {
WishartSufficientStatistics tmp = wishartStatistics;
wishartStatistics = savedWishartStatistics;
savedWishartStatistics = tmp;
}
}
@Override
protected void acceptState() {
// Do nothing
}
@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
outerProductsKnown = false;
}
@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {
outerProductsKnown = false;
// TODO If no partially missing traits and diffusion model hit, then no update necessary
}
public static XMLObjectParser PARSER = new AbstractXMLObjectParser() {
@Override
public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String name = xo.hasId() ? xo.getId() : PARSER_NAME;
String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
final TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xo.getChild(TreeDataLikelihood.class);
DataLikelihoodDelegate delegate = treeDataLikelihood.getDataLikelihoodDelegate();
if (!(delegate instanceof ContinuousDataLikelihoodDelegate)) {
throw new XMLParseException("May not provide a sequence data likelihood in the precision Gibbs sampler");
}
final ContinuousDataLikelihoodDelegate continuousData = (ContinuousDataLikelihoodDelegate) delegate;
return new WishartStatisticsWrapper(name,
traitName,
treeDataLikelihood, continuousData);
}
/**
* @return an array of syntax rules required by this element.
* Order is not important.
*/
@Override
public XMLSyntaxRule[] getSyntaxRules() {
return syntax;
}
@Override
public String getParserDescription() {
return null;
}
@Override
public Class getReturnType() {
return ConjugateWishartStatisticsProvider.class;
}
/**
* @return Parser name, which is identical to name of xml element parsed by it.
*/
@Override
public String getParserName() {
return PARSER_NAME;
}
private final XMLSyntaxRule[] syntax = new XMLSyntaxRule[] {
new ElementRule(TreeDataLikelihood.class),
AttributeRule.newStringRule(TRAIT_NAME, true),
};
};
private final LikelihoodTreeTraversal treeTraversalDelegate;
private final ContinuousRateTransformation rateTransformation;
private final TreeTrait tipSampleTrait;
// private final TreeTrait tipFullConditionalTrait;
private final int dimTrait;
private final int numTrait;
private final int tipCount;
private final int dimPartial;
private final ContinuousTraitDataModel continuousTraitDataModel = null;
private final ContinuousDataLikelihoodDelegate likelihoodDelegate;
private final ContinuousDataLikelihoodDelegate outerProductDelegate;
private final TreeDataLikelihood dataLikelihood;
private boolean traitDataKnown;
private boolean outerProductsKnown;
private boolean savedTraitDataKnown;
private boolean savedOuterProductsKnown;
private WishartSufficientStatistics wishartStatistics;
private WishartSufficientStatistics savedWishartStatistics;
private double[] tipTraits;
private static final boolean DEBUG = false;
@Override
public LogColumn[] getColumns() {
int sampleLength = 0;
if (tipSampleTrait != null) {
double[] sample = (double[]) tipSampleTrait.getTrait(dataLikelihood.getTree(), null);
sampleLength = sample.length;
}
LogColumn[] columns = new LogColumn[dimTrait * dimTrait + sampleLength];
int index = 0;
for (int i = 0; i < dimTrait; ++i) {
for (int j = 0; j < dimTrait; ++j) {
columns[index] = new OuterProductColumn("OP" + (i + 1) + "" + (j + 1), index);
++index;
}
}
for (int i = 0; i < sampleLength; ++i) {
columns[index] = new TipSampleColumn("TIP" + (i + 1), i);
++index;
}
return columns;
}
private class OuterProductColumn extends NumberColumn {
private int index;
public OuterProductColumn(String label, int index) {
super(label);
this.index = index;
}
@Override
public double getDoubleValue() {
WishartSufficientStatistics ws = getWishartStatistics();
return ws.getScaleMatrix()[index];
}
}
private class TipSampleColumn extends NumberColumn {
private int index;
public TipSampleColumn(String label, int index) {
super(label);
this.index = index;
}
@Override
public double getDoubleValue() {
double[] sample = (double[]) tipSampleTrait.getTrait(dataLikelihood.getTree(), null);
return sample[index];
}
}
}