/* * AbstractLikelihoodCore.java * * Copyright (C) 2002-2006 Alexei Drummond and Andrew Rambaut * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package dr.evomodel.newtreelikelihood; import dr.evomodel.substmodel.SubstitutionModel; import dr.evomodel.sitemodel.SiteModel; public class GeneralLikelihoodCore implements LikelihoodCore { public static final boolean DEBUG = false; protected final int stateCount; protected int nodeCount; protected int stateTipCount; protected int patternCount; protected int partialsSize; protected int matrixSize; protected int matrixCount; protected double[] cMatrix; protected double[] storedCMatrix; protected double[] eigenValues; protected double[] storedEigenValues; protected double[] frequencies; protected double[] storedFrequencies; protected double[] categoryProportions; protected double[] storedCategoryProportions; protected double[] categoryRates; protected double[] storedCategoryRates; protected double[][][] partials; protected int[][] states; protected double[][][] matrices; protected int[] currentMatricesIndices; protected int[] storedMatricesIndices; protected int[] currentPartialsIndices; protected int[] storedPartialsIndices; protected boolean useScaling = false; protected double[][][] scalingFactors; /** * Constructor * * @param stateCount number of states */ public GeneralLikelihoodCore(int stateCount) { this.stateCount = stateCount; } public boolean canHandleTipPartials() { return true; } public boolean canHandleTipStates() { return true; } public boolean canHandleDynamicRescaling() { return true; } /** * initializes partial likelihood arrays. * * @param nodeCount the number of nodes in the tree * @param stateTipCount the number of tips with states (rather than partials - can be zero) * @param patternCount the number of patterns * @param matrixCount the number of matrices (i.e., number of categories) */ public void initialize(int nodeCount, int stateTipCount, int patternCount, int matrixCount) { this.nodeCount = nodeCount; this.stateTipCount = stateTipCount; this.patternCount = patternCount; this.matrixCount = matrixCount; cMatrix = new double[stateCount * stateCount * stateCount]; storedCMatrix = new double[stateCount * stateCount * stateCount]; eigenValues = new double[stateCount]; storedEigenValues = new double[stateCount]; frequencies = new double[stateCount]; storedFrequencies = new double[stateCount]; categoryRates = new double[matrixCount]; storedCategoryRates = new double[matrixCount]; categoryProportions = new double[matrixCount]; storedCategoryProportions = new double[matrixCount]; partialsSize = patternCount * stateCount * matrixCount; partials = new double[2][nodeCount][partialsSize]; scalingFactors = new double[2][nodeCount][patternCount]; states = new int[nodeCount][patternCount * matrixCount]; matrixSize = (stateCount + 1) * stateCount; matrices = new double[2][nodeCount][matrixCount * matrixSize]; currentMatricesIndices = new int[nodeCount]; storedMatricesIndices = new int[nodeCount]; currentPartialsIndices = new int[nodeCount]; storedPartialsIndices = new int[nodeCount]; } /** * cleans up and deallocates arrays. */ public void finalize() throws Throwable { super.finalize(); nodeCount = 0; patternCount = 0; matrixCount = 0; partials = null; currentPartialsIndices = null; storedPartialsIndices = null; states = null; matrices = null; currentMatricesIndices = null; storedMatricesIndices = null; } /** * Sets partials for a tip */ public void setTipPartials(int tipIndex, double[] partials) { int k = 0; for (int i = 0; i < matrixCount; i++) { System.arraycopy(partials, 0, this.partials[0][tipIndex], k, partials.length); k += partials.length; } } /** * Sets partials for a tip - these are numbered from 0 and remain * constant throughout the run. * * @param tipIndex the tip index * @param states an array of patternCount state indices */ public void setTipStates(int tipIndex, int[] states) { int k = 0; for (int i = 0; i < matrixCount; i++) { for (int j = 0; j < states.length; j++) { this.states[tipIndex][k] = (states[j] < stateCount ? states[j] : stateCount); k++; } } } /** * Called when the substitution model has been updated so precalculations * can be obtained. */ public void updateSubstitutionModel(SubstitutionModel substitutionModel) { System.arraycopy(substitutionModel.getFrequencyModel().getFrequencies(), 0, frequencies, 0, frequencies.length); double[][] Evec = substitutionModel.getEigenVectors(); // if (DEBUG) System.err.println(new dr.math.matrixAlgebra.Vector(Evec[0])); double[][] Ievc = substitutionModel.getInverseEigenVectors(); // if (DEBUG) System.err.println(new dr.math.matrixAlgebra.Vector(Ievc[0])); int l =0; for (int i = 0; i < stateCount; i++) { for (int j = 0; j < stateCount; j++) { for (int k = 0; k < stateCount; k++) { cMatrix[l] = Evec[i][k] * Ievc[k][j]; l++; } } } System.arraycopy(substitutionModel.getEigenValues(), 0, eigenValues, 0, eigenValues.length); // if (DEBUG) System.err.println(new dr.math.matrixAlgebra.Vector(cMatrix)); // if (DEBUG) System.err.println(cMatrix[stateCount*stateCount*stateCount-1]); // if (DEBUG) System.exit(-1); } /** * Called when the site model has been updated so rates and proportions * can be obtained. */ public void updateSiteModel(SiteModel siteModel) { for (int i = 0; i < categoryRates.length; i++) { categoryRates[i] = siteModel.getRateForCategory(i); } System.arraycopy(siteModel.getCategoryProportions(), 0, categoryProportions, 0, categoryProportions.length); } /** * Specify the branch lengths that are being updated. These will be used to construct * the transition probability matrices for each branch. * * @param branchUpdateIndices the node indices of the branches to be updated * @param branchLengths the branch lengths of the branches to be updated * @param branchUpdateCount the number of branch updates */ public void updateMatrices(int[] branchUpdateIndices, double[] branchLengths, int branchUpdateCount) { for (int i = 0; i < branchUpdateCount; i++) { if (DEBUG) System.err.println("Updating matrix for node "+branchUpdateIndices[i]); currentMatricesIndices[branchUpdateIndices[i]] = 1 - currentMatricesIndices[branchUpdateIndices[i]]; calculateTransitionProbabilityMatrices(branchUpdateIndices[i], branchLengths[i]); if (DEBUG && branchUpdateIndices[i] == 0) { System.err.println(matrices[currentMatricesIndices[0]][0][0]); System.err.println(matrices[currentMatricesIndices[0]][0][184]); } } } int debugCount = 0; private void calculateTransitionProbabilityMatrices(int nodeIndex, double branchLength) { double[] tmp = new double[stateCount]; int n = 0; for (int l = 0; l < matrixCount; l++) { // if (DEBUG) System.err.println("1: Rate "+l+" = "+categoryRates[l]); for (int i = 0; i < stateCount; i++) { tmp[i] = Math.exp(eigenValues[i] * branchLength * categoryRates[l]); } // if (DEBUG) System.err.println(new dr.math.matrixAlgebra.Vector(tmp)); // if (DEBUG) System.exit(-1); int m = 0; for (int i = 0; i < stateCount; i++) { for (int j = 0; j < stateCount; j++) { double sum = 0.0; for (int k = 0; k < stateCount; k++) { sum += cMatrix[m] * tmp[k]; m++; } // if (DEBUG) System.err.println("1: matrices[][]["+n+"] = "+sum); matrices[currentMatricesIndices[nodeIndex]][nodeIndex][n] = sum; n++; } matrices[currentMatricesIndices[nodeIndex]][nodeIndex][n] = 1.0; n++; } // if (DEBUG) System.err.println(new dr.math.matrixAlgebra.Vector(matrices[currentMatricesIndices[nodeIndex]][nodeIndex])); // if (DEBUG) System.exit(0); } } public void updatePartials(int[] operations, int[] dependencies, int operationCount, boolean rescale) { updatePartials(operations, dependencies, operationCount); } /** * Specify the updates to be made. This specifies which partials are to be * calculated and by giving the dependencies between the operations, they * can be done in parallel if required. * @param operations an array of partial likelihood calculations to be done. * This is an array of triplets of node numbers specifying the two source * (descendent) nodes and the destination node for which the partials are * being calculated. * @param dependencies an array of dependencies for each of the operations * This is an array of pairs of integers for each of the operations above. * The first of each pair specifies which future operation is dependent * on this one. The second is just a boolean (0,1) as to whether this operation * is dependent on another. If these dependencies are not used then the * operations can safely be done in order. * @param operationCount the number of operators */ public void updatePartials(int[] operations, int[] dependencies, int operationCount) { int x = 0; for (int op = 0; op < operationCount; op++) { int nodeIndex1 = operations[x]; x++; int nodeIndex2 = operations[x]; x++; int nodeIndex3 = operations[x]; x++; currentPartialsIndices[nodeIndex3] = 1 - currentPartialsIndices[nodeIndex3]; if (nodeIndex1 < stateTipCount) { if (nodeIndex2 < stateTipCount) { updateStatesStates(nodeIndex1, nodeIndex2, nodeIndex3); } else { updateStatesPartials(nodeIndex1, nodeIndex2, nodeIndex3); } } else { if (nodeIndex2 < stateTipCount) { updateStatesPartials(nodeIndex2, nodeIndex1, nodeIndex3); } else { updatePartialsPartials(nodeIndex1, nodeIndex2, nodeIndex3); } } if (useScaling) { scalePartials(nodeIndex3); } } } /** * Calculates partial likelihoods at a node when both children have states. */ private void updateStatesStates(int nodeIndex1, int nodeIndex2, int nodeIndex3) { double[] matrices1 = matrices[currentMatricesIndices[nodeIndex1]][nodeIndex1]; double[] matrices2 = matrices[currentMatricesIndices[nodeIndex2]][nodeIndex2]; int[] states1 = states[nodeIndex1]; int[] states2 = states[nodeIndex2]; double[] partials3 = partials[currentPartialsIndices[nodeIndex3]][nodeIndex3]; int v = 0; for (int l = 0; l < matrixCount; l++) { for (int k = 0; k < patternCount; k++) { int state1 = states1[k]; int state2 = states2[k]; int w = l * matrixSize; for (int i = 0; i < stateCount; i++) { partials3[v] = matrices1[w + state1] * matrices2[w + state2]; v++; w += (stateCount + 1); } } } } /** * Calculates partial likelihoods at a node when one child has states and one has partials. * @param nodeIndex1 * @param nodeIndex2 * @param nodeIndex3 */ private void updateStatesPartials(int nodeIndex1, int nodeIndex2, int nodeIndex3) { double[] matrices1 = matrices[currentMatricesIndices[nodeIndex1]][nodeIndex1]; double[] matrices2 = matrices[currentMatricesIndices[nodeIndex2]][nodeIndex2]; int[] states1 = states[nodeIndex1]; double[] partials2 = partials[currentPartialsIndices[nodeIndex2]][nodeIndex2]; double[] partials3 = partials[currentPartialsIndices[nodeIndex3]][nodeIndex3]; double sum, tmp; int u = 0; int v = 0; for (int l = 0; l < matrixCount; l++) { for (int k = 0; k < patternCount; k++) { int state1 = states1[k]; int w = l * matrixSize; for (int i = 0; i < stateCount; i++) { tmp = matrices1[w + state1]; sum = 0.0; for (int j = 0; j < stateCount; j++) { sum += matrices2[w] * partials2[v + j]; w++; } // increment for the extra column at the end w++; partials3[u] = tmp * sum; u++; } v += stateCount; } } } private void updatePartialsPartials(int nodeIndex1, int nodeIndex2, int nodeIndex3) { double[] matrices1 = matrices[currentMatricesIndices[nodeIndex1]][nodeIndex1]; double[] matrices2 = matrices[currentMatricesIndices[nodeIndex2]][nodeIndex2]; double[] partials1 = partials[currentPartialsIndices[nodeIndex1]][nodeIndex1]; double[] partials2 = partials[currentPartialsIndices[nodeIndex2]][nodeIndex2]; double[] partials3 = partials[currentPartialsIndices[nodeIndex3]][nodeIndex3]; double sum1, sum2; int u = 0; int v = 0; for (int l = 0; l < matrixCount; l++) { for (int k = 0; k < patternCount; k++) { int w = l * matrixSize; for (int i = 0; i < stateCount; i++) { sum1 = sum2 = 0.0; for (int j = 0; j < stateCount; j++) { sum1 += matrices1[w] * partials1[v + j]; sum2 += matrices2[w] * partials2[v + j]; w++; } // increment for the extra column at the end w++; partials3[u] = sum1 * sum2; u++; } v += stateCount; } if (DEBUG) { // System.err.println("1:PP node = "+nodeIndex3); // for(int p=0; p<partials3.length; p++) { // System.err.println("1:PP\t"+partials3[p]); // } System.err.println("node = "+nodeIndex3); System.err.println(new dr.math.matrixAlgebra.Vector(partials3)); System.err.println(new dr.math.matrixAlgebra.Vector(scalingFactors[currentPartialsIndices[nodeIndex3]][nodeIndex3])); //System.exit(-1); } } } /** * Scale the partials at a given node. This uses a scaling suggested by Ziheng Yang in * Yang (2000) J. Mol. Evol. 51: 423-432 * <p/> * This function looks over the partial likelihoods for each state at each pattern * and finds the largest. If this is less than the scalingThreshold (currently set * to 1E-40) then it rescales the partials for that pattern by dividing by this number * (i.e., normalizing to between 0, 1). It then stores the log of this scaling. * This is called for every internal node after the partials are calculated so provides * most of the performance hit. Ziheng suggests only doing this on a proportion of nodes * but this sounded like a headache to organize (and he doesn't use the threshold idea * which improves the performance quite a bit). * * @param nodeIndex */ protected void scalePartials(int nodeIndex) { int u = 0; for (int i = 0; i < patternCount; i++) { double scaleFactor = 0.0; int v = u; for (int k = 0; k < matrixCount; k++) { for (int j = 0; j < stateCount; j++) { if (partials[currentPartialsIndices[nodeIndex]][nodeIndex][v] > scaleFactor) { scaleFactor = partials[currentPartialsIndices[nodeIndex]][nodeIndex][v]; } v++; } v += (patternCount - 1) * stateCount; } if (scaleFactor < 1E+40) { v = u; for (int k = 0; k < matrixCount; k++) { for (int j = 0; j < stateCount; j++) { partials[currentPartialsIndices[nodeIndex]][nodeIndex][v] /= scaleFactor; v++; } v += (patternCount - 1) * stateCount; } scalingFactors[currentPartialsIndices[nodeIndex]][nodeIndex][i] = Math.log(scaleFactor); } else { scalingFactors[currentPartialsIndices[nodeIndex]][nodeIndex][i] = 0.0; } u += stateCount; } } /** * This function returns the scaling factor for that pattern by summing over * the log scalings used at each node. If scaling is off then this just returns * a 0. * * @return the log scaling factor */ public double getLogScalingFactor(int pattern) { double logScalingFactor = 0.0; if (useScaling) { for (int i = 0; i < nodeCount; i++) { logScalingFactor += scalingFactors[currentPartialsIndices[i]][i][pattern]; if (DEBUG && pattern == 1) System.err.println("Adding "+scalingFactors[currentPartialsIndices[i]][i][pattern]); } } if (DEBUG) System.err.println("1:SF "+logScalingFactor+" for "+pattern); return logScalingFactor; } /** * Calculates pattern log likelihoods at a node. * * @param rootNodeIndex the index of the root node * @param outLogLikelihoods an array into which the log likelihoods will go */ public void calculateLogLikelihoods(int rootNodeIndex, double[] outLogLikelihoods) { // @todo I have a feeling this could be done in a single set of nested loops. double[] rootPartials = partials[currentPartialsIndices[rootNodeIndex]][rootNodeIndex]; double[] tmp = new double[patternCount * stateCount]; int u = 0; int v = 0; for (int k = 0; k < patternCount; k++) { for (int i = 0; i < stateCount; i++) { tmp[u] = rootPartials[v] * categoryProportions[0]; u++; v++; } } for (int l = 1; l < matrixCount; l++) { u = 0; for (int k = 0; k < patternCount; k++) { for (int i = 0; i < stateCount; i++) { tmp[u] += rootPartials[v] * categoryProportions[l]; u++; v++; } } } u = 0; for (int k = 0; k < patternCount; k++) { double sum = 0.0; for (int i = 0; i < stateCount; i++) { sum += frequencies[i] * tmp[u]; u++; } outLogLikelihoods[k] = Math.log(sum) + getLogScalingFactor(k); if (DEBUG) { System.err.println("log lik "+k+" = "+outLogLikelihoods[k]); } } if (DEBUG) System.exit(-1); } /** * Store current state */ public void storeState() { System.arraycopy(cMatrix, 0, storedCMatrix, 0, cMatrix.length); System.arraycopy(eigenValues, 0, storedEigenValues, 0, eigenValues.length); System.arraycopy(frequencies, 0, storedFrequencies, 0, frequencies.length); System.arraycopy(categoryRates, 0, storedCategoryRates, 0, categoryRates.length); System.arraycopy(categoryProportions, 0, storedCategoryProportions, 0, categoryProportions.length); System.arraycopy(currentMatricesIndices, 0, storedMatricesIndices, 0, nodeCount); System.arraycopy(currentPartialsIndices, 0, storedPartialsIndices, 0, nodeCount); } /** * Restore the stored state */ public void restoreState() { // Rather than copying the stored stuff back, just swap the pointers... double[] tmp = cMatrix; cMatrix = storedCMatrix; storedCMatrix = tmp; tmp = eigenValues; eigenValues = storedEigenValues; storedEigenValues = tmp; tmp = frequencies; frequencies = storedFrequencies; storedFrequencies = tmp; tmp = categoryRates; categoryRates = storedCategoryRates; storedCategoryRates = tmp; tmp = categoryProportions; categoryProportions = storedCategoryProportions; storedCategoryProportions = tmp; int[] tmp3 = currentMatricesIndices; currentMatricesIndices = storedMatricesIndices; storedMatricesIndices = tmp3; int[] tmp4 = currentPartialsIndices; currentPartialsIndices = storedPartialsIndices; storedPartialsIndices = tmp4; } }