package gdsc.smlm.function;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
import gdsc.core.test.BaseTimingTask;
import gdsc.core.test.TimingService;
import gdsc.core.utils.DoubleEquality;
public class ErfTest
{
//@formatter:off
private abstract class BaseErf
{
String name;
BaseErf(String name) { this.name = name; }
abstract double erf(double x);
abstract double erf(double x1, double x2);
}
private class ApacheErf extends BaseErf
{
ApacheErf() { super("apache erf"); }
double erf(double x) { return org.apache.commons.math3.special.Erf.erf(x); }
double erf(double x1, double x2) { return org.apache.commons.math3.special.Erf.erf(x1, x2); }
}
private class Erf extends BaseErf
{
Erf() { super("erf"); }
double erf(double x) { return gdsc.smlm.function.Erf.erf(x); }
double erf(double x1, double x2) { return gdsc.smlm.function.Erf.erf(x1, x2); }
}
private class Erf0 extends BaseErf
{
Erf0() { super("erf0"); }
double erf(double x) { return gdsc.smlm.function.Erf.erf0(x); }
double erf(double x1, double x2) { return gdsc.smlm.function.Erf.erf0(x1, x2); }
}
private class Erf2 extends BaseErf
{
Erf2() { super("erf2"); }
double erf(double x) { return gdsc.smlm.function.Erf.erf2(x); }
double erf(double x1, double x2) { return gdsc.smlm.function.Erf.erf2(x1, x2); }
}
//@formatter:on
@Test
public void erf0xHasLowError()
{
erfxHasLowError(new Erf0(), 5e-4);
}
@Test
public void erfxHasLowError()
{
erfxHasLowError(new Erf(), 3e-7);
}
@Test
public void erf2xHasLowError()
{
erfxHasLowError(new Erf2(), 1.3e-4);
}
private void erfxHasLowError(BaseErf erf, double expected)
{
RandomGenerator rg = new Well19937c(30051977);
int range = 8;
double max = 0;
for (int xi = -range; xi <= range; xi++)
{
for (int i = 0; i < 5; i++)
{
double x = xi + rg.nextDouble();
double o = erf.erf(x);
double e = org.apache.commons.math3.special.Erf.erf(x);
double error = Math.abs(o - e);
if (max < error)
max = error;
//System.out.printf("x=%f, e=%f, o=%f, error=%f\n", x, e, o, error);
Assert.assertTrue(error < expected);
}
}
System.out.printf("erfx %s max error = %g\n", erf.name, max);
}
@Test
public void erfApachexIndistinguishableFrom1()
{
erfxIndistinguishableFrom1(new ApacheErf());
}
@Test
public void erf0xIndistinguishableFrom1()
{
erfxIndistinguishableFrom1(new Erf0());
}
@Test
public void erfxIndistinguishableFrom1()
{
erfxIndistinguishableFrom1(new Erf());
}
@Test
public void erf2xIndistinguishableFrom1()
{
erfxIndistinguishableFrom1(new Erf2());
}
private void erfxIndistinguishableFrom1(BaseErf erf)
{
// Find switch using a binary search
double lower = 1;
double upper = 40;
while (DoubleEquality.complement(lower, upper) > 1)
{
double mid = (upper + lower) * 0.5;
double o = erf.erf(mid);
if (o == 1)
{
upper = mid;
}
else
{
lower = mid;
}
}
System.out.printf("erfx %s indistinguishable from 1: x > %s, x >= %s\n", erf.name, Double.toString(lower),
Double.toString(upper));
}
@Test
public void erf0xxHasLowError()
{
erfxxHasLowError(new Erf0(), 4e-2);
}
@Test
public void erfxxHasLowError()
{
erfxxHasLowError(new Erf(), 7e-4);
}
@Test
public void erf2xxHasLowError()
{
erfxxHasLowError(new Erf2(), 1.1e-2);
}
private void erfxxHasLowError(BaseErf erf, double expected)
{
RandomGenerator rg = new Well19937c(30051977);
int range = 3;
double max = 0;
for (int xi = -range; xi <= range; xi++)
{
for (int xi2 = -range; xi2 <= range; xi2++)
{
for (int i = 0; i < 5; i++)
{
double x = xi + rg.nextDouble();
for (int j = 0; j < 5; j++)
{
double x2 = xi2 + rg.nextDouble();
double o = erf.erf(x, x2);
double e = org.apache.commons.math3.special.Erf.erf(x, x2);
double error = Math.abs(o - e);
if (max < error)
max = error;
//System.out.printf("x=%f, x2=%f, e=%f, o=%f, error=%f\n", x, x2, e, o, error);
Assert.assertTrue(error < expected);
}
}
}
}
System.out.printf("erfxx %s max error = %g\n", erf.name, max);
}
@Test
public void erf0xxHasLowErrorForUnitBlocks()
{
erfxxHasLowErrorForUnitBlocks(new Erf0(), 5e-4);
}
@Test
public void erfxxHasLowErrorForUnitBlocks()
{
erfxxHasLowErrorForUnitBlocks(new Erf(), 5e-7);
}
@Test
public void erf2xxHasLowErrorForUnitBlocks()
{
erfxxHasLowErrorForUnitBlocks(new Erf2(), 1e-4);
}
private void erfxxHasLowErrorForUnitBlocks(BaseErf erf, double expected)
{
int range = 8;
double max = 0;
for (int xi = -range; xi <= range; xi++)
{
double x = xi;
double x2 = xi + 1;
double o = erf.erf(x, x2);
double e = org.apache.commons.math3.special.Erf.erf(x, x2);
double error = Math.abs(o - e);
if (max < error)
max = error;
//System.out.printf("x=%f, x2=%f, e=%f, o=%f, error=%f\n", x, x2, e, o, error);
Assert.assertTrue(error < expected);
}
System.out.printf("erfxx %s unit max error = %g\n", erf.name, max);
}
@Test
public void erf0xxHasLowerErrorThanGaussianApproximationForUnitBlocks()
{
erfxxHasLowerErrorThanGaussianApproximationForUnitBlocks(new Erf0());
}
@Test
public void erfxxHasLowerErrorThanGaussianApproximationForUnitBlocks()
{
erfxxHasLowerErrorThanGaussianApproximationForUnitBlocks(new Erf());
}
@Test
public void erf2xxHasLowerErrorThanGaussianApproximationForUnitBlocks()
{
erfxxHasLowerErrorThanGaussianApproximationForUnitBlocks(new Erf2());
}
private void erfxxHasLowerErrorThanGaussianApproximationForUnitBlocks(BaseErf erf)
{
int range = 5;
double max = 0, max2 = 0;
// Standard deviation
double s = 1.3;
final double twos2 = 2 * s * s;
double norm = 1 / (Math.PI * twos2);
final double denom = 1.0 / (Math.sqrt(2.0) * s);
double sum1 = 0, sum2 = 0, sum3 = 0;
for (int x = -range; x <= range; x++)
{
double o1 = 0.5 * erf.erf((x - 0.5) * denom, (x + 0.5) * denom);
double e1 = 0.5 * org.apache.commons.math3.special.Erf.erf((x - 0.5) * denom, (x + 0.5) * denom);
for (int y = -range; y <= range; y++)
{
double o2 = 0.5 * erf.erf((y - 0.5) * denom, (y + 0.5) * denom);
double e2 = 0.5 * org.apache.commons.math3.special.Erf.erf((y - 0.5) * denom, (y + 0.5) * denom);
double o = o1 * o2;
double e = e1 * e2;
double oo = norm * FastMath.exp(-(x * x + y * y) / twos2);
sum1 += e;
sum2 += o;
sum3 += oo;
double absError = Math.abs(o - e);
if (e < 1e-4 || absError < 1e-10)
continue;
double error = DoubleEquality.relativeError(o, e);
double error2 = DoubleEquality.relativeError(oo, e);
if (max < error)
max = error;
if (max2 < error2)
max2 = error2;
//System.out.printf("x=%d, y=%d, e=%g, o=%g, o2=%g, error=%f, error2=%f\n", x, y, e, o, oo, error, error2);
Assert.assertTrue(error < error2);
}
}
Assert.assertTrue(erf.name + " Gaussian 2D integral is not 1", sum1 > 0.999);
Assert.assertTrue(erf.name + " Erf approx integral is incorrect",
DoubleEquality.relativeError(sum1, sum2) < 1e-3);
Assert.assertTrue(erf.name + " Gaussian approx integral is incorrect",
DoubleEquality.relativeError(sum1, sum3) < 1e-3);
System.out.printf(erf.name + " Erf approx pixel unit max error = %f\n", max);
System.out.printf(erf.name + " Gaussian approx pixel unit max error = %f\n", max2);
}
private class MyTimingTask extends BaseTimingTask
{
BaseErf erf;
double[] x;
public MyTimingTask(BaseErf erf, double[] x)
{
super(erf.name);
this.erf = erf;
this.x = x;
}
public int getSize()
{
return 1;
}
public Object getData(int i)
{
return null;
}
public Object run(Object data)
{
for (int i = 0; i < x.length; i++)
erf.erf(x[i]);
return null;
}
}
@Test
public void erfApproxIsFaster()
{
int range = 5;
int steps = 10000;
final double[] x = new double[steps];
double total = 2 * range;
double step = total / steps;
for (int i = 0; i < steps; i++)
x[i] = -range + i * step;
TimingService ts = new TimingService(5);
ts.execute(new MyTimingTask(new ApacheErf(), x));
ts.execute(new MyTimingTask(new Erf(), x));
ts.execute(new MyTimingTask(new Erf0(), x));
ts.execute(new MyTimingTask(new Erf2(), x));
int size = ts.getSize();
ts.repeat(size);
ts.report();
}
@Test
public void gaussianIntegralApproximatesErf()
{
double x = 1.3, y = 2.2, s = 1.14;
int minx = (int) x;
int miny = (int) y;
int maxx = minx + 1;
int maxy = miny + 1;
// Full integration using the Erf
// Note: The PSF of a 2D Gaussian is described in Smith et all using a denominator
// of (2.0 * s * s) for both x and Y directions. This is wrong. We need the
// integral of the single Guassian in each dimension so the denomiator is (sqrt(2.0) * s).
// See: Smith et al, (2010). Fast, single-molecule localisation that achieves
// theoretically minimum uncertainty. Nature Methods 7, 373-375
// (supplementary note).
//final double denom = 1.0 / (2.0 * s * s); // As per Smith, etal (2010),
final double denom = 1.0 / (Math.sqrt(2.0) * s);
double e1 = 0.5 * org.apache.commons.math3.special.Erf.erf(minx * denom, maxx * denom);
double e2 = 0.5 * org.apache.commons.math3.special.Erf.erf(miny * denom, maxy * denom);
double e = e1 * e2;
double o = 0;
// Numeric integration
final double twos2 = 2 * s * s;
double norm = 1 / (Math.PI * twos2);
for (int i = 0, steps = 1; i < 4; i++, steps = (int) Math.pow(10, i))
{
// Gaussian is: FastMath.exp(-(x * x + y * y) / twos2) over all x and y
// But we can do this by separating x and y:
// FastMath.exp(-(x * x) / twos2) * FastMath.exp(-(y * y) / twos2)
// pre-compute
double[] ex = new double[steps];
double sumey = 0;
if (steps == 1)
{
// Use the actual values for x and y
ex[0] = FastMath.exp(-(x * x) / twos2);
sumey = FastMath.exp(-(y * y) / twos2);
}
else
{
for (int j = 0; j < steps; j++)
{
double xx = minx + (double) j / steps;
double yy = miny + (double) j / steps;
ex[j] = FastMath.exp(-(xx * xx) / twos2);
sumey += FastMath.exp(-(yy * yy) / twos2);
}
}
double sum = 0;
for (int j = 0; j < steps; j++)
{
sum += ex[j] * sumey;
}
//// Check
//double sum2 = 0;
//for (int j = 0; j <= steps; j++)
//{
// double xx = minx + (double) j / steps;
// for (int k = 0; k <= steps; k++)
// {
// double yy = miny + (double) k / steps;
// sum2 += FastMath.exp(-(xx * xx + yy * yy) / twos2);
// }
//}
//System.out.printf("sum=%f, sum2=%f\n", sum, sum2);
int n = steps * steps;
o = norm * sum / n;
System.out.printf("n=%d, e=%f, o=%f, error=%f\n", n, e, o, DoubleEquality.relativeError(e, o));
}
Assert.assertEquals(e, o, e * 1e-2);
}
@Test
public void analyticErfGradientCorrectForErfApproximation()
{
BaseErf erf = new Erf();
int range = 7;
int steps = 10000;
double step = (double) range / steps;
double delta = 1e-3;
DoubleEquality eq = new DoubleEquality(4, 1e-6);
for (int i = 0; i < steps; i++)
{
double x = i * step;
double x1 = x + delta;
double x2 = x - delta;
double o1 = erf.erf(x1);
double o2 = erf.erf(x2);
double delta2 = x1 - x2;
double g = (o1 - o2) / delta2;
double e = gdsc.smlm.function.Erf.dErf_dx(x);
if (!eq.almostEqualComplement(e, g))
Assert.assertTrue(x + " : " + e + " != " + g, false);
}
}
}