/*
* File GeneralSubstitutionModel.java
*
* Copyright (C) 2010 Remco Bouckaert remco@cs.auckland.ac.nz
*
* This file is not copyright Remco! It is copied from BEAST 1.
*
* This file is part of BEAST2.
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership and licensing.
*
* BEAST is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
*
* BEAST is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with BEAST; if not, write to the
* Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
* Boston, MA 02110-1301 USA
*/
package beast.evolution.substitutionmodel;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import beast.core.Description;
import beast.core.Function;
import beast.core.Input;
import beast.core.Input.Validate;
import beast.evolution.datatype.DataType;
import beast.evolution.tree.Node;
@Description("Specifies transition probability matrix with no restrictions on the rates other " +
"than that one of the is equal to one and the others are specified relative to " +
"this unit rate. Works for any number of states.")
public class GeneralSubstitutionModel extends SubstitutionModel.Base {
final public Input<Function> ratesInput =
new Input<>("rates", "Rate parameter which defines the transition rate matrix. " +
"Only the off-diagonal entries need to be specified (diagonal makes row sum to zero in a " +
"rate matrix). Entry i specifies the rate from floor(i/(n-1)) to i%(n-1)+delta where " +
"n is the number of states and delta=1 if floor(i/(n-1)) <= i%(n-1) and 0 otherwise.", Validate.REQUIRED);
final public Input<String> eigenSystemClass = new Input<>("eigenSystem", "Name of the class used for creating an EigenSystem", DefaultEigenSystem.class.getName());
/**
* a square m_nStates x m_nStates matrix containing current rates *
*/
protected double[][] rateMatrix;
@Override
public void initAndValidate() {
super.initAndValidate();
updateMatrix = true;
nrOfStates = frequencies.getFreqs().length;
if (ratesInput.get().getDimension() != nrOfStates * (nrOfStates - 1)) {
throw new IllegalArgumentException("Dimension of input 'rates' is " + ratesInput.get().getDimension() + " but a " +
"rate matrix of dimension " + nrOfStates + "x" + (nrOfStates - 1) + "=" + nrOfStates * (nrOfStates - 1) + " was " +
"expected");
}
try {
eigenSystem = createEigenSystem();
} catch (SecurityException | ClassNotFoundException | InstantiationException | IllegalAccessException | IllegalArgumentException
| InvocationTargetException e) {
throw new IllegalArgumentException(e.getMessage());
}
//eigenSystem = new DefaultEigenSystem(m_nStates);
rateMatrix = new double[nrOfStates][nrOfStates];
relativeRates = new double[ratesInput.get().getDimension()];
storedRelativeRates = new double[ratesInput.get().getDimension()];
} // initAndValidate
/**
* create an EigenSystem of the class indicated by the eigenSystemClass input
* @throws ClassNotFoundException
* @throws SecurityException
* @throws InvocationTargetException
* @throws IllegalArgumentException
* @throws IllegalAccessException
* @throws InstantiationException *
*/
protected EigenSystem createEigenSystem() throws SecurityException, ClassNotFoundException, InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException {
Constructor<?>[] ctors = Class.forName(eigenSystemClass.get()).getDeclaredConstructors();
Constructor<?> ctor = null;
for (int i = 0; i < ctors.length; i++) {
ctor = ctors[i];
if (ctor.getGenericParameterTypes().length == 1)
break;
}
ctor.setAccessible(true);
return (EigenSystem) ctor.newInstance(nrOfStates);
}
protected double[] relativeRates;
protected double[] storedRelativeRates;
protected EigenSystem eigenSystem;
protected EigenDecomposition eigenDecomposition;
private EigenDecomposition storedEigenDecomposition;
protected boolean updateMatrix = true;
private boolean storedUpdateMatrix = true;
@Override
public void getTransitionProbabilities(Node node, double startTime, double endTime, double rate, double[] matrix) {
double distance = (startTime - endTime) * rate;
int i, j, k;
double temp;
// this must be synchronized to avoid being called simultaneously by
// two different likelihood threads - AJD
synchronized (this) {
if (updateMatrix) {
setupRelativeRates();
setupRateMatrix();
eigenDecomposition = eigenSystem.decomposeMatrix(rateMatrix);
updateMatrix = false;
}
}
// is the following really necessary?
// implemented a pool of iexp matrices to support multiple threads
// without creating a new matrix each call. - AJD
// a quick timing experiment shows no difference - RRB
double[] iexp = new double[nrOfStates * nrOfStates];
// Eigen vectors
double[] Evec = eigenDecomposition.getEigenVectors();
// inverse Eigen vectors
double[] Ievc = eigenDecomposition.getInverseEigenVectors();
// Eigen values
double[] Eval = eigenDecomposition.getEigenValues();
for (i = 0; i < nrOfStates; i++) {
temp = Math.exp(distance * Eval[i]);
for (j = 0; j < nrOfStates; j++) {
iexp[i * nrOfStates + j] = Ievc[i * nrOfStates + j] * temp;
}
}
int u = 0;
for (i = 0; i < nrOfStates; i++) {
for (j = 0; j < nrOfStates; j++) {
temp = 0.0;
for (k = 0; k < nrOfStates; k++) {
temp += Evec[i * nrOfStates + k] * iexp[k * nrOfStates + j];
}
matrix[u] = Math.abs(temp);
u++;
}
}
} // getTransitionProbabilities
/**
* access to (copy of) rate matrix *
*/
protected double[][] getRateMatrix() {
return rateMatrix.clone();
}
protected void setupRelativeRates() {
Function rates = this.ratesInput.get();
for (int i = 0; i < rates.getDimension(); i++) {
relativeRates[i] = rates.getArrayValue(i);
}
}
/**
* sets up rate matrix *
*/
protected void setupRateMatrix() {
double[] freqs = frequencies.getFreqs();
for (int i = 0; i < nrOfStates; i++) {
rateMatrix[i][i] = 0;
for (int j = 0; j < i; j++) {
rateMatrix[i][j] = relativeRates[i * (nrOfStates - 1) + j];
}
for (int j = i + 1; j < nrOfStates; j++) {
rateMatrix[i][j] = relativeRates[i * (nrOfStates - 1) + j - 1];
}
}
// bring in frequencies
for (int i = 0; i < nrOfStates; i++) {
for (int j = i + 1; j < nrOfStates; j++) {
rateMatrix[i][j] *= freqs[j];
rateMatrix[j][i] *= freqs[i];
}
}
// set up diagonal
for (int i = 0; i < nrOfStates; i++) {
double sum = 0.0;
for (int j = 0; j < nrOfStates; j++) {
if (i != j)
sum += rateMatrix[i][j];
}
rateMatrix[i][i] = -sum;
}
// normalise rate matrix to one expected substitution per unit time
double subst = 0.0;
for (int i = 0; i < nrOfStates; i++)
subst += -rateMatrix[i][i] * freqs[i];
for (int i = 0; i < nrOfStates; i++) {
for (int j = 0; j < nrOfStates; j++) {
rateMatrix[i][j] = rateMatrix[i][j] / subst;
}
}
} // setupRateMatrix
/**
* CalculationNode implementation follows *
*/
@Override
public void store() {
storedUpdateMatrix = updateMatrix;
if( eigenDecomposition != null ) {
storedEigenDecomposition = eigenDecomposition.copy();
}
// System.arraycopy(relativeRates, 0, storedRelativeRates, 0, relativeRates.length);
super.store();
}
/**
* Restore the additional stored state
*/
@Override
public void restore() {
updateMatrix = storedUpdateMatrix;
// To restore all this stuff just swap the pointers...
// double[] tmp1 = storedRelativeRates;
// storedRelativeRates = relativeRates;
// relativeRates = tmp1;
if( storedEigenDecomposition != null ) {
EigenDecomposition tmp = storedEigenDecomposition;
storedEigenDecomposition = eigenDecomposition;
eigenDecomposition = tmp;
}
super.restore();
}
@Override
protected boolean requiresRecalculation() {
// we only get here if something is dirty
updateMatrix = true;
return true;
}
/**
* This function returns the Eigen vectors.
*
* @return the array
*/
@Override
public EigenDecomposition getEigenDecomposition(Node node) {
synchronized (this) {
if (updateMatrix) {
setupRelativeRates();
setupRateMatrix();
eigenDecomposition = eigenSystem.decomposeMatrix(rateMatrix);
updateMatrix = false;
}
}
return eigenDecomposition;
}
@Override
public boolean canHandleDataType(DataType dataType) {
return dataType.getStateCount() != Integer.MAX_VALUE;
}
} // class GeneralSubstitutionModel