package gdsc.smlm.function; import gdsc.core.utils.DoubleEquality; import gdsc.core.utils.StoredDataStatistics; import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.Well19937c; import org.junit.Assert; import org.junit.Test; public class PoissonGammaGaussianFunctionTest { double[] photons = { 0, 0.25, 0.5, 1, 2, 4, 10, 100, 1000 }; double[] noise = { 30, 45, 76 }; // in electrons double[] cameraGain = { 6.5, 45 }; // ADU/e double emGain = 250; // Realistic parameters for speed test double s = 7.16; double g = 39.1; @Test public void cumulativeProbabilityIsOneWithApproximation() { for (double p : photons) for (double rn : noise) for (double cg : cameraGain) cumulativeProbabilityIsOne(p, rn, cg, true, false); } @Test public void cumulativeProbabilityIsOneWithFullIntegration() { for (double p : photons) for (double s : noise) for (double g : cameraGain) cumulativeProbabilityIsOne(p, s, g, false, false); } @Test public void cumulativeProbabilityIsOneWithSimpleIntegration() { for (double p : photons) for (double s : noise) for (double g : cameraGain) cumulativeProbabilityIsOne(p, s, g, false, true); } @Test public void approximationCloselyMatchesFullIntegration() { double e = closelyMatchesFullIntegration(5e-2, true, false); System.out.println("Approximation max error = " + e); } @Test public void simpleIntegrationCloselyMatchesFullIntegration() { double e = closelyMatchesFullIntegration(1e-3, false, true); System.out.println("Simple integration max error = " + e); } @Test public void approximationFasterThanFullIntegration() { PoissonGammaGaussianFunction f1 = new PoissonGammaGaussianFunction(1 / g, s); f1.setUseApproximation(false); f1.setUseSimpleIntegration(false); PoissonGammaGaussianFunction f2 = new PoissonGammaGaussianFunction(1 / g, s); f2.setUseSimpleIntegration(false); f2.setUseApproximation(true); fasterThan(f1, f2); } @Test public void approximationFasterThanSimpleIntegration() { PoissonGammaGaussianFunction f1 = new PoissonGammaGaussianFunction(1 / g, s); f1.setUseApproximation(false); f1.setUseSimpleIntegration(true); PoissonGammaGaussianFunction f2 = new PoissonGammaGaussianFunction(1 / g, s); f2.setUseSimpleIntegration(false); f2.setUseApproximation(true); fasterThan(f1, f2); } @Test public void simpleIntegrationFasterThanFullIntegration() { PoissonGammaGaussianFunction f1 = new PoissonGammaGaussianFunction(1 / g, s); f1.setUseApproximation(false); f1.setUseSimpleIntegration(false); PoissonGammaGaussianFunction f2 = new PoissonGammaGaussianFunction(1 / g, s); f2.setUseSimpleIntegration(true); f2.setUseApproximation(false); fasterThan(f1, f2); } private void cumulativeProbabilityIsOne(final double mu, final double rn, final double cg, boolean useApproximation, boolean useSimpleIntegration) { // Read noise should be in proportion to the camera gain double s = rn / cg; double g = emGain / cg; PoissonGammaGaussianFunction f = new PoissonGammaGaussianFunction(1 / g, s); f.setUseApproximation(useApproximation); f.setUseSimpleIntegration(useSimpleIntegration); double p = 0; int min = 1; int max = 0; // Evaluate an initial range. // Gaussian should have >99% within +/- 3s // Poisson will have mean mu with a variance mu. // At large mu it is approximately normal so use 3 sqrt(mu) for the range added to the mean if (mu > 0) { min = (int) -Math.ceil(3 * s); max = (int) Math.ceil(mu + 3 * (Math.max(s, Math.sqrt(mu)))); for (int x = min; x <= max; x++) { final double pp = f.likelihood(x, mu); //System.out.printf("x=%d, p=%f\n", x, pp); p += pp; } if (p > 1.01) Assert.fail("P > 1: " + p); } // We have most of the probability density. // Now keep evaluating up and down until no difference final double changeTolerance = 1e-6; for (int x = min - 1;; x--) { final double pp = f.likelihood(x, mu); //System.out.printf("x=%d, p=%f\n", x, pp); p += pp; if (pp / p < changeTolerance) break; } for (int x = max + 1;; x++) { final double pp = f.likelihood(x, mu); //System.out.printf("x=%d, p=%f\n", x, pp); p += pp; if (pp / p < changeTolerance) break; } System.out.printf("%s : mu=%f, rn=%f, cg=%f, s=%f, g=%f, p=%f\n", getName(f), mu, rn, cg, s, g, p); Assert.assertEquals(String.format("mu=%f, rn=%f, cg=%f, s=%f, g=%f", mu, rn, cg, s, g), 1, p, 0.02); } private double closelyMatchesFullIntegration(double error, boolean useApproximation, boolean useSimpleIntegration) { //DoubleEquality eq = new DoubleEquality(error, 1e-7); double maxError = 0; for (double rn : noise) for (double cg : cameraGain) { // Read noise should be in proportion to the camera gain double s = rn / cg; double g = emGain / cg; PoissonGammaGaussianFunction f1 = new PoissonGammaGaussianFunction(1 / g, s); f1.setUseApproximation(false); f1.setUseSimpleIntegration(false); PoissonGammaGaussianFunction f2 = new PoissonGammaGaussianFunction(1 / g, s); f2.setUseApproximation(useApproximation); f2.setUseSimpleIntegration(useSimpleIntegration); for (double p : photons) { for (double x = p * 0.5 - 5 * s; x < 2 * p; x += 1) { double p1 = f1.likelihood(x, p); double p2 = f2.likelihood(x, p); double relativeError = DoubleEquality.relativeError(p1, p2); boolean equal = relativeError <= error; //eq.almostEqualRelativeOrAbsolute(p1, p2); if (!equal) { Assert.assertTrue(String.format("rn=%f, cg=%f, s=%f, g=%f, p=%f, x=%f: %f != %f (%f)", rn, cg, s, g, p, x, p1, p2, relativeError), equal); } if (maxError < relativeError) maxError = relativeError; } } } return maxError; } private void fasterThan(PoissonGammaGaussianFunction f1, PoissonGammaGaussianFunction f2) { // Generate realistic data from the probability mass function double[][] samples = new double[photons.length][]; for (int j = 0; j < photons.length; j++) { int start = (int) (4 * -s); int u = start; StoredDataStatistics stats = new StoredDataStatistics(); while (stats.getSum() < 0.995) { stats.add(f1.likelihood(u, photons[j])); u++; } // Generate cumulative probability double[] data = stats.getValues(); for (int i = 1; i < data.length; i++) data[i] += data[i - 1]; // Sample RandomGenerator rand = new Well19937c(); double[] sample = new double[1000]; for (int i = 0; i < sample.length; i++) { final double p = rand.nextDouble(); int x = 0; while (x < data.length && data[x] < p) x++; sample[i] = x; } samples[j] = sample; } // Warm-up run(f1, samples, photons); run(f2, samples, photons); long t1 = 0; for (int i = 0; i < 5; i++) t1 += run(f1, samples, photons); long t2 = 0; for (int i = 0; i < 5; i++) t2 += run(f2, samples, photons); System.out.printf("%s %d -> %s %d = %f x\n", getName(f1), t1, getName(f2), t2, (double) t1 / t2); } private long run(PoissonGammaGaussianFunction f, double[][] samples, double[] photons) { long start = System.nanoTime(); for (int j = 0; j < photons.length; j++) { final double p = photons[j]; for (double x : samples[j]) f.likelihood(x, p); } return System.nanoTime() - start; } private String getName(PoissonGammaGaussianFunction f) { if (f.isUseApproximation()) return "Approximation"; if (f.isUseSimpleIntegration()) return "Simple integration"; return "Full integration"; } }