package test.beast.math.distributions;
import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.MathException;
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.analysis.integration.RombergIntegrator;
import org.apache.commons.math.analysis.integration.UnivariateRealIntegrator;
import org.junit.Test;
import beast.math.GammaFunction;
import beast.math.distributions.Gamma;
import beast.util.Randomizer;
import junit.framework.TestCase;
public class GammaTest extends TestCase {
@Test
public void testGammaCummulative() throws Exception {
Gamma dist = new Gamma();
dist.initByName("alpha", "0.001", "beta", "1000.0");
double v = dist.inverseCumulativeProbability(0.5);
assertEquals(v, 5.244206e-299, 1e-304);
v = dist.inverseCumulativeProbability(0.05);
assertEquals(v, 0.0);
v = dist.inverseCumulativeProbability(0.025);
assertEquals(v, 0.0);
v = dist.inverseCumulativeProbability(0.95);
assertEquals(v, 2.973588e-20, 1e-24);
v = dist.inverseCumulativeProbability(0.975);
assertEquals(v, 5.679252e-09, 1e-13);
}
/** The code below is adapted from GammaDistributionTest from BEAST 1
* This test stochastically draws gamma
* variates and compares the coded pdf
* with the actual pdf.
* The tolerance is required to be at most 1e-10.
*/
static double mypdf(double value, double shape, double scale) {
return Math.exp((shape-1) * Math.log(value) - value/scale - GammaFunction.lnGamma(shape) - shape * Math.log(scale) );
}
public void testPdf() throws MathException {
final int numberOfTests = 300;
double totErr = 0;
double ptotErr = 0; int np = 0;
double qtotErr = 0;
Randomizer.setSeed(551);
for(int i = 0; i < numberOfTests; i++){
final double mean = .01 + (3-0.01) * Randomizer.nextDouble();
final double var = .01 + (3-0.01) * Randomizer.nextDouble();
double scale0 = var / mean;
double shape = mean / scale0;
final Gamma gamma = new Gamma();
Gamma.mode mode = Gamma.mode.values()[Randomizer.nextInt(4)];
double other = 0;
switch (mode) {
case ShapeScale: other = scale0; break;
case ShapeRate: other = 1/scale0; break;
case ShapeMean: other = scale0 * shape; break;
case OneParameter: other = 1/shape; scale0 = 1/shape; break;
}
final double scale = scale0;
gamma.initByName("alpha", shape +"", "beta", other +"", "mode", mode);
final double value = Randomizer.nextGamma(shape, 1/scale);
final double mypdf = mypdf(value, shape, scale);
final double pdf = gamma.density(value);
if ( Double.isInfinite(mypdf) && Double.isInfinite(pdf)) {
continue;
}
assertFalse(Double.isNaN(mypdf));
assertFalse(Double.isNaN(pdf));
totErr += mypdf != 0 ? Math.abs((pdf - mypdf)/mypdf) : pdf;
assertFalse("nan", Double.isNaN(totErr));
//assertEquals("" + shape + "," + scale + "," + value, mypdf,gamma.pdf(value),1e-10);
final double cdf = gamma.cumulativeProbability(value);
UnivariateRealFunction f = new UnivariateRealFunction() {
public double value(double v) throws FunctionEvaluationException {
return mypdf(v, shape, scale);
}
};
final UnivariateRealIntegrator integrator = new RombergIntegrator();
integrator.setAbsoluteAccuracy(1e-14);
integrator.setMaximalIterationCount(16); // fail if it takes too much time
double x;
try {
x = integrator.integrate(f, 0, value);
ptotErr += cdf != 0.0 ? Math.abs(x-cdf)/cdf : x;
np += 1;
//assertTrue("" + shape + "," + scale + "," + value + " " + Math.abs(x-cdf)/x + "> 1e-6", Math.abs(1-cdf/x) < 1e-6);
final double q = gamma.inverseCumulativeProbability(cdf);
qtotErr += q != 0 ? Math.abs(q-value)/q : value;
//System.out.println(shape + "," + scale + " " + value);
} catch( ConvergenceException e ) {
// can't integrate , skip test
//System.out.print(" theta(" + shape + "," + scale + ") skipped");
}
// assertEquals("" + shape + "," + scale + "," + value + " " + Math.abs(q-value)/value, q, value, 1e-6);
// System.out.print("\n" + np + ": " + mode + " " + totErr/np + " " + qtotErr/np + " " + ptotErr/np);
}
//System.out.println( !Double.isNaN(totErr) );
// System.out.println(totErr);
// bad test, but I can't find a good threshold that works for all individual cases
assertTrue("failed " + totErr/numberOfTests, totErr/numberOfTests < 1e-7);
assertTrue("failed " + qtotErr/numberOfTests , qtotErr/numberOfTests < 1e-10);
assertTrue("failed " + ptotErr/np, np > 0 ? (ptotErr/np < 2e-7) : true);
}
}