/* * File: ProbabilityMassFunctionUtil.java * Authors: Kevin R. Dixon * Company: Sandia National Laboratories * Project: Cognitive Foundry * * Copyright Feb 3, 2009, Sandia Corporation. * Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive * license for use of this work by or on behalf of the U.S. Government. * Export of this program may require a license from the United States * Government. See CopyrightHistory.txt for complete details. * */ package gov.sandia.cognition.statistics; import gov.sandia.cognition.annotation.PublicationReference; import gov.sandia.cognition.annotation.PublicationType; import gov.sandia.cognition.collection.CollectionUtil; import gov.sandia.cognition.learning.data.DefaultInputOutputPair; import gov.sandia.cognition.learning.data.InputOutputPair; import gov.sandia.cognition.math.ProbabilityUtil; import gov.sandia.cognition.math.UnivariateStatisticsUtil; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Random; /** * Utility methods for helping computations in PMFs. * @author Kevin R. Dixon * @since 3.0 */ public class ProbabilityMassFunctionUtil { /** * Computes the information-theoretic entropy of the PMF in bits. * @param <DataType> * Type of data on the domain of the PMF. * @param pmf * PMF to compute the entropy. * @return * Entropy in bits of the given PMF. */ @PublicationReference( author="Wikipedia", title="Entropy (information theory)", type=PublicationType.WebPage, year=2009, url="http://en.wikipedia.org/wiki/Entropy_(Information_theory)" ) public static <DataType> double getEntropy( final ProbabilityMassFunction<DataType> pmf ) { // Compute the entropy by looping over the values in the maps Collection<? extends DataType> domain = pmf.getDomain(); ArrayList<Double> data = new ArrayList<Double>( domain.size() ); for( DataType input : domain ) { data.add( pmf.evaluate( input ) ); } return UnivariateStatisticsUtil.computeEntropy(data); } /** * Samples from the ProbabilityMassFunction. The return value will be * sampled according to the given PMF. * @param <DataType> Type of data to use * @param pmf PMF from which to sample * @param random Random to sample from * @param numSamples Number of samples to draw from the given PMF. * @return * Samples drawn according to the given PMF. */ public static <DataType> ArrayList<DataType> sample( final ProbabilityMassFunction<DataType> pmf, final Random random, final int numSamples ) { ArrayList<DataType> samples = new ArrayList<DataType>(numSamples); sampleInto(pmf, random, numSamples, samples); return samples; } /** * Samples from the ProbabilityMassFunction. The return value will be * sampled according to the given PMF. * @param <DataType> Type of data to use * @param pmf PMF from which to sample * @param random Random to sample from * @param numSamples Number of samples to draw from the given PMF. * @param output Collection to add samples drawn according to the given PMF. */ public static <DataType> void sampleInto( final ProbabilityMassFunction<DataType> pmf, final Random random, final int numSamples, final Collection<? super DataType> output) { if (numSamples == 1) { output.add(sampleSingle(pmf, random)); } else if (numSamples > 1) { sampleMultipleInto(pmf, random, numSamples, output); } } /** * Draws a single sample from the given PMF * @param <DataType> * Type of observations generated by the PMF * @param pmf * PMF from which to draw. * @param random * Random number generator * @return * Single sample from the PMF */ public static <DataType> DataType sampleSingle( final ProbabilityMassFunction<DataType> pmf, final Random random ) { double p = random.nextDouble(); for( DataType x : pmf.getDomain() ) { p -= pmf.evaluate(x); if( p <= 0.0 ) { return x; } } return null; } /** * Samples from the ProbabilityMassFunction. The return value will be * sampled according to the given PMF. * @param <DataType> Type of data to use * @param pmf PMF from which to sample * @param random Random to sample from * @param numSamples Number of samples to draw from the given PMF. * @return * Samples drawn according to the given PMF. */ @SuppressWarnings("unchecked") public static <DataType> ArrayList<DataType> sampleMultiple( final ProbabilityMassFunction<DataType> pmf, final Random random, final int numSamples ) { final ArrayList<DataType> result = new ArrayList<DataType>(numSamples); sampleMultipleInto(pmf, random, numSamples, result); return result; } /** * Samples from the ProbabilityMassFunction. The return value will be * sampled according to the given PMF. * @param <DataType> Type of data to use * @param pmf PMF from which to sample * @param random Random to sample from * @param numSamples Number of samples to draw from the given PMF. * @param output Collection to add samples drawn according to the given PMF. */ @SuppressWarnings("unchecked") public static <DataType> void sampleMultipleInto( final ProbabilityMassFunction<DataType> pmf, final Random random, final int numSamples, final Collection<? super DataType> output) { // Compute the cumulative probability counts. // We can then use binary search to make the lookup process VERY zoomy. int N = pmf.getDomain().size(); double[] cumulativeProbabilities = new double[ N ]; ArrayList<? extends DataType> domain = CollectionUtil.asArrayList( pmf.getDomain() ); double psum = 0.0; final int domainSize = domain.size(); for( int index = 0; index < domainSize; index++ ) { final DataType x = domain.get(index); psum += pmf.evaluate( x ); cumulativeProbabilities[index] = psum; } sampleMultipleInto(cumulativeProbabilities, domain, random, numSamples, output); } /** * Samples multiple elements from the domain proportionately to the * cumulative weights in the given weight array using a fast * binary search algorithm * @param <DataType> * Type of data to be sampled * @param cumulativeWeights * Cumulative weights to sample from * @param domain * Domain from which to sample * @param random * Random number generator * @param numSamples * Number of samples to draw from the distribution * @return * Samples draw proportionately from the cumulative weights */ public static <DataType> ArrayList<DataType> sampleMultiple( final double[] cumulativeWeights, final List<? extends DataType> domain, final Random random, final int numSamples ) { final ArrayList<DataType> result = new ArrayList<DataType>(numSamples); sampleMultipleInto(cumulativeWeights, domain, random, numSamples, result); return result; } /** * Samples multiple elements from the domain proportionately to the * cumulative weights in the given weight array using a fast * binary search algorithm * @param <DataType> * Type of data to be sampled * @param cumulativeWeights * Cumulative weights to sample from * @param domain * Domain from which to sample * @param random * Random number generator * @param numSamples * Number of samples to draw from the distribution * @param output * The collection to put the samples in. */ public static <DataType> void sampleMultipleInto( final double[] cumulativeWeights, final List<? extends DataType> domain, final Random random, final int numSamples, final Collection<? super DataType> output) { for( int n = 0; n < numSamples; n++ ) { output.add(sample(cumulativeWeights, domain, random)); } } /** * Samples an element from the domain proportionately to the * cumulative weights in the given weight array using a fast * binary search algorithm. * @param <DataType> * Type of data to be sampled * @param cumulativeWeights * Cumulative weights to sample from * @param domain * Domain from which to sample * @param random * Random number generator * @return * A sample from the domain according to the weights. */ public static <DataType> DataType sample( final double[] cumulativeWeights, final List<? extends DataType> domain, final Random random) { final int index = DiscreteSamplingUtil.sampleIndexFromCumulativeProportions( random, cumulativeWeights); return domain.get(index); } /** * Samples a single element from the domain proportionately to the given * weights * @param <DataType> * Type of data to be sampled * @param weights * Weights from which we will sample proportionately * @param domain * Domain from which to return the result * @param random * Random number generator * @return * A single sample from the domain proportionately from the weights */ public static <DataType> DataType sampleSingle( final double[] weights, final Collection<? extends DataType> domain, final Random random ) { double sum = 0.0; final int N = weights.length; for( int n = 0; n < N; n++ ) { sum += weights[n]; } double s = sum * random.nextDouble(); int n = 0; for( DataType value : domain ) { s -= weights[n]; if( s <= 0.0 ) { return value; } } return null; } /** * Inverts the discrete CDF, that is p=Pr{x<=X}. * @param <DataType> Type of number from the distribution * @param cdf * CDF of a discrete distribution. * @param p * Probability to invert, must be [0,1]. * @return * Value of x such that p >= CDF(x) and p <= CDF(x_next). */ public static <DataType extends Number> InputOutputPair<DataType,Double> inverse( final CumulativeDistributionFunction<DataType> cdf, final double p ) { ProbabilityUtil.assertIsProbability(p); for( DataType x : ((DiscreteDistribution<DataType>) cdf).getDomain() ) { final double px = cdf.evaluate(x); if( p <= px ) { return new DefaultInputOutputPair<DataType, Double>( x, px ); } } throw new IllegalArgumentException( "Could not invert CDF for p=" + p ); } /** * Computes the CDF value for the given PMF for the input. That is, * the value of P=CDF(input)=sum(PMF(x<=input)). * @param input * Input to compute the CDF of. * @param distribution * Distribution to consider. * @return * CDF value of the distirbution for the given input */ public static double computeCumulativeValue( final int input, final ClosedFormDiscreteUnivariateDistribution<? super Integer> distribution ) { int minx = distribution.getMinSupport().intValue(); ProbabilityMassFunction<? super Integer> pmf = distribution.getProbabilityFunction(); double sum = 0.0; for( int x = minx; x <= input; x++ ) { sum += pmf.evaluate(x); } return sum; } }