/* * StateHistoryTest.java * * Copyright (c) 2002-2014 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ package test.dr.evomodel.substmodel; import dr.evomodel.substmodel.nucleotide.HKY; import dr.evolution.datatype.Nucleotides; import dr.evomodel.substmodel.FrequencyModel; import dr.evomodel.substmodel.MarkovJumpsSubstitutionModel; import dr.evomodel.substmodel.SubstitutionModel; import dr.evomodel.substmodel.UniformizedSubstitutionModel; import dr.inference.markovjumps.MarkovJumpsCore; import dr.inference.markovjumps.StateHistory; import dr.math.LogTricks; import dr.math.MathUtils; import dr.math.matrixAlgebra.Vector; import test.dr.math.MathTestCase; /** * @author Marc A. Suchard */ public class StateHistoryTest extends MathTestCase { public static final int N = 1000000; public void setUp() { MathUtils.setSeed(666); freqModel = new FrequencyModel(Nucleotides.INSTANCE, new double[]{0.45, 0.25, 0.05, 0.25}); baseModel = new HKY(2.0, freqModel); stateCount = baseModel.getDataType().getStateCount(); lambda = new double[stateCount * stateCount]; baseModel.getInfinitesimalMatrix(lambda); System.out.println("lambda = " + new Vector(lambda)); markovjumps = new MarkovJumpsSubstitutionModel(baseModel); } public void testGetLogLikelihood() { System.out.println("Start of getLogLikelihood test"); int startingState = 1; int endingState = 1; double duration = 0.5; int iterations = 5000000; // int iterations = 1; double[] probs = new double[16]; baseModel.getTransitionProbabilities(duration, probs); double trueProb = probs[startingState * 4 + endingState]; System.out.println("Tru prob = " + trueProb); UniformizedSubstitutionModel uSM = new UniformizedSubstitutionModel(baseModel); uSM.setSaveCompleteHistory(true); double logProb = Double.NEGATIVE_INFINITY; double prob = 0.0; double condProb = 0.0; for (int i = 0; i < iterations; ++i) { uSM.computeCondStatMarkovJumps(startingState, endingState, duration, trueProb); StateHistory history = uSM.getStateHistory(); // System.out.println(history.getEndingTime() - history.getStartingTime()); assertEquals(history.getEndingTime() - history.getStartingTime(), duration, 10E-3); assertEquals(history.getStartingState(), startingState); assertEquals(history.getEndingState(), endingState); double logLikelihood = history.getLogLikelihood(lambda, 4); prob += Math.exp(logLikelihood); logProb = LogTricks.logSum(logProb, logLikelihood); condProb += Math.exp(-logLikelihood); // System.err.println(logLikelihood); } logProb = Math.exp(logProb - Math.log(iterations)); prob /= iterations; condProb /= iterations; System.out.println("Sim prob = " + prob); System.out.println("Inv prob = " + (1.0 / condProb)); // System.out.println("log prob = " + logProb); // System.exit(-1); System.out.println(); System.out.println(); // Try using unconditioned simulation double marginalProb = 0.0; double mcProb = 0.0; double invMcProb = 0.0; int totalTries = 0; int i = 0; while (i < iterations) { startingState = MathUtils.randomChoicePDF(freqModel.getFrequencies()); StateHistory history = StateHistory.simulateUnconditionalOnEndingState(0, startingState, duration, lambda, 4); if (//history.getEndingState() == endingState && // history.getNumberOfJumps() == randomChoice1 true ) { marginalProb += 1.0; i++; double logLike = history.getLogLikelihood(lambda, 4); mcProb += Math.exp(logLike); invMcProb += Math.exp(-logLike); // if (i % 100000 == 0) System.out.println(i); } totalTries++; } marginalProb /= totalTries; mcProb /= iterations; invMcProb /= iterations; System.out.println("Sim uncd = " + marginalProb); System.out.println("mc prob = " + mcProb); System.out.println("m2 prob = " + (1.0 / invMcProb)); assertEquals(prob, trueProb); } public void testFreqDistribution() { System.out.println("Start of FreqDistribution test"); int startingState = 0; double duration = 10; // 10 expected substitutions is close to \infty double[] freq = new double[stateCount]; for (int i = 0; i < N; i++) { StateHistory simultant = StateHistory.simulateUnconditionalOnEndingState(0.0, startingState, duration, lambda, stateCount); freq[simultant.getEndingState()]++; } for (int i = 0; i < stateCount; i++) { freq[i] /= N; } System.out.println("freq = " + new Vector(freq)); assertEquals(freq, freqModel.getFrequencies(), 1E-3); System.out.println("End of FreqDistribution test\n"); } public void testCounts() { System.out.println("State of Counts test"); int startingState = 2; double duration = 0.5; int[] counts = new int[stateCount * stateCount]; double[] expectedCounts = new double[stateCount * stateCount]; for (int i = 0; i < N; i++) { StateHistory simultant = StateHistory.simulateUnconditionalOnEndingState(0.0, startingState, duration, lambda, stateCount); simultant.accumulateSufficientStatistics(counts, null); } for (int i = 0; i < stateCount * stateCount; i++) { expectedCounts[i] = (double) counts[i] / (double) N; } double[] r = new double[stateCount * stateCount]; double[] joint = new double[stateCount * stateCount]; double[] analytic = new double[stateCount * stateCount]; for (int from = 0; from < stateCount; from++) { for (int to = 0; to < stateCount; to++) { double marginal = 0; if (from != to) { MarkovJumpsCore.fillRegistrationMatrix(r, from, to, stateCount); markovjumps.setRegistration(r); markovjumps.computeJointStatMarkovJumps(duration, joint); for (int j = 0; j < stateCount; j++) { marginal += joint[startingState * stateCount + j]; // Marginalize out ending state } } analytic[from * stateCount + to] = marginal; } } System.out.println("unconditional expected counts = " + new Vector(expectedCounts)); System.out.println("analytic counts = " + new Vector(analytic)); assertEquals(expectedCounts, analytic, 1E-3); System.out.println("End of Counts test\n"); } double[] lambda; FrequencyModel freqModel; SubstitutionModel baseModel; MarkovJumpsSubstitutionModel markovjumps; int stateCount; }