/*
* AncestralStateTreeLikelihood.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.oldevomodel.treelikelihood;
import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.DataType;
import dr.evolution.datatype.GeneralDataType;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.oldevomodel.sitemodel.SiteModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.math.MathUtils;
import java.util.logging.Logger;
/**
* @author Marc A. Suchard
*/
@Deprecated // Switching to BEAGLE
public class AncestralStateTreeLikelihood extends TreeLikelihood implements TreeTraitProvider {
public static final String STATES_KEY = "states";
// private boolean useExtraReconstructedStates = false;
/**
* Constructor.
* Now also takes a DataType so that ancestral states are printed using data codes
*
* @param patternList -
* @param treeModel -
* @param siteModel -
* @param branchRateModel -
* @param useAmbiguities -
* @param storePartials -
* @param dataType - need to provide the data-type, so that corrent data characters can be returned
* @param tag - string label for reconstruction characters in tree log
* @param forceRescaling -
* @param useMAP - perform maximum aposteriori reconstruction
* @param returnML - report integrate likelihood of tip data
*/
public AncestralStateTreeLikelihood(PatternList patternList, TreeModel treeModel,
SiteModel siteModel, BranchRateModel branchRateModel,
boolean useAmbiguities, boolean storePartials,
final DataType dataType,
final String tag,
boolean forceRescaling,
boolean useMAP,
boolean returnML) {
super(patternList, treeModel, siteModel, branchRateModel, null, useAmbiguities, false, storePartials,
false, forceRescaling);
this.dataType = dataType;
this.tag = tag;
reconstructedStates = new int[treeModel.getNodeCount()][patternCount];
storedReconstructedStates = new int[treeModel.getNodeCount()][patternCount];
this.useMAP = useMAP;
this.returnMarginalLogLikelihood = returnML;
treeTraits.addTrait(STATES_KEY, new TreeTrait.IA() {
public String getTraitName() {
return tag;
}
public Intent getIntent() {
return Intent.NODE;
}
public int[] getTrait(Tree tree, NodeRef node) {
return getStatesForNode(tree,node);
}
public String getTraitString(Tree tree, NodeRef node) {
return formattedState(getStatesForNode(tree,node), dataType);
}
});
if (useAmbiguities) {
Logger.getLogger("dr.evomodel.treelikelihood").info("Ancestral reconstruction using ambiguities is currently "+
"not support without BEAGLE");
System.exit(-1);
}
}
public AncestralStateTreeLikelihood(PatternList patternList, TreeModel treeModel,
SiteModel siteModel, BranchRateModel branchRateModel,
boolean useAmbiguities, boolean storePartials,
DataType dataType,
String tag,
boolean forceRescaling) {
this(patternList, treeModel, siteModel, branchRateModel, useAmbiguities,
storePartials, dataType, tag, forceRescaling, false, true);
}
public void storeState() {
super.storeState();
for (int i = 0; i < reconstructedStates.length; i++) {
System.arraycopy(reconstructedStates[i], 0, storedReconstructedStates[i], 0, reconstructedStates[i].length);
}
storedAreStatesRedrawn = areStatesRedrawn;
storedJointLogLikelihood = jointLogLikelihood;
}
public void restoreState() {
super.restoreState();
int[][] temp = reconstructedStates;
reconstructedStates = storedReconstructedStates;
storedReconstructedStates = temp;
areStatesRedrawn = storedAreStatesRedrawn;
jointLogLikelihood = storedJointLogLikelihood;
}
public DataType getDataType() {
return dataType;
}
public int[] getStatesForNode(Tree tree, NodeRef node) {
if (tree != treeModel) {
throw new RuntimeException("Can only reconstruct states on treeModel given to constructor");
}
if (!likelihoodKnown) {
calculateLogLikelihood();
likelihoodKnown = true;
}
if (!areStatesRedrawn) {
redrawAncestralStates();
}
return reconstructedStates[node.getNumber()];
}
public void redrawAncestralStates() {
jointLogLikelihood = 0;
traverseSample(treeModel, treeModel.getRoot(), null);
areStatesRedrawn = true;
}
// private boolean checkConditioning = true;
protected void handleModelChangedEvent(Model model, Object object, int index) {
super.handleModelChangedEvent(model, object, index);
fireModelChanged(model);
}
protected double calculateLogLikelihood() {
areStatesRedrawn = false;
double marginalLogLikelihood = super.calculateLogLikelihood();
if (returnMarginalLogLikelihood) {
return marginalLogLikelihood;
}
// redraw states and return joint density of drawn states
redrawAncestralStates();
return jointLogLikelihood;
}
protected TreeTraitProvider.Helper treeTraits = new Helper();
public TreeTrait[] getTreeTraits() {
return treeTraits.getTreeTraits();
}
public TreeTrait getTreeTrait(String key) {
return treeTraits.getTreeTrait(key);
}
private static String formattedState(int[] state, DataType dataType) {
StringBuffer sb = new StringBuffer();
sb.append("\"");
if (dataType instanceof GeneralDataType) {
boolean first = true;
for (int i : state) {
if (!first) {
sb.append(" ");
} else {
first = false;
}
sb.append(dataType.getCode(i));
}
} else {
for (int i : state) {
sb.append(dataType.getChar(i));
}
}
sb.append("\"");
return sb.toString();
}
private int drawChoice(double[] measure) {
if (useMAP) {
double max = measure[0];
int choice = 0;
for (int i = 1; i < measure.length; i++) {
if (measure[i] > max) {
max = measure[i];
choice = i;
}
}
return choice;
} else {
return MathUtils.randomChoicePDF(measure);
}
}
/**
* Traverse (pre-order) the tree sampling the internal node states.
*
* @param tree - TreeModel on which to perform sampling
* @param node - current node
* @param parentState - character state of the parent node to 'node'
*/
public void traverseSample(TreeModel tree, NodeRef node, int[] parentState) {
int nodeNum = node.getNumber();
NodeRef parent = tree.getParent(node);
// This function assumes that all partial likelihoods have already been calculated
// If the node is internal, then sample its state given the state of its parent (pre-order traversal).
double[] conditionalProbabilities = new double[stateCount];
int[] state = new int[patternCount];
if (!tree.isExternal(node)) {
if (parent == null) {
double[] rootPartials = getRootPartials();
// This is the root node
for (int j = 0; j < patternCount; j++) {
System.arraycopy(rootPartials, j * stateCount, conditionalProbabilities, 0, stateCount);
double[] frequencies = frequencyModel.getFrequencies();
for (int i = 0; i < stateCount; i++) {
conditionalProbabilities[i] *= frequencies[i];
}
try {
state[j] = drawChoice(conditionalProbabilities);
} catch (Error e) {
System.err.println(e.toString());
System.err.println("Please report error to Marc");
state[j] = 0;
}
reconstructedStates[nodeNum][j] = state[j];
//System.out.println("Pr(j) = " + frequencies[state[j]]);
jointLogLikelihood += Math.log(frequencies[state[j]]);
}
} else {
// This is an internal node, but not the root
double[] partialLikelihood = new double[stateCount * patternCount];
if (categoryCount > 1)
throw new RuntimeException("Reconstruction not implemented for multiple categories yet.");
likelihoodCore.getPartials(nodeNum, partialLikelihood);
// final double branchRate = branchRateModel.getBranchRate(tree, node);
//
// // Get the operational time of the branch
// final double branchTime = branchRate * ( tree.getNodeHeight(parent) - tree.getNodeHeight(node) );
//
// for (int i = 0; i < categoryCount; i++) {
//
// siteModel.getTransitionProbabilitiesForCategory(i, branchTime, probabilities);
//
// }
//
((AbstractLikelihoodCore) likelihoodCore).getNodeMatrix(nodeNum, 0, probabilities);
for (int j = 0; j < patternCount; j++) {
int parentIndex = parentState[j] * stateCount;
int childIndex = j * stateCount;
for (int i = 0; i < stateCount; i++) {
conditionalProbabilities[i] = partialLikelihood[childIndex + i] * probabilities[parentIndex + i];
}
state[j] = drawChoice(conditionalProbabilities);
reconstructedStates[nodeNum][j] = state[j];
double contrib = probabilities[parentIndex + state[j]];
//System.out.println("Pr(" + parentState[j] + ", " + state[j] + ") = " + contrib);
jointLogLikelihood += Math.log(contrib);
}
}
// Traverse down the two child nodes
NodeRef child1 = tree.getChild(node, 0);
traverseSample(tree, child1, state);
NodeRef child2 = tree.getChild(node, 1);
traverseSample(tree, child2, state);
} else {
// This is an external leaf
((AbstractLikelihoodCore) likelihoodCore).getNodeStates(nodeNum, reconstructedStates[nodeNum]);
// Check for ambiguity codes and sample them
for (int j = 0; j < patternCount; j++) {
final int thisState = reconstructedStates[nodeNum][j];
final int parentIndex = parentState[j] * stateCount;
((AbstractLikelihoodCore) likelihoodCore).getNodeMatrix(nodeNum, 0, probabilities);
if (dataType.isAmbiguousState(thisState)) {
System.arraycopy(probabilities, parentIndex, conditionalProbabilities, 0, stateCount);
reconstructedStates[nodeNum][j] = drawChoice(conditionalProbabilities);
}
double contrib = probabilities[parentIndex + reconstructedStates[nodeNum][j]];
//System.out.println("Pr(" + parentState[j] + ", " + reconstructedStates[nodeNum][j] + ") = " + contrib);
jointLogLikelihood += Math.log(contrib);
}
}
}
private DataType dataType;
private int[][] reconstructedStates;
private int[][] storedReconstructedStates;
private String tag;
private boolean areStatesRedrawn = false;
private boolean storedAreStatesRedrawn = false;
private boolean useMAP = false;
private boolean returnMarginalLogLikelihood = true;
private double jointLogLikelihood;
private double storedJointLogLikelihood;
}