/*
* SubstitutionModelDelegate.java
*
* Copyright (C) 2002-2012 Alexei Drummond, Andrew Rambaut & Marc A. Suchard
*
* This file is part of BEAST.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package dr.app.beagle.evomodel.treelikelihood;
import beagle.Beagle;
import dr.app.beagle.evomodel.branchmodel.BranchModel;
import dr.app.beagle.evomodel.substmodel.EigenDecomposition;
import dr.app.beagle.evomodel.substmodel.SubstitutionModel;
import dr.evolution.tree.Tree;
import java.util.*;
/**
* @author Andrew Rambaut
* @author Filip Bielejec
* @author Marc A. Suchard
* @version $Id$
*/
public final class SubstitutionModelDelegate {
private static final boolean DEBUG = false;
private static final int BUFFER_POOL_SIZE = 100;
private final Tree tree;
private final List<SubstitutionModel> substitutionModelList;
private final BranchModel branchModel;
private final int eigenCount;
private final int nodeCount;
private final int extraBufferCount;
private final BufferIndexHelper eigenBufferHelper;
private BufferIndexHelper matrixBufferHelper;
private Deque<Integer> availableBuffers = new ArrayDeque<Integer>();
public SubstitutionModelDelegate(Tree tree, BranchModel branchModel) {
this.tree = tree;
this.substitutionModelList = branchModel.getSubstitutionModels();
this.branchModel = branchModel;
eigenCount = substitutionModelList.size();
nodeCount = tree.getNodeCount();
// two eigen buffers for each decomposition for store and restore.
eigenBufferHelper = new BufferIndexHelper(eigenCount, 0);
// two matrices for each node less the root
matrixBufferHelper = new BufferIndexHelper(nodeCount, 0);
this.extraBufferCount = branchModel.requiresMatrixConvolution() ? BUFFER_POOL_SIZE : 0;
for (int i = 0; i < extraBufferCount; i++) {
pushAvailableBuffer(i + matrixBufferHelper.getBufferCount());
}
}// END: Constructor
public boolean canReturnComplexDiagonalization() {
return substitutionModelList.get(0).getEigenDecomposition().canReturnComplexDiagonalization();
}
public int getEigenBufferCount() {
return eigenBufferHelper.getBufferCount();
}
public int getMatrixBufferCount() {
return matrixBufferHelper.getBufferCount() + extraBufferCount;
}
public int getSubstitutionModelCount() {
return substitutionModelList.size();
}
public SubstitutionModel getSubstitutionModel(int index) {
return substitutionModelList.get(index);
}
public void updateSubstitutionModels(Beagle beagle) {
for (int i = 0; i < eigenCount; i++) {
eigenBufferHelper.flipOffset(i);
EigenDecomposition ed = substitutionModelList.get(i).getEigenDecomposition();
beagle.setEigenDecomposition(
eigenBufferHelper.getOffsetIndex(i),
ed.getEigenVectors(),
ed.getInverseEigenVectors(),
ed.getEigenValues());
}
}
public void updateTransitionMatrices(Beagle beagle, int[] branchIndices, double[] edgeLength, int updateCount) {
int[][] probabilityIndices = new int[eigenCount][updateCount];
double[][] edgeLengths = new double[eigenCount][updateCount];
int[] counts = new int[eigenCount];
List<Deque<Integer>> convolutionList = new ArrayList<Deque<Integer>>();
for (int i = 0; i < updateCount; i++) {
BranchModel.Mapping mapping = branchModel.getBranchModelMapping(tree.getNode(branchIndices[i]));
int[] order = mapping.getOrder();
double[] weights = mapping.getWeights();
if (order.length == 1) {
probabilityIndices[order[0]][counts[order[0]]] = matrixBufferHelper.getOffsetIndex(branchIndices[i]);
edgeLengths[order[0]][counts[order[0]]] = edgeLength[i];
counts[order[0]] ++;
} else {
double sum = 0.0;
for (double w : weights) {
sum += w;
}
Deque<Integer> bufferIndices = new ArrayDeque<Integer>();
for (int j = 0; j < order.length; j++) {
int buffer;
boolean done;
do {
done = true;
buffer = popAvailableBuffer();
if (buffer < 0) {
// no buffers available
if (DEBUG) {
System.out.println("Ran out of buffers for transition matrices - computing current list.");
}
// we have run out of buffers, process what we have and continue...
computeTransitionMatrices(beagle, probabilityIndices, edgeLengths, counts);
convolveMatrices(beagle, convolutionList);
// reset the counts
for (int k = 0; k < eigenCount; k ++) {
counts[k] = 0;
}
done = false;
}
} while (!done);
probabilityIndices[order[j]][counts[order[j]]] = buffer;
edgeLengths[order[j]][counts[order[j]]] = weights[j] * edgeLength[i] / sum;
counts[order[j]]++;
bufferIndices.add(buffer);
}
bufferIndices.add(matrixBufferHelper.getOffsetIndex(branchIndices[i]));
convolutionList.add(bufferIndices);
}
}
computeTransitionMatrices(beagle, probabilityIndices, edgeLengths, counts);
convolveMatrices(beagle, convolutionList);
}
private void computeTransitionMatrices(Beagle beagle, int[][] probabilityIndices, double[][] edgeLengths, int[] counts) {
if (DEBUG) {
System.out.print("Computing matrices:");
}
for (int i = 0; i < eigenCount; i++) {
if (DEBUG) {
for (int j = 0; j < counts[i]; j++) {
System.out.print(" " + probabilityIndices[i][j]);
}
}
beagle.updateTransitionMatrices(eigenBufferHelper.getOffsetIndex(i),
probabilityIndices[i],
null, // firstDerivativeIndices
null, // secondDerivativeIndices
edgeLengths[i],
counts[i]);
}
if (DEBUG) {
System.out.println();
}
}
private void convolveMatrices(Beagle beagle, List<Deque<Integer>> convolutionList) {
while (convolutionList.size() > 0) {
int[] firstConvolutionBuffers = new int[nodeCount];
int[] secondConvolutionBuffers = new int[nodeCount];
int[] resultConvolutionBuffers = new int[nodeCount];
int operationsCount = 0;
List<Deque<Integer>> empty = new ArrayList<Deque<Integer>>();
for (Deque<Integer> convolve : convolutionList) {
if (convolve.size() > 3) {
firstConvolutionBuffers[operationsCount] = convolve.pop();
secondConvolutionBuffers[operationsCount] = convolve.pop();
int buffer;
boolean done;
do {
done = true;
buffer = popAvailableBuffer();
if (buffer < 0) {
// no buffers available
// throw new RuntimeException("All out of buffers");
// we have run out of buffers, process what we have and continue...
if (DEBUG) {
System.out.println("Ran out of buffers for convolving - computing current list.");
System.out.print("Convolving matrices:");
for (int i = 0; i < operationsCount; i++) {
System.out.print(" " + firstConvolutionBuffers[i] + "*" + secondConvolutionBuffers[i] + "->" + resultConvolutionBuffers[i]);
}
System.out.println();
}
beagle.convolveTransitionMatrices(firstConvolutionBuffers, // A
secondConvolutionBuffers, // B
resultConvolutionBuffers, // C
operationsCount // count
);
operationsCount = 0;
done = false;
}
} while(!done);
resultConvolutionBuffers[operationsCount] = buffer;
convolve.push(buffer);
operationsCount ++;
} else if (convolve.size() == 3) {
firstConvolutionBuffers[operationsCount] = convolve.pop();
secondConvolutionBuffers[operationsCount] = convolve.pop();
resultConvolutionBuffers[operationsCount] = convolve.pop();
operationsCount ++;
} else {
throw new RuntimeException("Unexpected convolve list size");
}
if (convolve.size() == 0) {
empty.add(convolve);
}
}
if (DEBUG) {
System.out.print("Convolving matrices:");
for (int i = 0; i < operationsCount; i++) {
System.out.print(" " + firstConvolutionBuffers[i] + "*" + secondConvolutionBuffers[i] + "->" + resultConvolutionBuffers[i]);
}
System.out.println();
}
beagle.convolveTransitionMatrices(firstConvolutionBuffers, // A
secondConvolutionBuffers, // B
resultConvolutionBuffers, // C
operationsCount // count
);
for (int i = 0; i < operationsCount; i++) {
if (firstConvolutionBuffers[i] >= matrixBufferHelper.getBufferCount()) {
pushAvailableBuffer(firstConvolutionBuffers[i]);
}
if (secondConvolutionBuffers[i] >= matrixBufferHelper.getBufferCount()) {
pushAvailableBuffer(secondConvolutionBuffers[i]);
}
}
convolutionList.removeAll(empty);
}
}
private int popAvailableBuffer() {
if (availableBuffers.isEmpty()) {
return -1;
}
return availableBuffers.pop();
}
private void pushAvailableBuffer(int index) {
availableBuffers.push(index);
}
public double[] getRootStateFrequencies() {
return substitutionModelList.get(0).getFrequencyModel().getFrequencies();
}// END: getStateFrequencies
public void flipMatrixBuffer(int branchIndex) {
matrixBufferHelper.flipOffset(branchIndex);
}
public int getMatrixIndex(int branchIndex) {
return matrixBufferHelper.getOffsetIndex(branchIndex);
}
public void storeState() {
eigenBufferHelper.storeState();
matrixBufferHelper.storeState();
}
public void restoreState() {
eigenBufferHelper.restoreState();
matrixBufferHelper.restoreState();
}
}// END: class