/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.commons.math3.distribution; import java.io.Serializable; import org.apache.commons.math3.exception.MathInternalError; import org.apache.commons.math3.exception.NotStrictlyPositiveException; import org.apache.commons.math3.exception.NumberIsTooLargeException; import org.apache.commons.math3.exception.OutOfRangeException; import org.apache.commons.math3.exception.util.LocalizedFormats; import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.RandomDataImpl; import org.apache.commons.math3.util.FastMath; /** * Base class for integer-valued discrete distributions. Default * implementations are provided for some of the methods that do not vary * from distribution to distribution. * */ public abstract class AbstractIntegerDistribution implements IntegerDistribution, Serializable { /** Serializable version identifier */ private static final long serialVersionUID = -1146319659338487221L; /** * RandomData instance used to generate samples from the distribution. * @deprecated As of 3.1, to be removed in 4.0. Please use the * {@link #random} instance variable instead. */ @Deprecated protected final RandomDataImpl randomData = new RandomDataImpl(); /** * RNG instance used to generate samples from the distribution. * @since 3.1 */ protected final RandomGenerator random; /** * @deprecated As of 3.1, to be removed in 4.0. Please use * {@link #AbstractIntegerDistribution(RandomGenerator)} instead. */ @Deprecated protected AbstractIntegerDistribution() { // Legacy users are only allowed to access the deprecated "randomData". // New users are forbidden to use this constructor. random = null; } /** * @param rng Random number generator. * @since 3.1 */ protected AbstractIntegerDistribution(RandomGenerator rng) { random = rng; } /** * {@inheritDoc} * * The default implementation uses the identity * <p>{@code P(x0 < X <= x1) = P(X <= x1) - P(X <= x0)}</p> */ public double cumulativeProbability(int x0, int x1) throws NumberIsTooLargeException { if (x1 < x0) { throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, x0, x1, true); } return cumulativeProbability(x1) - cumulativeProbability(x0); } /** * {@inheritDoc} * * The default implementation returns * <ul> * <li>{@link #getSupportLowerBound()} for {@code p = 0},</li> * <li>{@link #getSupportUpperBound()} for {@code p = 1}, and</li> * <li>{@link #solveInverseCumulativeProbability(double, int, int)} for * {@code 0 < p < 1}.</li> * </ul> */ public int inverseCumulativeProbability(final double p) throws OutOfRangeException { if (p < 0.0 || p > 1.0) { throw new OutOfRangeException(p, 0, 1); } int lower = getSupportLowerBound(); if (p == 0.0) { return lower; } if (lower == Integer.MIN_VALUE) { if (checkedCumulativeProbability(lower) >= p) { return lower; } } else { lower -= 1; // this ensures cumulativeProbability(lower) < p, which // is important for the solving step } int upper = getSupportUpperBound(); if (p == 1.0) { return upper; } // use the one-sided Chebyshev inequality to narrow the bracket // cf. AbstractRealDistribution.inverseCumulativeProbability(double) final double mu = getNumericalMean(); final double sigma = FastMath.sqrt(getNumericalVariance()); final boolean chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) || Double.isInfinite(sigma) || Double.isNaN(sigma) || sigma == 0.0); if (chebyshevApplies) { double k = FastMath.sqrt((1.0 - p) / p); double tmp = mu - k * sigma; if (tmp > lower) { lower = ((int) FastMath.ceil(tmp)) - 1; } k = 1.0 / k; tmp = mu + k * sigma; if (tmp < upper) { upper = ((int) FastMath.ceil(tmp)) - 1; } } return solveInverseCumulativeProbability(p, lower, upper); } /** * This is a utility function used by {@link * #inverseCumulativeProbability(double)}. It assumes {@code 0 < p < 1} and * that the inverse cumulative probability lies in the bracket {@code * (lower, upper]}. The implementation does simple bisection to find the * smallest {@code p}-quantile <code>inf{x in Z | P(X<=x) >= p}</code>. * * @param p the cumulative probability * @param lower a value satisfying {@code cumulativeProbability(lower) < p} * @param upper a value satisfying {@code p <= cumulativeProbability(upper)} * @return the smallest {@code p}-quantile of this distribution */ protected int solveInverseCumulativeProbability(final double p, int lower, int upper) { while (lower + 1 < upper) { int xm = (lower + upper) / 2; if (xm < lower || xm > upper) { /* * Overflow. * There will never be an overflow in both calculation methods * for xm at the same time */ xm = lower + (upper - lower) / 2; } double pm = checkedCumulativeProbability(xm); if (pm >= p) { upper = xm; } else { lower = xm; } } return upper; } /** {@inheritDoc} */ public void reseedRandomGenerator(long seed) { random.setSeed(seed); randomData.reSeed(seed); } /** * {@inheritDoc} * * The default implementation uses the * <a href="http://en.wikipedia.org/wiki/Inverse_transform_sampling"> * inversion method</a>. */ public int sample() { return inverseCumulativeProbability(random.nextDouble()); } /** * {@inheritDoc} * * The default implementation generates the sample by calling * {@link #sample()} in a loop. */ public int[] sample(int sampleSize) { if (sampleSize <= 0) { throw new NotStrictlyPositiveException( LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); } int[] out = new int[sampleSize]; for (int i = 0; i < sampleSize; i++) { out[i] = sample(); } return out; } /** * Computes the cumulative probability function and checks for {@code NaN} * values returned. Throws {@code MathInternalError} if the value is * {@code NaN}. Rethrows any exception encountered evaluating the cumulative * probability function. Throws {@code MathInternalError} if the cumulative * probability function returns {@code NaN}. * * @param argument input value * @return the cumulative probability * @throws MathInternalError if the cumulative probability is {@code NaN} */ private double checkedCumulativeProbability(int argument) throws MathInternalError { double result = Double.NaN; result = cumulativeProbability(argument); if (Double.isNaN(result)) { throw new MathInternalError(LocalizedFormats .DISCRETE_CUMULATIVE_PROBABILITY_RETURNED_NAN, argument); } return result; } /** * For a random variable {@code X} whose values are distributed according to * this distribution, this method returns {@code log(P(X = x))}, where * {@code log} is the natural logarithm. In other words, this method * represents the logarithm of the probability mass function (PMF) for the * distribution. Note that due to the floating point precision and * under/overflow issues, this method will for some distributions be more * precise and faster than computing the logarithm of * {@link #probability(int)}. * <p> * The default implementation simply computes the logarithm of {@code probability(x)}.</p> * * @param x the point at which the PMF is evaluated * @return the logarithm of the value of the probability mass function at {@code x} */ public double logProbability(int x) { return FastMath.log(probability(x)); } }