/*
* EpochBranchSubstitutionModel.java
*
* Copyright (C) 2002-2012 Alexei Drummond, Andrew Rambaut & Marc A. Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.app.beagle.evomodel.sitemodel;
import beagle.Beagle;
import dr.app.beagle.evomodel.substmodel.EigenDecomposition;
import dr.app.beagle.evomodel.substmodel.FrequencyModel;
import dr.app.beagle.evomodel.substmodel.SubstitutionModel;
import dr.app.beagle.evomodel.treelikelihood.BufferIndexHelper;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import java.util.*;
/**
* @author Filip Bielejec
* @author Marc A. Suchard
* @version $Id$
*/
@SuppressWarnings("serial")
@Deprecated // Switching to BranchModel
public class EpochBranchSubstitutionModel extends AbstractModel implements
BranchSubstitutionModel, Citable {
// /////////////
// ---DEBUG---//
// /////////////
private static final boolean DEBUG_EPOCH = false;
private static Integer stateCount = null;
private static Integer categoryCount = null;
// //////////////////
// ---END: DEBUG---//
// //////////////////
public static final boolean TRY_EPOCH = true;
public static final String EPOCH_BRANCH_SUBSTITUTION_MODEL = "EpochBranchSubstitutionModel";
private final List<SubstitutionModel> substModelList;
private final List<FrequencyModel> frequencyModelList;
private final BranchRateModel branchRateModel;
private final Parameter epochTimes;
private int firstBuffer;
private Map<Integer, double[]> convolutionMatricesMap = new HashMap<Integer, double[]>();
private int requestedBuffers;
public EpochBranchSubstitutionModel(List<SubstitutionModel> substModelList,
List<FrequencyModel> frequencyModelList,
BranchRateModel branchRateModel,
Parameter epochTimes) {
super(EPOCH_BRANCH_SUBSTITUTION_MODEL);
if (frequencyModelList.size() != 1) {
throw new IllegalArgumentException(
"EpochBranchSubstitutionModel requires one FrequencyModel");
}
this.substModelList = substModelList;
this.frequencyModelList = frequencyModelList;
this.epochTimes = epochTimes;
this.requestedBuffers = 0;
this.branchRateModel = branchRateModel;
for (SubstitutionModel model : substModelList) {
addModel(model);
}
for (FrequencyModel model : frequencyModelList) {
addModel(model);
}
if(DEBUG_EPOCH) {
stateCount = frequencyModelList.get(0).getDataType().getStateCount();
categoryCount = 4;
}//END: DEBUG_EPOCH
addVariable(epochTimes);
}// END: Constructor
/**
* @return number of extra transition matrices buffers to allocate
*/
public int getExtraBufferCount(TreeModel treeModel) {
requestedBuffers = 100;
System.out.println("Allocating " + requestedBuffers + " extra buffers.");
return requestedBuffers;
}// END: getBufferCount
public void setFirstBuffer(int firstBufferCount) {
firstBuffer = firstBufferCount;
}// END: setFirstBuffer
public EigenDecomposition getEigenDecomposition(int branchIndex,
int categoryIndex) {
return substModelList.get(branchIndex).getEigenDecomposition();
}// END: getEigenDecomposition
public SubstitutionModel getSubstitutionModel(int branchIndex,
int categoryIndex) {
return substModelList.get(branchIndex);
}// END: getSubstitutionModel
public double[] getStateFrequencies(int categoryIndex) {
return frequencyModelList.get(categoryIndex).getFrequencies();
}// END: getStateFrequencies
public int getEigenCount() {
// Use an extra eigenIndex to identify branches that need convolution
return substModelList.size() + 1;
}// END: getEigenCount
public void setEigenDecomposition(Beagle beagle, int eigenIndex,
BufferIndexHelper bufferHelper, int dummy) {
if (eigenIndex < substModelList.size()) {
EigenDecomposition ed = getEigenDecomposition(eigenIndex, dummy);
beagle.setEigenDecomposition(bufferHelper.getOffsetIndex(eigenIndex),
ed.getEigenVectors(),
ed.getInverseEigenVectors(),
ed.getEigenValues()
);
}// END: nModels check
}// END: setEigenDecomposition
public boolean canReturnComplexDiagonalization() {
for (SubstitutionModel model : substModelList) {
if (model.canReturnComplexDiagonalization()) {
return true;
}
}
return false;
}// END: canReturnComplexDiagonalization
protected void handleModelChangedEvent(Model model, Object object, int index) {
fireModelChanged();
}// END: handleModelChangedEvent
@SuppressWarnings("rawtypes")
protected void handleVariableChangedEvent(Variable variable, int index,
Parameter.ChangeType type) {
}// END: handleVariableChangedEvent
protected void storeState() {
}// END: storeState
protected void restoreState() {
}// END: restoreState
protected void acceptState() {
}// END: acceptState
/**
* Calculate weights that branch spends in each substitution model
*
* @param tree
* @param node
* @return nModels if branch needs convolution, subst model index if not
*/
public int getBranchIndex(final Tree tree, final NodeRef node,
int bufferIndex) {
int nModels = substModelList.size();
int lastTransitionTime = nModels - 2;
double[] weights = new double[nModels];
double[] transitionTimes = epochTimes.getParameterValues();
double parentHeight = tree.getNodeHeight(tree.getParent(node));
double nodeHeight = tree.getNodeHeight(node);
double branchLength = tree.getBranchLength(node);
int returnValue = 0;
if (parentHeight <= transitionTimes[0]) {
weights[0] = branchLength;
returnValue = 0;
} else {
// first case: 0-th transition time
if (nodeHeight < transitionTimes[0] && transitionTimes[0] <= parentHeight) {
weights[0] = transitionTimes[0] - nodeHeight;
returnValue = nModels;
} else {
weights[0] = 0;
}// END: 0-th model check
// second case: i to i+1 transition times
for (int i = 1; i <= lastTransitionTime; i++) {
if (nodeHeight < transitionTimes[i]) {
if (parentHeight <= transitionTimes[i] && transitionTimes[i - 1] < nodeHeight) {
weights[i] = branchLength;
returnValue = i;
} else {
double startTime = Math.max(nodeHeight, transitionTimes[i - 1]);
double endTime = Math.min(parentHeight, transitionTimes[i]);
if (endTime < startTime) {
weights[i] = 0;
} else {
weights[i] = (endTime - startTime);
returnValue = nModels;
}// END: negative weights check
}// END: full branch in middle epoch check
} else {
weights[i] = 0;
}// END: i-th model check
}// END: i loop
// third case: last transition time
if (parentHeight >= transitionTimes[lastTransitionTime] && transitionTimes[lastTransitionTime] > nodeHeight) {
weights[lastTransitionTime + 1] = parentHeight - transitionTimes[lastTransitionTime];
returnValue = nModels;
} else if (nodeHeight > transitionTimes[lastTransitionTime]) {
weights[lastTransitionTime + 1] = branchLength;
returnValue = nModels - 1;
} else {
weights[lastTransitionTime + 1] = 0;
}// END: last transition time check
}// END: if branch below first transition time bail out
if (branchRateModel != null) {
weights = scaleArray(weights, branchRateModel.getBranchRate(tree, node));
}
convolutionMatricesMap.put(bufferIndex, weights);
if (DEBUG_EPOCH) {
System.out.println("bufferIndex: " + bufferIndex);
System.out.println("weights: ");
printArray(weights, weights.length);
}// END: DEBUG_EPOCH
return returnValue;
}// END: getBranchIndex
public void updateTransitionMatrices(Beagle beagle,
int eigenIndex,
BufferIndexHelper bufferHelper,
final int[] probabilityIndices,
final int[] firstDerivativeIndices,
final int[] secondDervativeIndices,
final double[] edgeLengths,
int count // number of branches to update in parallel
) {
if (eigenIndex < substModelList.size()) {
if (DEBUG_EPOCH) {
System.out.println("Branch falls in a single category");
System.out.println("eigenIndex: " + eigenIndex);
System.out.println("Populating buffers: ");
printArray(probabilityIndices, count);
System.out.println("for weights: ");
printArray(edgeLengths, count);
}//END: DEBUG_EPOCH
// Branches fall in a single category
beagle.updateTransitionMatrices(bufferHelper.getOffsetIndex(eigenIndex),
probabilityIndices,
firstDerivativeIndices,
secondDervativeIndices,
edgeLengths,
count);
if (DEBUG_EPOCH) {
System.out.println("Transition probabilities from model: ");
for (int k = 0; k < probabilityIndices.length; k++) {
double tmp[] = new double[categoryCount * stateCount * stateCount];
beagle.getTransitionMatrix(probabilityIndices[k], // matrixIndex
tmp // outMatrix
);
System.out.println(probabilityIndices[k]);
printMatrix(tmp, stateCount, stateCount);
}
}// END: DEBUG_EPOCH
} else {
// Branches require convolution of two or more matrices
int stepSize = requestedBuffers/4 ;
if (DEBUG_EPOCH) {
System.out.println("Branch requires convolution");
System.out.println("stepSize: " + stepSize);
// System.out.println("count from tree = " + count);
// System.out.println("convolutionMatricesMap.size() = " + convolutionMatricesMap.size());
System.out.println("probabilityIndices: ");
printArray(probabilityIndices, probabilityIndices.length);
}//END: DEBUG_EPOCH
int step = 0;
while(step < count) {
if (DEBUG_EPOCH) {
System.out.println("step: " + step);
}//END: DEBUG_EPOCH
int[] firstBuffers = new int[stepSize];
int[] secondBuffers = new int[stepSize];
int[] firstExtraBuffers = new int[stepSize];
int[] secondExtraBuffers = new int[stepSize];
int[] resultBranchBuffers = new int[stepSize];
int[] probabilityBuffers = new int[stepSize];
int[] firstConvolutionBuffers = new int[stepSize];
int[] secondConvolutionBuffers = new int[stepSize];
int[] resultConvolutionBuffers = new int[stepSize];
for (int i = 0; i < stepSize; i++) {
firstBuffers[i] = firstBuffer + i;
secondBuffers[i] = (firstBuffer + stepSize) + i;
firstExtraBuffers[i] = (firstBuffer + 2 * stepSize) + i;
secondExtraBuffers[i] = (firstBuffer + 3 * stepSize) + i;
if (i < count) {
resultBranchBuffers[i] = probabilityIndices[i + step];
}
}// END: stepSize loop
if (DEBUG_EPOCH) {
System.out.println("resultBranchBuffers ");
printArray(resultBranchBuffers, resultBranchBuffers.length);
}//END: DEBUG_EPOCH
for (int i = 0; i < substModelList.size(); i++) {
int eigenBuffer = bufferHelper.getOffsetIndex(i);
double[] weights = new double[stepSize];
for (int j = 0; j < stepSize; j++) {
if ((step + j) < count) {
int index = probabilityIndices[j + step];
if (DEBUG_EPOCH) {
System.out.println("step + j: " + (step + j) + " index: " + index);
}
weights[j] = convolutionMatricesMap.get(index)[i];
}// END: index padding check
}// END: stepSize loop
if ((i == 1) && (i == (substModelList.size() - 1))) {
probabilityBuffers = secondBuffers;
firstConvolutionBuffers = firstBuffers;
secondConvolutionBuffers = probabilityBuffers;
resultConvolutionBuffers = resultBranchBuffers;
} else if ((i == 1) && (i != (substModelList.size() - 1))) {
probabilityBuffers = secondBuffers;
firstConvolutionBuffers = firstBuffers;
secondConvolutionBuffers = probabilityBuffers;
resultConvolutionBuffers = firstExtraBuffers;
} else if ((i != 1) && (i == (substModelList.size() - 1))) {
// even
if (i % 2 == 0) {
probabilityBuffers = firstBuffers;
firstConvolutionBuffers = firstExtraBuffers;
secondConvolutionBuffers = probabilityBuffers;
resultConvolutionBuffers = resultBranchBuffers;
// odd
} else {
probabilityBuffers = secondBuffers;
firstConvolutionBuffers = secondExtraBuffers;
secondConvolutionBuffers = probabilityBuffers;
resultConvolutionBuffers = resultBranchBuffers;
}
} else {
// even
if (i % 2 == 0) {
probabilityBuffers = firstBuffers;
firstConvolutionBuffers = firstExtraBuffers;
secondConvolutionBuffers = probabilityBuffers;
resultConvolutionBuffers = secondExtraBuffers;
// odd
} else {
probabilityBuffers = secondBuffers;
firstConvolutionBuffers = secondExtraBuffers;
secondConvolutionBuffers = probabilityBuffers;
resultConvolutionBuffers = firstExtraBuffers;
}// END: even-odd check
}// END: first-last buffer check
checkBuffers(probabilityBuffers);
int operationsCount = Math.min(stepSize, (count - step));
if (DEBUG_EPOCH) {
System.out.println("eigenBuffer: " + eigenBuffer);
System.out.println("Populating buffers: ");
printArray(probabilityBuffers, operationsCount);
System.out.println("for weights: ");
printArray(weights, operationsCount);
}//END: DEBUG_EPOCH
beagle.updateTransitionMatrices(eigenBuffer, // eigenIndex
probabilityBuffers, // probabilityIndices
null, // firstDerivativeIndices
null, // secondDerivativeIndices
weights, // edgeLengths
operationsCount // count
);
if (i != 0) {
if (DEBUG_EPOCH) {
System.out.println("convolving buffers: ");
printArray(firstConvolutionBuffers, operationsCount);
System.out.println("with buffers: ");
printArray(secondConvolutionBuffers, operationsCount);
System.out.println("into buffers: ");
printArray(resultConvolutionBuffers, operationsCount);
}//END: DEBUG_EPOCH
beagle.convolveTransitionMatrices(firstConvolutionBuffers, // A
secondConvolutionBuffers, // B
resultConvolutionBuffers, // C
operationsCount // count
);
}// END: 0-th eigen index check
}// END: eigen indices loop
step += stepSize;
}// END: step loop
}// END: eigenIndex check
if (DEBUG_EPOCH) {
System.out.println("Transition probabilities from model:");
for (int k = 0; k < probabilityIndices.length; k++) {
double tmp[] = new double[categoryCount * stateCount * stateCount];
beagle.getTransitionMatrix(probabilityIndices[k], // matrixIndex
tmp // outMatrix
);
System.out.println(probabilityIndices[k]);
printMatrix(tmp, stateCount, stateCount);
}
}//END: DEBUG_EPOCH
}// END: updateTransitionMatrices
private void checkBuffers(int[] probabilityBuffers) {
for (int buffer : probabilityBuffers) {
if (buffer >= firstBuffer + requestedBuffers) {
System.err.println("Programming error: requesting use of BEAGLE transition matrix buffer not allocated.");
System.err.println("Allocated: 0 to " + (firstBuffer + requestedBuffers - 1));
System.err.println("Requested = " + buffer);
System.err.println("Please complain to Button-Boy");
}
}
}//END: checkBuffers
/**
* @return a list of citations associated with this object
*/
public List<Citation> getCitations() {
List<Citation> citations = new ArrayList<Citation>();
citations.add(new Citation(new Author[]{new Author("F", "Bielejec"),
new Author("P", "Lemey"), new Author("G", "Baele"),
new Author("MA", "Suchard")}, Citation.Status.IN_PREPARATION));
return citations;
}// END: getCitations
// /////////////
// ---DEBUG---//
// /////////////
public static void printArray(double[] array) {
for (int i = 0; i < array.length; i++) {
System.out.println(String.format(Locale.US, "%.10f", array[i]));
}
System.out.print("\n");
}// END: printArray
public static void printArray(int[] array) {
for (int i = 0; i < array.length; i++) {
System.out.println(array[i]);
}
}// END: printArray
public static void printArray(double[] array, int nrow) {
for (int row = 0; row < nrow; row++) {
System.out.println(String.format(Locale.US, "%.10f", array[row]));
}
System.out.print("\n");
}// END: printArray
public static void printArray(int[] array, int nrow) {
for (int row = 0; row < nrow; row++) {
System.out.println(array[row]);
}
System.out.print("\n");
}// END: printArray
public static void print2DArray(double[][] array) {
for (int row = 0; row < array.length; row++) {
System.out.print("| ");
for (int col = 0; col < array[row].length; col++) {
System.out.print(String.format(Locale.US, "%.10f", array[row][col]) + " ");
}
System.out.print("|\n");
}
System.out.print("\n");
}// END: print2DArray
public static void print2DArray(int[][] array) {
for (int row = 0; row < array.length; row++) {
for (int col = 0; col < array[row].length; col++) {
System.out.print(array[row][col] + " ");
}
System.out.print("\n");
}
}// END: print2DArray
public static void printMatrix(double[][] matrix, int nrow, int ncol) {
for (int row = 0; row < nrow; row++) {
for (int col = 0; col < nrow; col++)
System.out.print(String.format(Locale.US, "%.10f", matrix[col + row * nrow]) + " ");
System.out.print("\n");
}
System.out.print("\n");
}// END: printMatrix
public static void printMatrix(double[] matrix, int nrow, int ncol) {
for (int row = 0; row < nrow; row++) {
System.out.print("| ");
for (int col = 0; col < nrow; col++)
System.out.print(String.format(Locale.US, "%.10f", matrix[col + row * nrow]) + " ");
System.out.print("|\n");
}
System.out.print("\n");
}// END: printMatrix
public static void printMatrix(int[] matrix, int nrow, int ncol) {
for (int row = 0; row < nrow; row++) {
System.out.print("| ");
for (int col = 0; col < nrow; col++)
System.out.print(matrix[col + row * nrow] + " ");
System.out.print("|\n");
}
System.out.print("\n");
}// END: printMatrix
public double[] scaleArray(double[] array, double scalar) {
for (int i = 0; i < array.length; i++) {
array[i] = array[i] * scalar;
}
return array;
}// END: scaleArray
// //////////////////
// ---END: DEBUG---//
// //////////////////
public int getExtraBufferCount_old(TreeModel treeModel) {
// loop over the tree to determine the count
double[] transitionTimes = epochTimes.getParameterValues();
int rootId = treeModel.getRoot().getNumber();
int count = 0;
for (NodeRef node : treeModel.getNodes()) {
if (node.getNumber() != rootId) {
double nodeHeight = treeModel.getNodeHeight(node);
double parentHeight = treeModel.getNodeHeight(treeModel
.getParent(node));
for (int i = 0; i < transitionTimes.length; i++) {
if (nodeHeight <= transitionTimes[i] && transitionTimes[i] < parentHeight) {
count++;
break;
}// END: transition time check check
}// END: transition times loop
}// END: root check
}// END: nodes loop
requestedBuffers = count * 4;
System.out.println("Allocating " + requestedBuffers + " extra buffers.");
return requestedBuffers;
}// END: getBufferCount_old
}// END: class