/*
* EpochTreeBranchSubstitutionModel.java
*
* Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.oldevomodel.treelikelihood;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.oldevomodel.sitemodel.SiteModel;
import dr.oldevomodel.substmodel.SubstitutionModel;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.matrixAlgebra.Vector;
import java.util.List;
/**
* @author Marc A. Suchard
*/
@Deprecated // Switching to BEAGLE
public class EpochTreeBranchSubstitutionModel extends TreeBranchSubstitutionModel {
public static final boolean DEBUG = false;
public EpochTreeBranchSubstitutionModel(String name,
SiteModel siteModel, List<SubstitutionModel> substModelList, BranchRateModel branchModel,
Parameter transitionTimes) {
super(name, siteModel, null, branchModel);
this.modelList = substModelList;
this.transitionTimesParameter = transitionTimes;
this.transitionTimes = transitionTimesParameter.getParameterValues();
addVariable(transitionTimes);
for (SubstitutionModel model : modelList)
addModel(model);
numberModels = modelList.size();
weight = new double[numberModels];
stateCount = modelList.get(0).getDataType().getStateCount();
stepMatrix = new double[stateCount * stateCount];
productMatrix = new double[stateCount * stateCount];
resultMatrix = new double[stateCount * stateCount];
}
public void getTransitionProbabilities(Tree tree, NodeRef node, int rateCategory, double[] matrix) {
NodeRef parent = tree.getParent(node);
final double branchRate = branchModel.getBranchRate(tree, node);
// Get the operational time of the branch
final double startTime = tree.getNodeHeight(parent);
final double endTime = tree.getNodeHeight(node);
final double branchTime = branchRate * (startTime - endTime);
if (branchTime < 0.0) {
throw new RuntimeException("Negative branch length: " + branchTime);
}
double distance = siteModel.getRateForCategory(rateCategory) * branchTime;
int matrixCount = 0;
boolean oneMatrix = (getEpochWeights(startTime, endTime, weight) == 1);
for (int m = 0; m < numberModels; m++) {
if (weight[m] > 0) {
SubstitutionModel model = modelList.get(m);
if (matrixCount == 0) {
if (oneMatrix) {
model.getTransitionProbabilities(distance, matrix);
break;
} else
model.getTransitionProbabilities(distance * weight[m], resultMatrix);
matrixCount++;
} else {
model.getTransitionProbabilities(distance * weight[m], stepMatrix);
// Sum over unobserved state
int index = 0;
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
productMatrix[index] = 0;
for (int k = 0; k < stateCount; k++) {
productMatrix[index] += resultMatrix[i * stateCount + k] * stepMatrix[k * stateCount + j];
}
index++;
}
}
// Swap pointers
double[] tmpMatrix = resultMatrix;
resultMatrix = productMatrix;
productMatrix = tmpMatrix;
}
}
}
if (!oneMatrix)
System.arraycopy(resultMatrix, 0, matrix, 0, stateCount * stateCount);
}
private int getEpochWeights(double startTime, double endTime, double[] weights) {
int matrixCount = 0;
final double lengthTime = endTime - startTime;
final int lastTime = numberModels - 2;
// model 0, 1, 2, ..., K-2, K-1
// times 0, 1, ..., K-2,
// where K = numberModels
// First epoch: 0 -> transitionTimes[0];
if (startTime <= transitionTimes[0]) {
if (endTime <= transitionTimes[0])
weights[0] = 1;
else
weights[0] = (transitionTimes[0] - startTime) / lengthTime;
matrixCount++;
} else
weights[0] = 0;
// Middle epoches:
for (int i = 1; i <= lastTime; i++) {
if (startTime <= transitionTimes[i]) {
double start = Math.max(startTime, transitionTimes[i - 1]);
double end = Math.min(endTime, transitionTimes[i]);
weights[i] = (end - start) / lengthTime;
matrixCount++;
} else
weights[i] = 0;
}
// Last epoch: transitionTimes[K-2] -> Infinity
if (lastTime >= 0) {
if (endTime > transitionTimes[lastTime]) {
double start = Math.max(startTime, transitionTimes[lastTime]);
weights[lastTime + 1] = (endTime - start) / lengthTime;
matrixCount++;
} else
weights[lastTime + 1] = 0;
}
if (DEBUG) {
double totalWeight = 0;
for (int i = 0; i < numberModels; i++)
totalWeight += weights[i];
System.err.println("Start: " + startTime + " End: " + endTime + " Count: " + matrixCount + " Weight: " + totalWeight + " - " + new Vector(weights));
if (totalWeight > 1.001) System.exit(-1);
if (totalWeight < 0.999) System.exit(-1);
}
return matrixCount;
}
protected final void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {
super.handleVariableChangedEvent(variable, index, type);
if (variable == transitionTimesParameter) {
transitionTimes = transitionTimesParameter.getParameterValues();
fireModelChanged(variable, index);
}
}
private List<SubstitutionModel> modelList;
private Parameter transitionTimesParameter;
private double[] transitionTimes;
private double[] weight;
private double[] stepMatrix;
private double[] productMatrix;
private double[] resultMatrix;
private int numberModels;
private int stateCount;
}