package statalign.base.hmm;
import statalign.base.Utils;
/**
* This class implements a simplified version of the TKF92 pair-HMM.
*
* @author novak
*
*/
public class HmmTkf92 extends Hmm2 {
private double _B, _T, _R, _L, _M;
private double HMM2_11() { return _R+(1-_R)*(1-_L*_B)*(_L/_M)*Math.exp(-_M*_T); }
private double HMM2_12() { return (1-_R)*(1-_L*_B)*(_L/_M)*(1-Math.exp(-_M*_T)); }
private double HMM2_13() { return (1-_R)*_L*_B; }
private double HMM2_21() { return (1-_R)*(_L*_B/(1-Math.exp(-_M*_T)))*Math.exp(-_M*_T); }
private double HMM2_22() { return _R+(1-_R)*_L*_B; }
private double HMM2_23() { return (1-_R)*(1-_M*_B/(1-Math.exp(-_M*_T))); }
private double HMM2_31() { return (1-_R)*(1-_L*_B)*_L/_M*Math.exp(-_M*_T); }
private double HMM2_32() { return (1-_R)*(1-_L*_B)*_L/_M*(1-Math.exp(-_M*_T)); }
private double HMM2_33() { return _R+(1-_R)*_L*_B; }
private double HMM2_S1() { return (1-_L*_B)*_L/_M*Math.exp(-_M*_T); }
private double HMM2_S2() { return (1-_L*_B)*_L/_M*(1-Math.exp(-_M*_T)); }
private double HMM2_S3() { return _L*_B; }
private double HMM2_SE() { return (1-_L*_B)*(1-_L/_M); }
private double HMM2_1E() { return (1-_R)*(1-_L*_B)*(1-_L/_M); }
private double HMM2_2E() { return (1-_R)*(_M*_B/(1-Math.exp(-_M*_T)))*(1-_L/_M); }
private double HMM2_3E() { return (1-_R)*(1-_L*_B)*(1-_L/_M); }
/* Emission (pattern) of the states: columns are states, 1st row: parent, 2nd row: child */
private int stateEmit[][] = {{0,1,1,0,0},
{0,1,0,1,0}};
/* converts states' emission pattern as a binary number (p=2,ch=1) to state # : e.g. 3->1, 2->2
4->4 is virtual, so that end state has a pattern, too */
private int emitPatt2State[] = {0,3,2,1,4};
/**
* This constructor creates a TKF92 pair-HMM for the tree
* @param defParams The default parameters of the model. Currently it is r = 0.2, lambda = 0.009,
* mu = 0.011.
*/
public HmmTkf92(double defParams[]) {
if(defParams != null) {
params = new double[defParams.length];
System.arraycopy(defParams, 0, params, 0, defParams.length);
} else {
params = new double[] { 0.5, 0.009, 0.011 }; // TKF92 parameters: r, lambda, mu
// !!! param init?
}
}
/**
* Returns an array specifying the emission pattern of each state. See class Hmm.
*/
public int[][] getStateEmit() {
return stateEmit;
}
/**
* Returns a conversion array from emission patterns (coded as integers) into state
* indices. See class Hmm for details.
*/
public int[] getEmitPatt2State() {
return emitPatt2State;
}
/**
* Returns the index of the start state.
*/
public int getStart() {
return 0;
}
/**
* Returns the index of the end state.
*/
public int getEnd() {
return 4;
}
/**
* Calculates a transition matrix given an edge length.
* See the abstract class Hmm2 for details.
*/
public double[][] preCalcTransMatrix(double[][] transMatrix, double t) {
_R = params[0]; _L = params[1]; _M = params[2];
return calcTransMatrix(transMatrix,t);
}
public double[][] preCalcTransMatrix(double[][] transMatrix, double t, double[] newParams) {
_R = newParams[0]; _L = newParams[1]; _M = newParams[2];
double[][] result = calcTransMatrix(transMatrix,t);
_R = params[0]; _L = params[1]; _M = params[2];
return result;
}
public double[][] calcTransMatrix(double[][] transMatrix, double t) {
if(transMatrix == null)
transMatrix = new double[5][5]; // TKF92 has 5 states including start (st. 0) & end (st. 4)
_T = t; _B = Math.exp((_L-_M)*_T); _B = (1-_B)/(_M-_L*_B);
transMatrix[1][1] = Math.log(HMM2_11());
transMatrix[1][2] = Math.log(HMM2_12());
transMatrix[1][3] = Math.log(HMM2_13());
transMatrix[2][1] = Math.log(HMM2_21());
transMatrix[2][2] = Math.log(HMM2_22());
transMatrix[2][3] = Math.log(HMM2_23());
transMatrix[3][1] = Math.log(HMM2_31());
transMatrix[3][2] = Math.log(HMM2_32());
transMatrix[3][3] = Math.log(HMM2_33());
transMatrix[0][1] = Math.log(HMM2_S1());
transMatrix[0][2] = Math.log(HMM2_S2());
transMatrix[0][3] = Math.log(HMM2_S3());
transMatrix[0][4] = Math.log(HMM2_SE());
transMatrix[1][4] = Math.log(HMM2_1E());
transMatrix[2][4] = Math.log(HMM2_2E());
transMatrix[3][4] = Math.log(HMM2_3E());
for(int i = 0; i < 5; i++) {
transMatrix[i][0] = Utils.log0;
transMatrix[4][i] = Utils.log0;
}
return transMatrix;
}
@Override
public double getLogStationaryProb(int length) {
double R = params[0];
double L = params[1];
double M = params[2];
return (Math.log((1-L/M)*(L/M)*(1-R)) +
(length - 1)*Math.log(((L/M)*(1-R)+R)));
// cf. Thorne et al. (1992)
}
/**
* For testing purposes.
* @param args No argument is used.
*/
public static void main(String args[]) {
HmmTkf92 hmm = new HmmTkf92(null);
double tm[][] = hmm.preCalcTransMatrix(null, 1);
boolean failed = false;
for(int i = 0; i < 5; i++) {
double sum = Utils.log0;
for(int j = 0; j < 5; j++)
sum = Utils.logAdd(sum, tm[i][j]);
sum = Math.exp(sum);
System.out.println("Row "+i+" sum: "+sum);
if(Math.abs(sum-(i==4?0.0:1.0)) > 1e-5) {
failed = true;
for(int j = 0; j < 5; j++)
System.out.println(" "+i+"->"+j+" likelihood: "+Math.exp(tm[i][j]));
}
}
System.out.println(failed?"Test failed.":"Test passed.");
}
}