package statalign.base.hmm; import java.util.Arrays; import statalign.base.Utils; /** * An implementation of the abstract pair-HMM that emits characters into non-observable sequences * Used when a new alignment is proposed that aligns together two substrings via an ancestral sequence. * * @author novak * */ public class HmmNonParam extends HmmSilent { //final double P = 0.9999999999; double P = 0.99999; // original //final double P = 0.99; final double Q = 0.6; // original //final double Q = 0.2; final int SILENT = 7; /* transition matrix for 3-sequence alignment HMM, st. 7 is silent, st. 0 is start, st. 6 is end*/ private double transMatrix[][] = null; @Override public void updateParam(double[] _P) { P = _P[0]; updateTransMatrix(); } private void updateTransMatrix() { transMatrix = new double[][] {{0,P*P*P*P*P,P*P*P*P*(1-P),P*P*P*P*(1-P),1-P,P*(1-P),P-P*P*P-P*(1-P),P*P*P*(1-P)*(1-P)}, {0,1-P+P*P*P*P*P*P,P*P*P*P*P*(1-P),P*P*P*P*P*(1-P),P*(1-P),P*P*(1-P),P*P*P*(1-P),P*P*P*P*(1-P)*(1-P)}, {0,P*P*P*P*P*Q,1-Q+P*P*P*P*(1-P)*Q,P*P*P*P*(1-P)*Q,(1-P)*Q,P*(1-P)*Q,P*P*(1-P)*Q,P*P*P*(1-P)*(1-P)*Q}, {0,P*P*P*P*P*Q,P*P*P*P*(1-P)*Q,1-Q+P*P*P*P*(1-P)*Q,(1-P)*Q,P*(1-P)*Q,P*P*(1-P)*Q,P*P*P*(1-P)*(1-P)*Q}, {0,P*P*P*P*Q,P*P*P*(1-P)*Q,P*P*P*(1-P)*Q,1-Q,(1-P)*Q,P*(1-P)*Q,P*P*(1-P)*(1-P)*Q}, {0,P*P*P*Q,P*P*(1-P)*Q,P*P*(1-P)*Q,0,1-Q,(1-P)*Q,P*(1-P)*(1-P)*Q}, {0,0,0,0,0,0,0,0}, {0,P*P*P*P*P*Q,P*P*P*P*(1-P)*Q,P*P*P*P*(1-P)*Q,(1-P)*Q,P*(1-P)*Q,P*P*(1-P)*Q,1-Q+P*P*P*(1-P)*(1-P)*Q}}; double sil2sil = 1-transMatrix[SILENT][SILENT]; int i, j; for(i = 0; i <= 5; i++) for(j = 1; j < 7; j++) redTransMatrix[i][j] = Math.log(transMatrix[i][j]+transMatrix[i][SILENT]*transMatrix[SILENT][j]/sil2sil); for(i = 0; i < 7; i++) { redTransMatrix[i][0] = Utils.log0; redTransMatrix[6][i] = Utils.log0; } for(i = 0; i < 8; i++) for(j = 0; j < 8; j++) transMatrix[i][j] = Math.log(transMatrix[i][j]); } /* reduced transition matrix, silent st. eliminated, st. 0 is start, st. 6 is end */ private double redTransMatrix[][] = new double[7][7]; /* states' emission descriptor: first dim. is parent/left child/right child, 2nd dim. is state, value is 0/1 */ private int stateEmit[][] = {{0,1,1,1,0,0,0,1}, {0,1,1,0,1,0,0,0}, {0,1,0,1,0,1,0,0}}; /* converts states' emission pattern as a binary number (p=4,l=2,r=1) to state # : e.g. 7->1, 6->2 8->7 is virtual, so that end state has a pattern, too */ private int emitPatt2State[] = {0,5,4,-1,7,3,2,1,6}; /** * Constructs a HMMSilent for Tree. Sets up the transition matrices. */ public HmmNonParam() { updateTransMatrix(); } /** * Returns the transition matrix. * See the abstract class HmmSilent for more details. */ @Override public double[][] preCalcTransMatrix(double[][] transMatrix, double t1, double t2) { return this.transMatrix; } /** * Returns the transition matrix. * See the abstract class HmmSilent for more details. */ @Override public double[][] preCalcRedTransMatrix(double[][] redTransMatrix, double[][] transMatrix) { double[][] redTransMatrixCopy = new double[this.redTransMatrix.length][]; for (int i = 0; i < this.redTransMatrix.length; i++) { redTransMatrixCopy[i] = this.redTransMatrix[i].clone(); } return redTransMatrixCopy; } /** * Returns the index of the silent state. */ @Override public int getSilent() { return SILENT; } /** * Returns the index of the start state. */ @Override public int getStart() { return 0; } /** * Returns the index of the end state. */ @Override public int getEnd() { return 6; } /** * Returns an array specifying the emission pattern of each state. See class Hmm. */ @Override public int[][] getStateEmit() { return stateEmit; } /** * Returns a conversion array from emission patterns (coded as integers) into state * indices. See class Hmm for details. */ @Override public int[] getEmitPatt2State() { return emitPatt2State; } /** * For testing/debugging purposes. * * @param args No argument is used. */ public static void main(String[] args) { HmmNonParam hmm = new HmmNonParam(); double tm[][] = hmm.preCalcTransMatrix(null, 1, 1); double redtm[][] = hmm.preCalcRedTransMatrix(null, tm); double heat = 0.9; for (int i = 0; i < redtm.length; i++) { double tempSum = Utils.log0; for (int j = 0; j < redtm[i].length; j++) { redtm[i][j] = redtm[i][j] * heat; tempSum = Utils.logAdd(tempSum, redtm[i][j]); } if(tempSum != Double.NEGATIVE_INFINITY){ for (int j = 0; j < redtm[i].length; j++) { redtm[i][j] = redtm[i][j] - tempSum; } } } boolean failed = false; for(int i = 0; i < 7; i++) { double sum = Utils.log0; for(int j = 0; j < 7; j++) sum = Utils.logAdd(sum, redtm[i][j]); sum = Math.exp(sum); System.out.println(Arrays.toString(redtm[i])); System.out.println("Row "+i+" sum: "+sum); if(Math.abs(sum-(i==6?0.0:1.0)) > 1e-5) { failed = true; for(int j = 0; j < 7; j++) System.out.println(" "+i+"->"+j+" likelihood: "+Math.exp(redtm[i][j])); } } System.out.println(failed?"Test failed.":"Test passed."); } }