package test.dr.math;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.nucleotide.HKY;
import dr.evomodel.substmodel.MarkovJumpsSubstitutionModel;
import dr.evolution.datatype.Nucleotides;
import dr.inference.markovjumps.*;
import dr.inference.model.Parameter;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.Vector;
import java.util.Arrays;
/**
* @author Marc A. Suchard
*/
public class UniformizedStateHistoryTest extends MathTestCase {
public void setUp() {
MathUtils.setSeed(666);
Parameter kappa = new Parameter.Default(1, 2.0);
double[] pi = {0.45, 0.05, 0.30, 0.20};
Parameter freqs = new Parameter.Default(pi);
FrequencyModel f = new FrequencyModel(Nucleotides.INSTANCE, freqs);
hky = new HKY(kappa, f);
stateCount = hky.getDataType().getStateCount();
double[] lambda = new double[stateCount * stateCount];
hky.getInfinitesimalMatrix(lambda);
process = new SubordinatedProcess(lambda, stateCount);
}
HKY hky;
SubordinatedProcess process;
int stateCount;
public void testSubordinatedProcessGeneration() {
double[] oneStep = process.getDtmcProbabilities(1);
double[] threeStep = process.getDtmcProbabilities(3);
double[] rOneStep = {
0.2608696, 0.5217391, 4.347826e-02, 0.1739130,
0.7826087, 0.0000000, 4.347826e-02, 0.1739130,
0.3913043, 0.2608696, 2.220446e-16, 0.3478261,
0.3913043, 0.2608696, 8.695652e-02, 0.2608696
};
MarkovJumpsCore.makeComparableToRPackage(rOneStep);
double[] rThreeStep = {
0.3935235, 0.3570313, 0.04988904, 0.1995562,
0.5355470, 0.2150078, 0.04988904, 0.1995562,
0.4490014, 0.2993343, 0.04980685, 0.2018575,
0.4490014, 0.2993343, 0.05046437, 0.2012000
};
MarkovJumpsCore.makeComparableToRPackage(rThreeStep);
System.out.println("oneStep = " + new Vector(oneStep));
assertEquals(rOneStep, oneStep, 1E-6);
System.out.println("threeStep = " + new Vector(threeStep));
assertEquals(rThreeStep, threeStep, 1E-6);
}
public void testComputePdfForNextDraw() {
int startState = 1;
int endState = 0;
// int rStartState = 3; // 1 + 1 = 2 -> 3
// int rEndState = 1; // 0 + 1 = 1 -> 1
double[] pdf = new double[stateCount];
int n = 4;
int i = 1;
process.computePdfNextChainState(startState, endState, n, i, pdf);
pdf = MathUtils.getNormalized(pdf);
System.err.println("PDF = " + new Vector(pdf));
double[] rPDF = new double[]{3.422934e-01, 3.105519e-01, 2.216160e-16, 3.471547e-01};
MarkovJumpsCore.makeComparableToRPackage(rPDF);
assertEquals(rPDF, pdf, 1E-6);
}
public void testTotalChangesSamplingMethods() {
try {
int startState = 1;
int endState = 0;
double time = 0.5;
double[] ctmcProbabilities = new double[stateCount * stateCount];
hky.getTransitionProbabilities(time, ctmcProbabilities);
double ctmcProbability = ctmcProbabilities[startState * stateCount + endState];
double[] pdf = process.computePDFDirectly(startState, endState, time, ctmcProbability, 10);
System.out.println("PDF = " + new Vector(pdf));
double cutoff = pdf[0] + pdf[1] - 1E-6;
System.out.println("Test cutoff = " + cutoff);
int draw = process.drawNumberOfChanges(startState, endState, time, ctmcProbability, cutoff);
assertEquals(1, draw);
cutoff = pdf[0] + pdf[1] + 1E-6;
System.out.println("Test cutoff = " + cutoff);
draw = process.drawNumberOfChanges(startState, endState, time, ctmcProbability, cutoff);
assertEquals(2, draw);
cutoff = pdf[0] + pdf[1] + pdf[2] + 1E-6;
System.out.println("Test cutoff = " + cutoff);
draw = process.drawNumberOfChanges(startState, endState, time, ctmcProbability, cutoff);
assertEquals(3, draw);
System.out.println("");
startState = 1;
endState = 1;
time = 0.75;
hky.getTransitionProbabilities(time, ctmcProbabilities);
ctmcProbability = ctmcProbabilities[startState * stateCount + endState];
pdf = process.computePDFDirectly(startState, endState, time, ctmcProbability, 10);
System.out.println("PDF = " + new Vector(pdf));
cutoff = pdf[0] + pdf[1] + pdf[2] - 1E-6;
System.out.println("Test cutoff = " + cutoff);
draw = process.drawNumberOfChanges(startState, endState, time, ctmcProbability, cutoff);
assertEquals(2, draw);
cutoff = pdf[0] + pdf[1] + pdf[2] + 1E-6;
System.out.println("Test cutoff = " + cutoff);
draw = process.drawNumberOfChanges(startState, endState, time, ctmcProbability, cutoff);
assertEquals(3, draw);
} catch (SubordinatedProcess.Exception e) {
throw new RuntimeException("Subordinated process exception");
}
}
public void testStateHistorySimulationForJumps() {
try {
double startingTime = 1.0;
double endingTime = 3.0;
int startingState = 1;
int endingState = 3;
int N = 1000000;
double[] tmp = new double[stateCount * stateCount];
hky.getTransitionProbabilities(endingTime - startingTime, tmp);
double transitionProbability = tmp[startingState * stateCount + endingState];
double[][] registers = new double[2][stateCount * stateCount];
MarkovJumpsCore.fillRegistrationMatrix(registers[0], stateCount); // Count all jumps
registers[1][2 * stateCount + 1] = 1.0; // Mark just one state!
double[] expectations = new double[registers.length];
for (int i = 0; i < N; i++) {
StateHistory history = UniformizedStateHistory.simulateConditionalOnEndingState(
startingTime,
startingState,
endingTime,
endingState,
transitionProbability,
stateCount,
process);
for (int j = 0; j < registers.length; j++) {
expectations[j] += history.getTotalRegisteredCounts(registers[j]);
}
}
// Determine analytic solution
MarkovJumpsSubstitutionModel markovjumps = new MarkovJumpsSubstitutionModel(hky);
double[] mjExpectations = new double[stateCount * stateCount];
for (int j = 0; j < registers.length; j++) {
expectations[j] /= (double) N;
System.out.println("Expected number for register = " + expectations[j]);
markovjumps.setRegistration(registers[j]);
markovjumps.computeCondStatMarkovJumps(endingTime - startingTime, mjExpectations);
assertEquals(mjExpectations[startingState * stateCount + endingState], expectations[j], 1E-2);
}
} catch (SubordinatedProcess.Exception e) {
throw new RuntimeException("Subordinated process exception");
}
}
public void testStateHistorySimulationForRewards() {
try {
double startingTime = 1.0;
double endingTime = 3.0;
int startingState = 1;
int endingState = 3;
int N = 1000000;
double[] tmp = new double[stateCount * stateCount];
hky.getTransitionProbabilities(endingTime - startingTime, tmp);
double transitionProbability = tmp[startingState * stateCount + endingState];
double[][] registers = new double[3][stateCount];
Arrays.fill(registers[0], 1.0); // Reward all states
registers[1][0] = 1.0; // Reward just one state!
registers[2][3] = 1.0; // Reward just one state!
double[] expectations = new double[registers.length];
for (int i = 0; i < N; i++) {
StateHistory history = UniformizedStateHistory.simulateConditionalOnEndingState(
startingTime,
startingState,
endingTime,
endingState,
transitionProbability,
stateCount,
process);
for (int j = 0; j < registers.length; j++) {
expectations[j] += history.getTotalReward(registers[j]);
}
}
// Determine analytic solution
MarkovJumpsSubstitutionModel markovjumps = new MarkovJumpsSubstitutionModel(hky, MarkovJumpsType.REWARDS);
double[] mjExpectations = new double[stateCount * stateCount];
for (int j = 0; j < registers.length; j++) {
expectations[j] /= (double) N;
System.out.println("Expected reward for register[" + j +"] = " + expectations[j]);
markovjumps.setRegistration(registers[j]);
markovjumps.computeCondStatMarkovJumps(endingTime - startingTime, mjExpectations);
assertEquals(mjExpectations[startingState * stateCount + endingState], expectations[j], 1E-2);
}
} catch (SubordinatedProcess.Exception e) {
throw new RuntimeException("Subordinated process exception");
}
}
}
/*
# R markovjumps code
hky = as.eigen(hky.model(2, 1, c(0.45, 0.30, 0.05, 0.20), scale = T))
maxRate = - min(hky$rate.matrix)
R = diag(4) + hky$rate.matrix / maxRate
R3 = R %*% R %*% R
R4 = R %*% R3
(R[3,] * R3[,1]) / R4[3,1]
*/