package gdsc.smlm.fitting.nonlinear.gradient; import java.util.ArrayList; import org.apache.commons.math3.random.RandomDataGenerator; import org.apache.commons.math3.random.RandomGenerator; import org.apache.commons.math3.random.Well19937c; import org.ejml.data.DenseMatrix64F; import org.junit.Assert; import org.junit.Test; import gdsc.core.ij.Utils; import gdsc.core.utils.DoubleEquality; import gdsc.smlm.TestSettings; import gdsc.smlm.fitting.FisherInformationMatrix; import gdsc.smlm.function.Gradient1Function; import gdsc.smlm.function.gaussian.Gaussian2DFunction; import gdsc.smlm.function.gaussian.GaussianFunctionFactory; import gdsc.smlm.function.gaussian.erf.ErfGaussian2DFunction; public class PoissonGradientProcedureTest { boolean speedTests = true; DoubleEquality eq = new DoubleEquality(6, 1e-16); int MAX_ITER = 20000; int blockWidth = 10; double Background = 0.5; double Signal = 100; double Angle = Math.PI; double Xpos = 5; double Ypos = 5; double Xwidth = 1.2; double Ywidth = 1.2; RandomDataGenerator rdg; @Test public void gradientProcedureFactoryCreatesOptimisedProcedures() { Assert.assertEquals(PoissonGradientProcedureFactory.create(new DummyGradientFunction(6)).getClass(), PoissonGradientProcedure6.class); Assert.assertEquals(PoissonGradientProcedureFactory.create(new DummyGradientFunction(5)).getClass(), PoissonGradientProcedure5.class); Assert.assertEquals(PoissonGradientProcedureFactory.create(new DummyGradientFunction(4)).getClass(), PoissonGradientProcedure4.class); double[] b = null; Assert.assertEquals(PoissonGradientProcedureFactory.create(b, new DummyGradientFunction(6)).getClass(), PoissonGradientProcedure6.class); Assert.assertEquals(PoissonGradientProcedureFactory.create(b, new DummyGradientFunction(5)).getClass(), PoissonGradientProcedure5.class); Assert.assertEquals(PoissonGradientProcedureFactory.create(b, new DummyGradientFunction(4)).getClass(), PoissonGradientProcedure4.class); b = new double[new DummyGradientFunction(4).size()]; Assert.assertEquals(PoissonGradientProcedureFactory.create(b, new DummyGradientFunction(6)).getClass(), PoissonGradientProcedureB6.class); Assert.assertEquals(PoissonGradientProcedureFactory.create(b, new DummyGradientFunction(5)).getClass(), PoissonGradientProcedureB5.class); Assert.assertEquals(PoissonGradientProcedureFactory.create(b, new DummyGradientFunction(4)).getClass(), PoissonGradientProcedureB4.class); Assert.assertEquals(PoissonGradientProcedureFactory.create(b, new DummyGradientFunction(1)).getClass(), PoissonGradientProcedureB.class); } @Test public void gradientProcedureComputesSameAsGradientCalculator() { gradientProcedureComputesSameAsGradientCalculator(4); gradientProcedureComputesSameAsGradientCalculator(5); gradientProcedureComputesSameAsGradientCalculator(6); gradientProcedureComputesSameAsGradientCalculator(11); gradientProcedureComputesSameAsGradientCalculator(21); } @Test public void gradientProcedureIsNotSlowerThanGradientCalculator() { gradientProcedureIsNotSlowerThanGradientCalculator(4); gradientProcedureIsNotSlowerThanGradientCalculator(5); gradientProcedureIsNotSlowerThanGradientCalculator(6); // 2 peaks gradientProcedureIsNotSlowerThanGradientCalculator(11); // 4 peaks gradientProcedureIsNotSlowerThanGradientCalculator(21); } private void gradientProcedureComputesSameAsGradientCalculator(int nparams) { int iter = 10; rdg = new RandomDataGenerator(new Well19937c(30051977)); ArrayList<double[]> paramsList = new ArrayList<double[]>(iter); createFakeParams(nparams, iter, paramsList); int n = blockWidth * blockWidth; FakeGradientFunction func = new FakeGradientFunction(blockWidth, nparams); GradientCalculator calc = GradientCalculatorFactory.newCalculator(nparams, false); String name = String.format("[%d]", nparams); for (int i = 0; i < paramsList.size(); i++) { PoissonGradientProcedure p = PoissonGradientProcedureFactory.create(func); p.computeFisherInformation(paramsList.get(i)); double[][] m = calc.fisherInformationMatrix(n, paramsList.get(i), func); // Exactly the same ... double[] al = p.getLinear(); Assert.assertArrayEquals(name + " Observations: Not same alpha @ " + i, al, new DenseMatrix64F(m).data, 0); double[][] am = p.getMatrix(); for (int j = 0; j < nparams; j++) Assert.assertArrayEquals(name + " Observations: Not same alpha @ " + i, am[j], m[j], 0); } } private abstract class Timer { private int loops; int min; Timer() { } Timer(int min) { this.min = min; } long getTime() { // Run till stable timing long t1 = time(); for (int i = 0; i < 10; i++) { long t2 = t1; t1 = time(); if (loops >= min && DoubleEquality.relativeError(t1, t2) < 0.02) // 2% difference break; } return t1; } long time() { loops++; long t = System.nanoTime(); run(); t = System.nanoTime() - t; //System.out.printf("[%d] Time = %d\n", loops, t); return t; } abstract void run(); } private void gradientProcedureIsNotSlowerThanGradientCalculator(final int nparams) { org.junit.Assume.assumeTrue(speedTests || TestSettings.RUN_SPEED_TESTS); final int iter = 1000; rdg = new RandomDataGenerator(new Well19937c(30051977)); final ArrayList<double[]> paramsList = new ArrayList<double[]>(iter); createFakeParams(nparams, iter, paramsList); final int n = blockWidth * blockWidth; final FakeGradientFunction func = new FakeGradientFunction(blockWidth, nparams); GradientCalculator calc = GradientCalculatorFactory.newCalculator(nparams, false); for (int i = 0; i < paramsList.size(); i++) calc.fisherInformationMatrix(n, paramsList.get(i), func); for (int i = 0; i < paramsList.size(); i++) { PoissonGradientProcedure p = PoissonGradientProcedureFactory.create(func); p.computeFisherInformation(paramsList.get(i)); } // Realistic loops for an optimisation final int loops = 15; // Run till stable timing Timer t1 = new Timer() { @Override void run() { for (int i = 0, k = 0; i < iter; i++) { GradientCalculator calc = GradientCalculatorFactory.newCalculator(nparams, false); for (int j = loops; j-- > 0;) calc.fisherInformationMatrix(n, paramsList.get(k++ % iter), func); } } }; long time1 = t1.getTime(); Timer t2 = new Timer(t1.loops) { @Override void run() { for (int i = 0, k = 0; i < iter; i++) { PoissonGradientProcedure p = PoissonGradientProcedureFactory.create(func); for (int j = loops; j-- > 0;) p.computeFisherInformation(paramsList.get(k++ % iter)); } } }; long time2 = t2.getTime(); log("GradientCalculator = %d : PoissonGradientProcedure %d = %d : %fx\n", time1, nparams, time2, (1.0 * time1) / time2); if (TestSettings.ASSERT_SPEED_TESTS) { // Add contingency Assert.assertTrue(time2 < time1 * 1.5); } } @Test public void gradientProcedureUnrolledComputesSameAsGradientProcedure() { gradientProcedureUnrolledComputesSameAsGradientProcedure(4, false); gradientProcedureUnrolledComputesSameAsGradientProcedure(5, false); gradientProcedureUnrolledComputesSameAsGradientProcedure(6, false); } @Test public void gradientProcedureUnrolledComputesSameAsGradientProcedureWithPrecomputed() { gradientProcedureUnrolledComputesSameAsGradientProcedure(4, true); gradientProcedureUnrolledComputesSameAsGradientProcedure(5, true); gradientProcedureUnrolledComputesSameAsGradientProcedure(6, true); } private void gradientProcedureUnrolledComputesSameAsGradientProcedure(int nparams, boolean precomputed) { int iter = 10; rdg = new RandomDataGenerator(new Well19937c(30051977)); ArrayList<double[]> paramsList = new ArrayList<double[]>(iter); createFakeParams(nparams, iter, paramsList); FakeGradientFunction func = new FakeGradientFunction(blockWidth, nparams); double[] b = (precomputed) ? Utils.newArray(func.size(), 0.1, 1.3) : null; String name = String.format("[%d]", nparams); for (int i = 0; i < paramsList.size(); i++) { PoissonGradientProcedure p1 = (precomputed) ? new PoissonGradientProcedureB(b, func) : new PoissonGradientProcedure(func); p1.computeFisherInformation(paramsList.get(i)); PoissonGradientProcedure p2 = PoissonGradientProcedureFactory.create(b, func); p2.computeFisherInformation(paramsList.get(i)); // Exactly the same ... Assert.assertArrayEquals(name + " Observations: Not same alpha @ " + i, p1.getLinear(), p2.getLinear(), 0); double[][] am1 = p1.getMatrix(); double[][] am2 = p2.getMatrix(); for (int j = 0; j < nparams; j++) Assert.assertArrayEquals(name + " Observations: Not same alpha @ " + i, am1[j], am2[j], 0); } } @Test public void gradientProcedureIsFasterUnrolledThanGradientProcedure() { gradientProcedureIsFasterUnrolledThanGradientProcedure(4, false); gradientProcedureIsFasterUnrolledThanGradientProcedure(5, false); gradientProcedureIsFasterUnrolledThanGradientProcedure(6, false); } @Test public void gradientProcedureIsFasterUnrolledThanGradientProcedureWithPrecomputed() { gradientProcedureIsFasterUnrolledThanGradientProcedure(4, true); gradientProcedureIsFasterUnrolledThanGradientProcedure(5, true); gradientProcedureIsFasterUnrolledThanGradientProcedure(6, true); } private void gradientProcedureIsFasterUnrolledThanGradientProcedure(final int nparams, final boolean precomputed) { org.junit.Assume.assumeTrue(speedTests || TestSettings.RUN_SPEED_TESTS); final int iter = 100; rdg = new RandomDataGenerator(new Well19937c(30051977)); final ArrayList<double[]> paramsList = new ArrayList<double[]>(iter); createFakeParams(nparams, iter, paramsList); // Remove the timing of the function call by creating a dummy function final Gradient1Function func = new FakeGradientFunction(blockWidth, nparams); final double[] b = (precomputed) ? Utils.newArray(func.size(), 0.1, 1.3) : null; for (int i = 0; i < paramsList.size(); i++) { PoissonGradientProcedure p1 = (precomputed) ? new PoissonGradientProcedureB(b, func) : new PoissonGradientProcedure(func); p1.computeFisherInformation(paramsList.get(i)); p1.computeFisherInformation(paramsList.get(i)); PoissonGradientProcedure p2 = PoissonGradientProcedureFactory.create(b, func); p2.computeFisherInformation(paramsList.get(i)); p2.computeFisherInformation(paramsList.get(i)); // Check they are the same Assert.assertArrayEquals("M " + i, p1.getLinear(), p2.getLinear(), 0); } // Realistic loops for an optimisation final int loops = 15; // Run till stable timing Timer t1 = new Timer() { @Override void run() { for (int i = 0, k = 0; i < paramsList.size(); i++) { PoissonGradientProcedure p1 = (precomputed) ? new PoissonGradientProcedureB(b, func) : new PoissonGradientProcedure(func); for (int j = loops; j-- > 0;) p1.computeFisherInformation(paramsList.get(k++ % iter)); } } }; long time1 = t1.getTime(); Timer t2 = new Timer(t1.loops) { @Override void run() { for (int i = 0, k = 0; i < paramsList.size(); i++) { PoissonGradientProcedure p2 = PoissonGradientProcedureFactory.create(b, func); for (int j = loops; j-- > 0;) p2.computeFisherInformation(paramsList.get(k++ % iter)); } } }; long time2 = t2.getTime(); log("Precomputed=%b : Standard %d : Unrolled %d = %d : %fx\n", precomputed, time1, nparams, time2, (1.0 * time1) / time2); Assert.assertTrue(time2 < time1); } @Test public void crlbIsHigherWithPrecomputed() { int iter = 10; rdg = new RandomDataGenerator(new Well19937c(30051977)); ErfGaussian2DFunction func = (ErfGaussian2DFunction) GaussianFunctionFactory.create2D(1, 10, 10, GaussianFunctionFactory.FIT_ERF_FREE_CIRCLE, null); double[] a = new double[7]; int n = func.getNumberOfGradients(); // Get a background double[] b = new double[func.size()]; for (int i = 0; i < b.length; i++) b[i] = rdg.nextUniform(1, 2); for (int i = 0; i < iter; i++) { a[Gaussian2DFunction.BACKGROUND] = rdg.nextUniform(0.1, 0.3); a[Gaussian2DFunction.SIGNAL] = rdg.nextUniform(100, 300); a[Gaussian2DFunction.X_POSITION] = rdg.nextUniform(4, 6); a[Gaussian2DFunction.Y_POSITION] = rdg.nextUniform(4, 6); a[Gaussian2DFunction.X_SD] = rdg.nextUniform(1, 1.3); a[Gaussian2DFunction.Y_SD] = rdg.nextUniform(1, 1.3); PoissonGradientProcedure p1 = PoissonGradientProcedureFactory.create(func); p1.computeFisherInformation(a); PoissonGradientProcedure p2 = PoissonGradientProcedureFactory.create(b, func); p2.computeFisherInformation(a); FisherInformationMatrix m1 = new FisherInformationMatrix(p1.getLinear(), n); FisherInformationMatrix m2 = new FisherInformationMatrix(p2.getLinear(), n); double[] crlb1 = m1.crlb(); double[] crlb2 = m2.crlb(); Assert.assertNotNull(crlb1); Assert.assertNotNull(crlb2); //System.out.printf("%s : %s\n", Arrays.toString(crlb1), Arrays.toString(crlb2)); for (int j = 0; j < n; j++) Assert.assertTrue(crlb1[j] < crlb2[j]); } } protected int[] createFakeData(int nparams, int iter, ArrayList<double[]> paramsList, ArrayList<double[]> yList) { int[] x = new int[blockWidth * blockWidth]; for (int i = 0; i < x.length; i++) x[i] = i; for (int i = 0; i < iter; i++) { double[] params = new double[nparams]; double[] y = createFakeData(params); paramsList.add(params); yList.add(y); } return x; } private double[] createFakeData(double[] params) { int n = blockWidth * blockWidth; RandomGenerator r = rdg.getRandomGenerator(); for (int i = 0; i < params.length; i++) { params[i] = r.nextDouble(); } double[] y = new double[n]; for (int i = 0; i < y.length; i++) { y[i] = r.nextDouble(); } return y; } protected void createFakeParams(int nparams, int iter, ArrayList<double[]> paramsList) { for (int i = 0; i < iter; i++) { double[] params = new double[nparams]; createFakeParams(params); paramsList.add(params); } } private void createFakeParams(double[] params) { RandomGenerator r = rdg.getRandomGenerator(); for (int i = 0; i < params.length; i++) { params[i] = r.nextDouble(); } } protected ArrayList<double[]> copyList(ArrayList<double[]> paramsList) { ArrayList<double[]> params2List = new ArrayList<double[]>(paramsList.size()); for (int i = 0; i < paramsList.size(); i++) { params2List.add(copydouble(paramsList.get(i))); } return params2List; } private double[] copydouble(double[] d) { double[] d2 = new double[d.length]; for (int i = 0; i < d.length; i++) d2[i] = d[i]; return d2; } void log(String format, Object... args) { System.out.printf(format, args); } }