/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.mahout.math.jet.random;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.MahoutTestCase;
import org.junit.Test;
import java.util.Arrays;
import java.util.Locale;
import java.util.Random;
public final class GammaTest extends MahoutTestCase {
@Test
public void testNextDouble() {
double[] z = new double[100000];
Random gen = RandomUtils.getRandom();
for (double alpha : new double[]{1, 2, 10, 0.1, 0.01, 100}) {
Gamma g = new Gamma(alpha, 1, gen);
for (int i = 0; i < z.length; i++) {
z[i] = g.nextDouble();
}
Arrays.sort(z);
// verify that empirical CDF matches theoretical one pretty closely
for (double q : seq(0.01, 1, 0.01)) {
double p = z[(int) (q * z.length)];
assertEquals(q, g.cdf(p), 0.01);
}
}
}
@Test
public void testCdf() {
Random gen = RandomUtils.getRandom();
// verify scaling for special case of alpha = 1
for (double beta : new double[]{1, 0.1, 2, 100}) {
Gamma g1 = new Gamma(1, beta, gen);
Gamma g2 = new Gamma(1, 1, gen);
for (double x : seq(0, 0.99, 0.1)) {
assertEquals(String.format(Locale.ENGLISH, "Rate invariance: x = %.4f, alpha = 1, beta = %.1f", x, beta),
1 - Math.exp(-x * beta), g1.cdf(x), 1.0e-9);
assertEquals(String.format(Locale.ENGLISH, "Rate invariance: x = %.4f, alpha = 1, beta = %.1f", x, beta),
g2.cdf(beta * x), g1.cdf(x), 1.0e-9);
}
}
// now test scaling for a selection of values of alpha
for (double alpha : new double[]{0.01, 0.1, 1, 2, 10, 100, 1000}) {
Gamma g = new Gamma(alpha, 1, gen);
for (double beta : new double[]{0.1, 1, 2, 100}) {
Gamma g1 = new Gamma(alpha, beta, gen);
for (double x : seq(0, 0.9999, 0.001)) {
assertEquals(
String.format(Locale.ENGLISH, "Rate invariance: x = %.4f, alpha = %.2f, beta = %.1f", x, alpha, beta),
g.cdf(x * beta), g1.cdf(x), 0);
}
}
}
// now check against known values computed using R for various values of alpha
checkGammaCdf(0.01, 1, 0.0000000, 0.9450896, 0.9516444, 0.9554919, 0.9582258, 0.9603474, 0.9620810, 0.9635462, 0.9648148, 0.9659329, 0.9669321);
checkGammaCdf(0.1, 1, 0.0000000, 0.7095387, 0.7591012, 0.7891072, 0.8107067, 0.8275518, 0.8413180, 0.8529198, 0.8629131, 0.8716623, 0.8794196);
checkGammaCdf(1, 1, 0.0000000, 0.1812692, 0.3296800, 0.4511884, 0.5506710, 0.6321206, 0.6988058, 0.7534030, 0.7981035, 0.8347011, 0.8646647);
checkGammaCdf(10, 1, 0.000000e+00, 4.649808e-05, 8.132243e-03, 8.392402e-02, 2.833757e-01, 5.420703e-01, 7.576078e-01, 8.906006e-01, 9.567017e-01, 9.846189e-01, 9.950046e-01);
checkGammaCdf(100, 1, 0.000000e+00, 3.488879e-37, 1.206254e-15, 1.481528e-06, 1.710831e-02, 5.132988e-01, 9.721363e-01, 9.998389e-01, 9.999999e-01, 1.000000e+00, 1.000000e+00);
// > pgamma(seq(0,0.02,by=0.002),0.01,1)
// [1] 0.0000000 0.9450896 0.9516444 0.9554919 0.9582258 0.9603474 0.9620810 0.9635462 0.9648148 0.9659329 0.9669321
// > pgamma(seq(0,0.2,by=0.02),0.1,1)
// [1] 0.0000000 0.7095387 0.7591012 0.7891072 0.8107067 0.8275518 0.8413180 0.8529198 0.8629131 0.8716623 0.8794196
// > pgamma(seq(0,2,by=0.2),1,1)
// [1] 0.0000000 0.1812692 0.3296800 0.4511884 0.5506710 0.6321206 0.6988058 0.7534030 0.7981035 0.8347011 0.8646647
// > pgamma(seq(0,20,by=2),10,1)
// [1] 0.000000e+00 4.649808e-05 8.132243e-03 8.392402e-02 2.833757e-01 5.420703e-01 7.576078e-01 8.906006e-01 9.567017e-01 9.846189e-01 9.950046e-01
// > pgamma(seq(0,200,by=20),100,1)
// [1] 0.000000e+00 3.488879e-37 1.206254e-15 1.481528e-06 1.710831e-02 5.132988e-01 9.721363e-01 9.998389e-01 9.999999e-01 1.000000e+00 1.000000e+00
}
private static void checkGammaCdf(double alpha, double beta, double... values) {
Gamma g = new Gamma(alpha, beta, RandomUtils.getRandom());
int i = 0;
for (double x : seq(0, 2 * alpha, 2 * alpha / 10)) {
assertEquals(String.format(Locale.ENGLISH, "alpha=%.2f, i=%d, x=%.2f", alpha, i, x),
values[i], g.cdf(x), 1.0e-7);
i++;
}
}
private static double[] seq(double from, double to, double by) {
double[] r = new double[(int) Math.ceil(0.999999 * (to - from) / by)];
int i = 0;
for (double x = from; x < to - (to - from) * 1.0e-6; x += by) {
r[i++] = x;
}
return r;
}
@Test
public void testPdf() {
Random gen = RandomUtils.getRandom();
for (double alpha : new double[]{0.01, 0.1, 1, 2, 10, 100}) {
for (double beta : new double[]{0.1, 1, 2, 100}) {
Gamma g1 = new Gamma(alpha, beta, gen);
for (double x : seq(0, 0.99, 0.1)) {
double p = Math.pow(beta, alpha) * Math.pow(x, alpha - 1) *
Math.exp(-beta * x - org.apache.mahout.math.jet.stat.Gamma.logGamma(alpha));
assertEquals(String.format(Locale.ENGLISH, "alpha=%.2f, beta=%.2f, x=%.2f\n", alpha, beta, x),
p, g1.pdf(x), 1.0e-9);
}
}
}
}
}