package gdsc.smlm.function.gaussian;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
/**
* Contains tests for the Gaussian functions in single or double precision
* <p>
* The tests show that there is very little (if any) time penalty when using double precision for the calculations.
* However the precision of the single-precision functions is 1e-4 when using reasonable Gaussian parameters. This could
* effect the convergence of optimisers/fitters if using single precision math.
*/
public class PrecisionTest
{
int Single = 1;
int Double = 2;
private int MAX_ITER = 200000;
double SPEED_UP_FACTOR = 1.1;
int maxx = 10;
// Use realistic values for a camera with a bias of 500
static double[] params2 = new double[] { 500.23, 300.12, 0, 5.12, 5.23, 1.11, 1.11 };
static float[] params1 = toFloat(params2);
// Stripped down Gaussian functions copied from the gdsc.smlm.fitting.function.gaussian package
public abstract class Gaussian
{
public static final int BACKGROUND = 0;
public static final int AMPLITUDE = 1;
public static final int ANGLE = 2;
public static final int X_POSITION = 3;
public static final int Y_POSITION = 4;
public static final int X_SD = 5;
public static final int Y_SD = 6;
int maxx;
public Gaussian(int maxx)
{
this.maxx = maxx;
}
public void setMaxX(int maxx)
{
this.maxx = maxx;
}
}
public interface DoublePrecision
{
public void setMaxX(int maxx);
public void initialise(double[] a);
public double eval(final int x, final double[] dyda);
public double eval(final int x);
}
public interface SinglePrecision
{
public void setMaxX(int maxx);
public void initialise(float[] a);
public float eval(final int x, final float[] dyda);
public float eval(final int x);
}
public class DoubleCircularGaussian extends Gaussian implements DoublePrecision
{
double background;
double amplitude;
double x0pos;
double x1pos;
double aa;
double aa2;
double ax;
public DoubleCircularGaussian(int maxx)
{
super(maxx);
}
public void initialise(double[] a)
{
background = a[BACKGROUND];
amplitude = a[AMPLITUDE];
x0pos = a[X_POSITION];
x1pos = a[Y_POSITION];
final double sx = a[X_SD];
final double sx2 = sx * sx;
final double sx3 = sx2 * sx;
aa = -0.5 / sx2;
aa2 = -2.0 * aa;
// For the x-width gradient
ax = 1.0 / sx3;
}
public double eval(final int x, final double[] dyda)
{
dyda[0] = 1.0;
final int x1 = x / maxx;
final int x0 = x % maxx;
return background + gaussian(x0, x1, dyda);
}
private double gaussian(final int x0, final int x1, final double[] dy_da)
{
final double h = amplitude;
final double dx = x0 - x0pos;
final double dy = x1 - x1pos;
final double dx2dy2 = dx * dx + dy * dy;
dy_da[1] = FastMath.exp(aa * (dx2dy2));
final double y = h * dy_da[1];
final double yaa2 = y * aa2;
dy_da[2] = yaa2 * dx;
dy_da[3] = yaa2 * dy;
dy_da[4] = y * (ax * (dx2dy2));
return y;
}
public double eval(final int x)
{
final int x1 = x / maxx;
final int x0 = x % maxx;
final double dx = x0 - x0pos;
final double dy = x1 - x1pos;
return background + amplitude * FastMath.exp(aa * (dx * dx + dy * dy));
}
}
public class SingleCircularGaussian extends Gaussian implements SinglePrecision
{
float background;
float amplitude;
float x0pos;
float x1pos;
float aa;
float aa2;
float ax;
public SingleCircularGaussian(int maxx)
{
super(maxx);
}
public void initialise(float[] a)
{
background = a[BACKGROUND];
amplitude = a[AMPLITUDE];
x0pos = a[X_POSITION];
x1pos = a[Y_POSITION];
final float sx = a[X_SD];
final float sx2 = sx * sx;
final float sx3 = sx2 * sx;
aa = -0.5f / sx2;
aa2 = -2.0f * aa;
ax = 1.0f / sx3;
}
public float eval(final int x, final float[] dyda)
{
dyda[0] = 1.0f;
final int x1 = x / maxx;
final int x0 = x % maxx;
return background + gaussian(x0, x1, dyda);
}
private float gaussian(final int x0, final int x1, final float[] dy_da)
{
final float h = amplitude;
final float dx = x0 - x0pos;
final float dy = x1 - x1pos;
final float dx2dy2 = dx * dx + dy * dy;
dy_da[1] = (float) FastMath.exp(aa * (dx2dy2));
final float y = h * dy_da[1];
final float yaa2 = y * aa2;
dy_da[2] = yaa2 * dx;
dy_da[3] = yaa2 * dy;
dy_da[4] = y * (ax * (dx2dy2));
return y;
}
public float eval(final int x)
{
final int x1 = x / maxx;
final int x0 = x % maxx;
final float dx = x0 - x0pos;
final float dy = x1 - x1pos;
return background + amplitude * (float) (FastMath.exp(aa * (dx * dx + dy * dy)));
}
}
public class DoubleFixedGaussian extends Gaussian implements DoublePrecision
{
double width;
double background;
double amplitude;
double x0pos;
double x1pos;
double aa;
double aa2;
public DoubleFixedGaussian(int maxx)
{
super(maxx);
}
public void initialise(double[] a)
{
background = a[BACKGROUND];
amplitude = a[AMPLITUDE];
x0pos = a[X_POSITION];
x1pos = a[Y_POSITION];
width = a[X_SD];
final double sx = a[X_SD];
final double sx2 = sx * sx;
aa = -0.5 / sx2;
aa2 = -2.0 * aa;
}
public double eval(final int x, final double[] dyda)
{
dyda[0] = 1.0;
final int x1 = x / maxx;
final int x0 = x % maxx;
return background + gaussian(x0, x1, dyda);
}
private double gaussian(final int x0, final int x1, final double[] dy_da)
{
final double h = amplitude;
final double dx = x0 - x0pos;
final double dy = x1 - x1pos;
dy_da[1] = FastMath.exp(aa * (dx * dx + dy * dy));
final double y = h * dy_da[1];
final double yaa2 = y * aa2;
dy_da[2] = yaa2 * dx;
dy_da[3] = yaa2 * dy;
return y;
}
public double eval(final int x)
{
final int x1 = x / maxx;
final int x0 = x % maxx;
final double dx = x0 - x0pos;
final double dy = x1 - x1pos;
return background + amplitude * FastMath.exp(aa * (dx * dx + dy * dy));
}
}
public class SingleFixedGaussian extends Gaussian implements SinglePrecision
{
float width;
float background;
float amplitude;
float x0pos;
float x1pos;
float aa;
float aa2;
public SingleFixedGaussian(int maxx)
{
super(maxx);
}
public void initialise(float[] a)
{
background = a[BACKGROUND];
amplitude = a[AMPLITUDE];
x0pos = a[X_POSITION];
x1pos = a[Y_POSITION];
width = a[X_SD];
final float sx = a[X_SD];
final float sx2 = sx * sx;
aa = -0.5f / sx2;
aa2 = -2.0f * aa;
}
public float eval(final int x, final float[] dyda)
{
dyda[0] = 1.0f;
final int x1 = x / maxx;
final int x0 = x % maxx;
return background + gaussian(x0, x1, dyda);
}
private float gaussian(final int x0, final int x1, final float[] dy_da)
{
final float h = amplitude;
final float dx = x0 - x0pos;
final float dy = x1 - x1pos;
dy_da[1] = (float) (FastMath.exp(aa * (dx * dx + dy * dy)));
final float y = h * dy_da[1];
final float yaa2 = y * aa2;
dy_da[2] = yaa2 * dx;
dy_da[3] = yaa2 * dy;
return y;
}
public float eval(final int x)
{
final int x1 = x / maxx;
final int x0 = x % maxx;
final float dx = x0 - x0pos;
final float dy = x1 - x1pos;
return background + amplitude * (float) (FastMath.exp(aa * (dx * dx + dy * dy)));
}
}
@Test
public void circularFunctionPrecisionIs3sf()
{
functionsComputeSameValue(maxx, new SingleCircularGaussian(maxx), new DoubleCircularGaussian(maxx), 1e-3);
}
@Test
public void circularFunctionPrecisionIs4sf()
{
functionsComputeSameValue(maxx, new SingleCircularGaussian(maxx), new DoubleCircularGaussian(maxx), 1e-4);
}
@Test(expected = java.lang.AssertionError.class)
public void circularFunctionPrecisionIsNot5sf()
{
functionsComputeSameValue(maxx, new SingleCircularGaussian(maxx), new DoubleCircularGaussian(maxx), 1e-5);
}
@Test(expected = java.lang.AssertionError.class)
public void circularFunctionsPrecisionIsNot3sfAtLargeXY()
{
int maxx = this.maxx;
try
{
for (;;)
{
maxx *= 2;
System.out.printf("maxx = %d\n", maxx);
functionsComputeSameValue(maxx, new SingleCircularGaussian(maxx), new DoubleCircularGaussian(maxx),
1e-3);
}
}
catch (AssertionError e)
{
System.out.println(e.getMessage());
//e.printStackTrace();
throw e;
}
}
@Test(expected = java.lang.AssertionError.class)
public void circularSinglePrecisionIsNotMuchFasterWithGradients()
{
singlePrecisionIsFasterWithGradients(maxx, new SingleCircularGaussian(maxx), new DoubleCircularGaussian(maxx),
false);
}
@Test(expected = java.lang.AssertionError.class)
public void circularSinglePrecisionIsNotMuchFaster()
{
singlePrecisionIsFaster(maxx, new SingleCircularGaussian(maxx), new DoubleCircularGaussian(maxx), false);
}
@Test(expected = java.lang.AssertionError.class)
public void circularSinglePrecisionIsNotMuchFasterWithGradientsNoSum()
{
singlePrecisionIsFasterWithGradients(maxx, new SingleCircularGaussian(maxx), new DoubleCircularGaussian(maxx),
true);
}
@Test(expected = java.lang.AssertionError.class)
public void circularSinglePrecisionIsNotMuchFasterNoSum()
{
singlePrecisionIsFaster(maxx, new SingleCircularGaussian(maxx), new DoubleCircularGaussian(maxx), true);
}
@Test
public void fixedFunctionPrecisionIs3sf()
{
functionsComputeSameValue(maxx, new SingleFixedGaussian(maxx), new DoubleFixedGaussian(maxx), 1e-3);
}
@Test
public void fixedFunctionPrecisionIs4sf()
{
functionsComputeSameValue(maxx, new SingleFixedGaussian(maxx), new DoubleFixedGaussian(maxx), 1e-4);
}
@Test(expected = java.lang.AssertionError.class)
public void fixedFunctionPrecisionIsNot5sf()
{
functionsComputeSameValue(maxx, new SingleFixedGaussian(maxx), new DoubleFixedGaussian(maxx), 1e-5);
}
@Test(expected = java.lang.AssertionError.class)
public void fixedFunctionsPrecisionIsNot3sfAtLargeXY()
{
int maxx = this.maxx;
try
{
for (;;)
{
maxx *= 2;
System.out.printf("maxx = %d\n", maxx);
functionsComputeSameValue(maxx, new SingleFixedGaussian(maxx), new DoubleFixedGaussian(maxx), 1e-3);
}
}
catch (AssertionError e)
{
System.out.println(e.getMessage());
//e.printStackTrace();
throw e;
}
}
@Test(expected = java.lang.AssertionError.class)
public void fixedSinglePrecisionIsNotMuchFasterWithGradients()
{
singlePrecisionIsFasterWithGradients(maxx, new SingleFixedGaussian(maxx), new DoubleFixedGaussian(maxx), false);
}
@Test(expected = java.lang.AssertionError.class)
public void fixedSinglePrecisionIsNotMuchFaster()
{
singlePrecisionIsFaster(maxx, new SingleFixedGaussian(maxx), new DoubleFixedGaussian(maxx), false);
}
@Test(expected = java.lang.AssertionError.class)
public void fixedSinglePrecisionIsNotMuchFasterWithGradientsNoSum()
{
singlePrecisionIsFasterWithGradients(maxx, new SingleFixedGaussian(maxx), new DoubleFixedGaussian(maxx), true);
}
@Test(expected = java.lang.AssertionError.class)
public void fixedSinglePrecisionIsNotMuchFasterNoSum()
{
singlePrecisionIsFaster(maxx, new SingleFixedGaussian(maxx), new DoubleFixedGaussian(maxx), true);
}
private void functionsComputeSameValue(int maxx, SinglePrecision f1, DoublePrecision f2, final double precision)
{
f1.setMaxX(maxx);
f2.setMaxX(maxx);
float[] p1 = params1.clone();
double[] p2 = params2.clone();
p1[Gaussian.X_POSITION] = (float) (p2[Gaussian.X_POSITION] = (float) (0.123 + maxx / 2));
p1[Gaussian.Y_POSITION] = (float) (p2[Gaussian.Y_POSITION] = (float) (0.789 + maxx / 2));
f1.initialise(p1);
f2.initialise(p2);
final int n = p1.length;
float[] g1 = new float[n];
double[] g2 = new double[n];
double t1 = 0, t2 = 0;
double[] tg1 = new double[n];
double[] tg2 = new double[n];
for (int i = 0; i < maxx; i++)
{
float v1 = f1.eval(i);
t1 += v1;
double v2 = f2.eval(i);
t2 += v2;
Assert.assertEquals("Different values", v2, v1, precision);
float vv1 = f1.eval(i, g1);
double vv2 = f2.eval(i, g2);
Assert.assertEquals("Different f1 values", v1, vv1, precision);
Assert.assertEquals("Different f2 values", v2, vv2, precision);
for (int j = 0; j < n; j++)
{
tg1[j] += g1[j];
tg2[j] += g2[j];
}
Assert.assertArrayEquals("Different gradients", g2, toDouble(g1), precision);
}
Assert.assertArrayEquals("Different total gradients", tg2, tg1, precision);
Assert.assertEquals("Different totals", t2, t1, precision);
}
private void singlePrecisionIsFasterWithGradients(int maxx, SinglePrecision f1, DoublePrecision f2, boolean noSum)
{
f1.setMaxX(maxx);
f2.setMaxX(maxx);
float[] p1 = params1.clone();
double[] p2 = params2.clone();
p1[Gaussian.X_POSITION] = (float) (p2[Gaussian.X_POSITION] = (float) (0.123 + maxx / 2));
p1[Gaussian.Y_POSITION] = (float) (p2[Gaussian.Y_POSITION] = (float) (0.789 + maxx / 2));
long time1, time2;
if (noSum)
{
time1 = runSingleWithGradientsNoSum(maxx, f1, p1);
time1 = runSingleWithGradientsNoSum(maxx, f1, p1);
time1 += runSingleWithGradientsNoSum(maxx, f1, p1);
time2 = runDoubleWithGradientsNoSum(maxx, f2, p2);
time2 = runDoubleWithGradientsNoSum(maxx, f2, p2);
time2 += runDoubleWithGradientsNoSum(maxx, f2, p2);
}
else
{
time1 = runSingleWithGradients(maxx, f1, p1);
time1 = runSingleWithGradients(maxx, f1, p1);
time1 += runSingleWithGradients(maxx, f1, p1);
time2 = runDoubleWithGradients(maxx, f2, p2);
time2 = runDoubleWithGradients(maxx, f2, p2);
time2 += runDoubleWithGradients(maxx, f2, p2);
}
System.out.printf("%sGradient %s = %d, %s = %d => (%f)\n", (noSum) ? "No sum " : "", f1.getClass()
.getSimpleName(), time1, f2.getClass().getSimpleName(), time2, (double) time2 / time1);
Assert.assertTrue(time1 * SPEED_UP_FACTOR < time2);
}
@SuppressWarnings("unused")
private long runSingleWithGradients(int maxx, SinglePrecision f, float[] p)
{
f.initialise(p);
final int n = params1.length;
float[] g = new float[n];
double[] tg = new double[n];
// Warm up
for (int j = 0; j < 10; j++)
{
for (int i = 0; i < maxx; i++)
{
f.eval(i, g);
}
}
long time = System.nanoTime();
double sum = 0;
for (int j = 0; j < MAX_ITER; j++)
{
sum = 0;
for (int i = 0; i < maxx; i++)
{
sum += f.eval(i, g);
for (int k = 0; k < n; k++)
tg[k] += g[k];
}
}
return System.nanoTime() - time;
}
private long runSingleWithGradientsNoSum(int maxx, SinglePrecision f, float[] p)
{
f.initialise(p);
float[] g = new float[params1.length];
// Warm up
for (int j = 0; j < 10; j++)
{
for (int i = 0; i < maxx; i++)
{
f.eval(i, g);
}
}
long time = System.nanoTime();
for (int j = 0; j < MAX_ITER; j++)
{
for (int i = 0; i < maxx; i++)
{
f.eval(i, g);
}
}
return System.nanoTime() - time;
}
@SuppressWarnings("unused")
private long runDoubleWithGradients(int maxx, DoublePrecision f, double[] p)
{
f.initialise(p);
final int n = params1.length;
double[] g = new double[n];
double[] tg = new double[n];
// Warm up
for (int j = 0; j < 10; j++)
{
for (int i = 0; i < maxx; i++)
{
f.eval(i, g);
}
}
long time = System.nanoTime();
double sum = 0;
for (int j = 0; j < MAX_ITER; j++)
{
sum = 0;
for (int i = 0; i < maxx; i++)
{
sum += f.eval(i, g);
for (int k = 0; k < n; k++)
tg[k] += g[k];
}
}
return System.nanoTime() - time;
}
private long runDoubleWithGradientsNoSum(int maxx, DoublePrecision f, double[] p)
{
f.initialise(p);
double[] g = new double[params1.length];
// Warm up
for (int j = 0; j < 10; j++)
{
for (int i = 0; i < maxx; i++)
{
f.eval(i, g);
}
}
long time = System.nanoTime();
for (int j = 0; j < MAX_ITER; j++)
{
for (int i = 0; i < maxx; i++)
{
f.eval(i, g);
}
}
return System.nanoTime() - time;
}
private void singlePrecisionIsFaster(int maxx, SinglePrecision f1, DoublePrecision f2, boolean noSum)
{
f1.setMaxX(maxx);
f2.setMaxX(maxx);
float[] p1 = params1.clone();
double[] p2 = params2.clone();
p1[Gaussian.X_POSITION] = (float) (p2[Gaussian.X_POSITION] = (float) (0.123 + maxx / 2));
p1[Gaussian.Y_POSITION] = (float) (p2[Gaussian.Y_POSITION] = (float) (0.789 + maxx / 2));
long time1, time2;
if (noSum)
{
time1 = runSingleNoSum(maxx, f1, p1);
time1 = runSingleNoSum(maxx, f1, p1);
time1 += runSingleNoSum(maxx, f1, p1);
time2 = runDoubleNoSum(maxx, f2, p2);
time2 = runDoubleNoSum(maxx, f2, p2);
time2 += runDoubleNoSum(maxx, f2, p2);
}
else
{
time1 = runSingle(maxx, f1, p1);
time1 = runSingle(maxx, f1, p1);
time1 += runSingle(maxx, f1, p1);
time2 = runDouble(maxx, f2, p2);
time2 = runDouble(maxx, f2, p2);
time2 += runDouble(maxx, f2, p2);
}
System.out.printf("%s%s = %d, %s = %d => (%f)\n", (noSum) ? "No sum " : "", f1.getClass().getSimpleName(),
time1, f2.getClass().getSimpleName(), time2, (double) time2 / time1);
Assert.assertTrue(time1 * SPEED_UP_FACTOR < time2);
}
@SuppressWarnings("unused")
private long runSingle(int maxx, SinglePrecision f, float[] p)
{
// Warm up
for (int j = 0; j < 10; j++)
{
f.initialise(p);
for (int i = 0; i < maxx; i++)
{
f.eval(i);
}
}
long time = System.nanoTime();
double sum = 0;
for (int j = 0; j < MAX_ITER; j++)
{
sum = 0;
f.initialise(p);
for (int i = 0; i < maxx; i++)
{
sum += f.eval(i);
}
}
return System.nanoTime() - time;
}
private long runSingleNoSum(int maxx, SinglePrecision f, float[] p)
{
// Warm up
for (int j = 0; j < 10; j++)
{
f.initialise(p);
for (int i = 0; i < maxx; i++)
{
f.eval(i);
}
}
long time = System.nanoTime();
for (int j = 0; j < MAX_ITER; j++)
{
f.initialise(p);
for (int i = 0; i < maxx; i++)
{
f.eval(i);
}
}
return System.nanoTime() - time;
}
@SuppressWarnings("unused")
private long runDouble(int maxx, DoublePrecision f, double[] p)
{
// Warm up
for (int j = 0; j < 10; j++)
{
f.initialise(p);
for (int i = 0; i < maxx; i++)
{
f.eval(i);
}
}
long time = System.nanoTime();
double sum = 0;
for (int j = 0; j < MAX_ITER; j++)
{
sum = 0;
f.initialise(p);
for (int i = 0; i < maxx; i++)
{
sum += f.eval(i);
}
}
return System.nanoTime() - time;
}
private long runDoubleNoSum(int maxx, DoublePrecision f, double[] p)
{
// Warm up
for (int j = 0; j < 10; j++)
{
f.initialise(p);
for (int i = 0; i < maxx; i++)
{
f.eval(i);
}
}
long time = System.nanoTime();
for (int j = 0; j < MAX_ITER; j++)
{
f.initialise(p);
for (int i = 0; i < maxx; i++)
{
f.eval(i);
}
}
return System.nanoTime() - time;
}
private static float[] toFloat(double[] p)
{
float[] f = new float[p.length];
for (int i = 0; i < f.length; i++)
f[i] = (float) p[i];
return f;
}
private static double[] toDouble(float[] p)
{
double[] f = new double[p.length];
for (int i = 0; i < f.length; i++)
f[i] = p[i];
return f;
}
void log(String format, Object... args)
{
System.out.printf(format, args);
}
}