/* * ODEDemographicFunction.java * * Copyright (c) 2002-2015 Alexei Drummond, Andrew Rambaut and Marc Suchard * * This file is part of BEAST. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership and licensing. * * BEAST is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2 * of the License, or (at your option) any later version. * * BEAST is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with BEAST; if not, write to the * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, * Boston, MA 02110-1301 USA */ /* * Base class for demographic models based on numerically-integrated ODEs * */ package dr.evomodel.epidemiology; import dr.evolution.coalescent.*; /** * This interface provides methods that describe a demographic function. * @author Daniel Wilson */ /*public interface ODEDemographicFunction extends DemographicFunction { public abstract class Abstract implements ODEDemographicFunction { public Abstract(Type units) { setUnits(units); } public double getLogDemographic(double t) { return Math.log(getDemographic(t)); } public double getIntegral(double start, double finish) { return getIntensity(finish) - getIntensity(start); } } }*/ public abstract class ODEDemographicFunction extends DemographicFunction.Abstract { public ODEDemographicFunction(Type units) { super(units); } // Implement abstract types from base class /** * Default implementation * @param t * @return log(demographic at t) */ /*public double getLogDemographic(double t) { return Math.log(getDemographic(t)); }*/ /** * Calculates the integral 1/N(x) dx between start and finish. */ /*public double getIntegral(double start, double finish) { return getIntensity(finish) - getIntensity(start); }*/ /** * Returns the integral of 1/N(x) between start and finish, calling either the getAnalyticalIntegral or * getNumericalIntegral function as appropriate. */ public double getNumericalIntegral(double start, double finish) { throw new RuntimeException("not implemented"); } public double getDemographic(double t) { Evaluate(t); if(RKfail) return 0.0; return getDemographicFromPrevalence(Ynow,t); } public double getIntensity(double t) { Evaluate(t); if(RKfail) return Math.log(0.0); return Ynow[0]; } // Implement the following abstract functions: /** * Calculate the derivatives and store in dydt */ abstract void derivs(double t, double[] y, double[] dydt); /** * Set initial values of y * for(i=0;i<nvar;i++) Y[i][0] = ...; */ abstract void setInit(); /** * Calculate the effective population size from the prevalence, contained in y * @return */ abstract double getDemographicFromPrevalence(double[] y, double t); // Implemented base functions /** * Evaluate the demographic functions at time Tnow. * Store results in Ynow for immediate use by getDemographic or getIntensity */ void Evaluate(double t) { if(RKfail) return; // if (t==Tnow) ??? if(t<0.0) throw new RuntimeException("t cannot be negative"); int i; if(klast==-1 || t>T[klast]) { if(klast+1>=kmax) { // 3rd December 2011: Previous behaviour was stupid, now increase storage capacity // Assume linear, but flag warning //RKwarning = true; //for(i=0;i<nvar;i++) Ynow[i] = Y[i][klast]; //Tnow = t; //return; RKresize(); } // Continue integration try { RungeKutta(t); } catch(RuntimeException e) { System.err.println(e.getMessage()); RKfail = true; return; } for(i=0;i<nvar;i++) Ynow[i] = Y[i][klast]; Tnow = t; return; } if(t==T[klast]) { for(i=0;i<nvar;i++) Ynow[i] = Y[i][klast]; Tnow = t; return; } // Linearly interpolate int k1 = (int)Math.floor(t/dtsav); if(k1>klast-1) k1 = klast-1; if(k1<0) k1 = 0; int k2 = k1+1; if(k2>klast || T[k1]>t) { while(k2>klast || T[k1]>t) { --k1; --k2; } } else if(k1<0 || T[k2]<=t) { while(k1<0 || T[k2]<=t) { ++k1; ++k2; } } if(T[k1]==t) { for(i=0;i<nvar;i++) Ynow[i] = Y[i][k1]; } else { // Linearly interpolate for(i=0;i<nvar;i++) Ynow[i] = Y[i][k1] + (t-T[k1])*(Y[i][k2]-Y[i][k1])/(T[k2]-T[k1]); } Tnow = t; } /** * Initialize RK integration */ void RKinit() { klast = -1; RKwarning = false; RKfail = false; if(Y==null || Y.length!=nvar || Y[0].length!=kmax) { // Should be no memory leak Y = new double[nvar][kmax]; T = new double[kmax]; Ynow = new double[nvar]; ak2 = new double[nvar]; ak3 = new double[nvar]; ak4 = new double[nvar]; ak5 = new double[nvar]; ak6 = new double[nvar]; Ytemp = new double[nvar]; Yerr = new double[nvar]; y = new double[nvar]; dydt = new double[nvar]; yscal = new double[nvar]; } nok = nbad = 0; } /** * Increase kmax by kinc on the hoof */ void RKresize() { if(Y==null) throw new RuntimeException("Y not yet allocated"); if(T==null) throw new RuntimeException("T not yet allocated"); if(kmax==kabsolutemax) throw new RuntimeException("kabsolutemax exceeded"); // Store old value of kmax before incrementing it int oldkmax = kmax; kmax += kinc; // Temporary pointers to old arrays, should get garbage collected double oldY[][] = Y; double oldT[] = T; // Allocate new memory for enlarged arrays Y = new double[nvar][kmax]; T = new double[kmax]; // Copy across old values int i,j; for(i=0;i<oldkmax;i++) { for(j=0;j<nvar;j++) { Y[j][i] = oldY[j][i]; } T[i] = oldT[i]; } } /** * Flag if kmax is exceeded * @return */ public boolean RKwarn() { return RKwarning; } /** * Driving routine for RungeKutta integration */ void RungeKutta(double t2) { // if(h1<0) throw new RuntimeException("h1 must be positive"); if(klast==kmax-1) throw new RuntimeException("storage space is exceeded"); if(klast==-1) { setInit(); // virtual function, over-ride in derived class T[0] = 0; ++klast; } if(t2==0.0) return; int i,nstp; double t1 = T[klast]; // beginning of time range double t = t1; // current time double tsav = t1; // time of last saved point // Copy initial y values for (i=0;i<nvar;i++) y[i]=Y[i][klast]; // Step size double h=hinit; // NB h must be positive for(nstp=0;nstp<MAXSTP;nstp++) { double tmp = nstp; if(nstp>MAXSTP/2) { tmp = tmp+3; } // Calculate the derivatives at the present time derivs(t,y,dydt); // Calculate appropriate scalings for the error tolerance for(i=0;i<nvar;i++) yscal[i]=Math.abs(y[i])+Math.abs(dydt[i]*h)+TINY; // Storage (ensure there's room for final state) if(klast < kmax-2 && Math.abs(t-tsav) > Math.abs(dtsav)) { ++klast; for (i=0;i<nvar;i++) Y[i][klast]=y[i]; T[klast]=t; tsav=t; } // Reduce step size to avoid over-stepping target time t2 if((t+h-t2)*(t+h-t1) > 0.0) h=t2-t; // Perform the adaptive integration and update time t = rkqs(y,dydt,t,h,yscal); // Was the predicted step size used? if(hdid == h) ++nok; else ++nbad; // If the target t2 has been reached if((t-t2)*(t2-t1) >= 0.0) { // Store state at t2 if (klast < kmax-1) { ++klast; for (i=0;i<nvar;i++) Y[i][klast]=y[i]; T[klast]=t; } return; } if(Math.abs(hnext) <= hmin) { throw new RuntimeException("Step size too small in odeint"); } h=hnext; } throw new RuntimeException("Too many steps in routine odeint"); } /** * Take an adaptive step * * @return New time */ // Cannot output primitive scalars by modifying arguments. So hdid and hnext become member variables and t is returned. double rkqs(double[] y, double[] dydt, double t, double htry, double[] yscal) { int i; double errmax,h,htemp,tnew; h=htry; for(;;) { rkck(y,dydt,t,h); errmax=0.0; for(i=0;i<nvar;i++) errmax=Math.max(errmax,Math.abs(Yerr[i]/yscal[i])); errmax /= eps; if(errmax <= 1.0) break; if(Double.isNaN(errmax)) { throw new RuntimeException("errmax NaN"); // Invalid value of one or more of the variables was chosen. Halve the step size //h /= 2.0; } else { htemp=SAFETY*h*Math.pow(errmax,PSHRNK); h=(h >= 0.0 ? Math.max(htemp,0.1*h) : Math.min(htemp,0.1*h)); } tnew=t+h; if(tnew == t) { throw new RuntimeException("stepsize underflow in rkqs"); } } if(errmax > ERRCON) hnext=SAFETY*h*Math.pow(errmax,PGROW); else hnext=5.0*h; t += (hdid=h); for(i=0;i<nvar;i++) y[i]=Ytemp[i]; return t; } /** * Take one RK5 step */ // NB: Arrays such as ak2 and yerr are objects. Pointers to these objects are passed by // value into rkck. The whole array is not copied. So this should work... void rkck(double[] y, double[] dydt, double t, double h) { int i; for(i=0;i<nvar;i++) Ytemp[i]=y[i]+b21*h*dydt[i]; derivs(t+a2*h,Ytemp,ak2); for(i=0;i<nvar;i++) Ytemp[i]=y[i]+h*(b31*dydt[i]+b32*ak2[i]); derivs(t+a3*h,Ytemp,ak3); for(i=0;i<nvar;i++) Ytemp[i]=y[i]+h*(b41*dydt[i]+b42*ak2[i]+b43*ak3[i]); derivs(t+a4*h,Ytemp,ak4); for(i=0;i<nvar;i++) Ytemp[i]=y[i]+h*(b51*dydt[i]+b52*ak2[i]+b53*ak3[i]+b54*ak4[i]); derivs(t+a5*h,Ytemp,ak5); for(i=0;i<nvar;i++) Ytemp[i]=y[i]+h*(b61*dydt[i]+b62*ak2[i]+b63*ak3[i]+b64*ak4[i]+b65*ak5[i]); derivs(t+a6*h,Ytemp,ak6); for(i=0;i<nvar;i++) Ytemp[i]=y[i]+h*(c1*dydt[i]+c3*ak3[i]+c4*ak4[i]+c6*ak6[i]); for(i=0;i<nvar;i++) Yerr[i]=h*(dc1*dydt[i]+dc3*ak3[i]+dc4*ak4[i]+dc5*ak5[i]+dc6*ak6[i]); } // Member variables protected int nvar = 0; // Default to zero // Runge-Kutta integration variables protected int kmax=200; // Maximum storage capacity for integration protected int kinc=200; // Increment size for kmax when it is exceeded protected int kabsolutemax=200000; // Absolute maximum storage capacity for integration allowable protected int klast=-1; // Index of the last evaluation point protected double hinit=0.1; // Initial suggested step size for RungeKutta integration protected double[][] Y; // Storage for integration results. Y[0] must always contain Lambda, the integrated intensity function protected double[] T; // Storage for evaluated time points: interpolate between protected boolean RKwarning; // Warn if kmax is exceeded // Temporary variables protected double[] Ynow; // Instead of passing and returning vectors, store immediate value of Y protected double Tnow; // Immediate value of T // Static constants and storage used by rkck static final double a2=0.2, a3=0.3, a4=0.6, a5=1.0, a6=0.875, b21=0.2, b31=3.0/40.0, b32=9.0/40.0, b41=0.3, b42 = -0.9, b43=1.2, b51 = -11.0/54.0, b52=2.5, b53 = -70.0/27.0, b54=35.0/27.0, b61=1631.0/55296.0, b62=175.0/512.0, b63=575.0/13824.0, b64=44275.0/110592.0, b65=253.0/4096.0, c1=37.0/378.0, c3=250.0/621.0, c4=125.0/594.0, c6=512.0/1771.0, dc1=c1-2825.0/27648.0, dc3=c3-18575.0/48384.0, dc4=c4-13525.0/55296.0, dc5 = -277.00/14336.0, dc6=c6-0.25; private double[] ak2, ak3, ak4, ak5, ak6; // Static constants used by rkqs static final double SAFETY=0.9, PGROW=-0.2, PSHRNK=-0.25, ERRCON=1.89e-4; // Storage used by rkqs and RungeKutta protected double hdid, hnext; // Storage used by rkqs and rkck protected double[] Ytemp, Yerr; // Storage for RungeKutta protected int MAXSTP=10000; protected double TINY=1.0e-30; protected int nok=0, nbad=0; protected double[] y, dydt, yscal; protected double dtsav = 0.1; protected double hmin = 1.0e-16; protected double eps = 1e-4; protected boolean RKfail = false; }