package gdsc.smlm.fitting.nonlinear.gradient;
import java.util.ArrayList;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;
import org.ejml.data.DenseMatrix64F;
import org.junit.Assert;
import org.junit.Test;
import gdsc.core.utils.DoubleEquality;
import gdsc.core.utils.Statistics;
import gdsc.smlm.TestSettings;
import gdsc.smlm.fitting.linear.EJMLLinearSolver;
import gdsc.smlm.function.Gradient1Function;
import gdsc.smlm.function.gaussian.GaussianFunctionFactory;
import gdsc.smlm.function.gaussian.erf.ErfGaussian2DFunction;
import gdsc.smlm.function.gaussian.erf.SingleFreeCircularErfGaussian2DFunction;
/**
* Contains speed tests for the fastest method for calculating the Hessian and gradient vector
* for use in the LVM algorithm.
* <p>
* Note: This class is a test-bed for implementation strategies. The fastest strategy can then be used for other
* gradient procedures.
*/
public class LSQLVMGradientProcedureTest
{
boolean speedTests = true;
DoubleEquality eq = new DoubleEquality(6, 1e-16);
int MAX_ITER = 20000;
int blockWidth = 10;
double Background = 0.5;
double Signal = 100;
double Angle = Math.PI;
double Xpos = 5;
double Ypos = 5;
double Xwidth = 1.2;
double Ywidth = 1.2;
RandomDataGenerator rdg;
@Test
public void gradientProcedureFactoryCreatesOptimisedProcedures()
{
double[] y = new double[0];
Assert.assertEquals(LSQLVMGradientProcedureMatrixFactory.create(y, null, new DummyGradientFunction(6)).getClass(),
LSQLVMGradientProcedureMatrix6.class);
Assert.assertEquals(LSQLVMGradientProcedureMatrixFactory.create(y, null, new DummyGradientFunction(5)).getClass(),
LSQLVMGradientProcedureMatrix5.class);
Assert.assertEquals(LSQLVMGradientProcedureMatrixFactory.create(y, null, new DummyGradientFunction(4)).getClass(),
LSQLVMGradientProcedureMatrix4.class);
Assert.assertEquals(LSQLVMGradientProcedureLinearFactory.create(y, null, new DummyGradientFunction(6)).getClass(),
LSQLVMGradientProcedureLinear6.class);
Assert.assertEquals(LSQLVMGradientProcedureLinearFactory.create(y, null, new DummyGradientFunction(5)).getClass(),
LSQLVMGradientProcedureLinear5.class);
Assert.assertEquals(LSQLVMGradientProcedureLinearFactory.create(y, null, new DummyGradientFunction(4)).getClass(),
LSQLVMGradientProcedureLinear4.class);
Assert.assertEquals(LSQLVMGradientProcedureFactory.create(y, null, new DummyGradientFunction(6)).getClass(),
LSQLVMGradientProcedure6.class);
Assert.assertEquals(LSQLVMGradientProcedureFactory.create(y, null, new DummyGradientFunction(5)).getClass(),
LSQLVMGradientProcedure5.class);
Assert.assertEquals(LSQLVMGradientProcedureFactory.create(y, null, new DummyGradientFunction(4)).getClass(),
LSQLVMGradientProcedure4.class);
}
@Test
public void gradientProcedureLinearComputesSameAsGradientCalculator()
{
gradientProcedureComputesSameAsGradientCalculator(new LSQLVMGradientProcedureLinearFactory());
}
@Test
public void gradientProcedureMatrixComputesSameAsGradientCalculator()
{
gradientProcedureComputesSameAsGradientCalculator(new LSQLVMGradientProcedureMatrixFactory());
}
@Test
public void gradientProcedureComputesSameAsGradientCalculator()
{
gradientProcedureComputesSameAsGradientCalculator(new LSQLVMGradientProcedureFactory());
}
private void gradientProcedureComputesSameAsGradientCalculator(BaseLSQLVMGradientProcedureFactory factory)
{
gradientProcedureComputesSameAsGradientCalculator(4, factory);
gradientProcedureComputesSameAsGradientCalculator(5, factory);
gradientProcedureComputesSameAsGradientCalculator(6, factory);
gradientProcedureComputesSameAsGradientCalculator(11, factory);
gradientProcedureComputesSameAsGradientCalculator(21, factory);
}
@Test
public void gradientProcedureLinearIsNotSlowerThanGradientCalculator()
{
gradientProcedureIsNotSlowerThanGradientCalculator(new LSQLVMGradientProcedureLinearFactory());
}
@Test
public void gradientProcedureMatrixIsNotSlowerThanGradientCalculator()
{
gradientProcedureIsNotSlowerThanGradientCalculator(new LSQLVMGradientProcedureMatrixFactory());
}
@Test
public void gradientProcedureIsNotSlowerThanGradientCalculator()
{
gradientProcedureIsNotSlowerThanGradientCalculator(new LSQLVMGradientProcedureFactory());
}
private void gradientProcedureIsNotSlowerThanGradientCalculator(BaseLSQLVMGradientProcedureFactory factory)
{
gradientProcedureIsNotSlowerThanGradientCalculator(4, factory);
gradientProcedureIsNotSlowerThanGradientCalculator(5, factory);
gradientProcedureIsNotSlowerThanGradientCalculator(6, factory);
// 2 peaks
gradientProcedureIsNotSlowerThanGradientCalculator(11, factory);
// 4 peaks
gradientProcedureIsNotSlowerThanGradientCalculator(21, factory);
}
private void gradientProcedureComputesSameAsGradientCalculator(int nparams, BaseLSQLVMGradientProcedureFactory factory)
{
int iter = 10;
rdg = new RandomDataGenerator(new Well19937c(30051977));
double[][] alpha = new double[nparams][nparams];
double[] beta = new double[nparams];
ArrayList<double[]> paramsList = new ArrayList<double[]>(iter);
ArrayList<double[]> yList = new ArrayList<double[]>(iter);
int[] x = createFakeData(nparams, iter, paramsList, yList);
int n = x.length;
FakeGradientFunction func = new FakeGradientFunction(blockWidth, nparams);
GradientCalculator calc = GradientCalculatorFactory.newCalculator(nparams, false);
String name = factory.getClass().getSimpleName();
for (int i = 0; i < paramsList.size(); i++)
{
BaseLSQLVMGradientProcedure p = factory.createProcedure(yList.get(i), func);
p.gradient(paramsList.get(i));
double s = p.value;
double s2 = calc.findLinearised(n, yList.get(i), paramsList.get(i), alpha, beta, func);
// Exactly the same ...
Assert.assertEquals(name + " Result: Not same @ " + i, s, s2, 0);
Assert.assertArrayEquals(name + " Observations: Not same beta @ " + i, p.beta, beta, 0);
double[] al = p.getAlphaLinear();
Assert.assertArrayEquals(name + " Observations: Not same alpha @ " + i, al, new DenseMatrix64F(alpha).data,
0);
double[][] am = p.getAlphaMatrix();
for (int j = 0; j < nparams; j++)
Assert.assertArrayEquals(name + " Observations: Not same alpha @ " + i, am[j], alpha[j], 0);
}
}
private abstract class Timer
{
private int loops;
int min;
Timer()
{
}
Timer(int min)
{
this.min = min;
}
long getTime()
{
// Run till stable timing
long t1 = time();
for (int i = 0; i < 10; i++)
{
long t2 = t1;
t1 = time();
if (loops >= min && DoubleEquality.relativeError(t1, t2) < 0.02) // 2% difference
break;
}
return t1;
}
long time()
{
loops++;
long t = System.nanoTime();
run();
t = System.nanoTime() - t;
//System.out.printf("[%d] Time = %d\n", loops, t);
return t;
}
abstract void run();
}
private void gradientProcedureIsNotSlowerThanGradientCalculator(final int nparams,
final BaseLSQLVMGradientProcedureFactory factory)
{
org.junit.Assume.assumeTrue(speedTests || TestSettings.RUN_SPEED_TESTS);
final int iter = 1000;
rdg = new RandomDataGenerator(new Well19937c(30051977));
final double[][] alpha = new double[nparams][nparams];
final double[] beta = new double[nparams];
final ArrayList<double[]> paramsList = new ArrayList<double[]>(iter);
final ArrayList<double[]> yList = new ArrayList<double[]>(iter);
int[] x = createFakeData(nparams, iter, paramsList, yList);
final int n = x.length;
final FakeGradientFunction func = new FakeGradientFunction(blockWidth, nparams);
GradientCalculator calc = GradientCalculatorFactory.newCalculator(nparams, false);
for (int i = 0; i < paramsList.size(); i++)
calc.findLinearised(n, yList.get(i), paramsList.get(i), alpha, beta, func);
for (int i = 0; i < paramsList.size(); i++)
{
BaseLSQLVMGradientProcedure p = factory.createProcedure(yList.get(i), func);
p.gradient(paramsList.get(i));
}
// Realistic loops for an optimisation
final int loops = 15;
// Run till stable timing
Timer t1 = new Timer()
{
@Override
void run()
{
for (int i = 0, k = 0; i < iter; i++)
{
GradientCalculator calc = GradientCalculatorFactory.newCalculator(nparams, false);
for (int j = loops; j-- > 0;)
calc.findLinearised(n, yList.get(i), paramsList.get(k++ % iter), alpha, beta, func);
}
}
};
long time1 = t1.getTime();
Timer t2 = new Timer(t1.loops)
{
@Override
void run()
{
for (int i = 0, k = 0; i < iter; i++)
{
BaseLSQLVMGradientProcedure p = factory.createProcedure(yList.get(i), func);
for (int j = loops; j-- > 0;)
p.gradient(paramsList.get(k++ % iter));
}
}
};
long time2 = t2.getTime();
log("GradientCalculator = %d : %s %d = %d : %fx\n", time1, factory.getClass().getSimpleName(), nparams, time2,
(1.0 * time1) / time2);
if (TestSettings.ASSERT_SPEED_TESTS)
{
// Add contingency
Assert.assertTrue(time2 < time1 * 1.5);
}
}
@Test
public void gradientProcedureUnrolledComputesSameAsGradientProcedure()
{
// Test the method that will be used for the standard and unrolled versions
// for all other 'gradient procedures'
gradientProcedureUnrolledComputesSameAsGradientProcedure(4);
gradientProcedureUnrolledComputesSameAsGradientProcedure(5);
gradientProcedureUnrolledComputesSameAsGradientProcedure(6);
}
private void gradientProcedureUnrolledComputesSameAsGradientProcedure(int nparams)
{
int iter = 10;
rdg = new RandomDataGenerator(new Well19937c(30051977));
ArrayList<double[]> paramsList = new ArrayList<double[]>(iter);
ArrayList<double[]> yList = new ArrayList<double[]>(iter);
createFakeData(nparams, iter, paramsList, yList);
FakeGradientFunction func = new FakeGradientFunction(blockWidth, nparams);
String name = GradientCalculator.class.getSimpleName();
for (int i = 0; i < paramsList.size(); i++)
{
BaseLSQLVMGradientProcedure p1 = LSQLVMGradientProcedureFactory.create(yList.get(i), null, func);
p1.gradient(paramsList.get(i));
BaseLSQLVMGradientProcedure p2 = new LSQLVMGradientProcedure(yList.get(i), func);
p2.gradient(paramsList.get(i));
// Exactly the same ...
Assert.assertEquals(name + " Result: Not same @ " + i, p1.value, p2.value, 0);
Assert.assertArrayEquals(name + " Observations: Not same beta @ " + i, p1.beta, p2.beta, 0);
Assert.assertArrayEquals(name + " Observations: Not same alpha @ " + i, p1.getAlphaLinear(),
p2.getAlphaLinear(), 0);
double[][] am1 = p1.getAlphaMatrix();
double[][] am2 = p2.getAlphaMatrix();
for (int j = 0; j < nparams; j++)
Assert.assertArrayEquals(name + " Observations: Not same alpha @ " + i, am1[j], am2[j], 0);
}
}
@Test
public void gradientProcedureIsFasterUnrolledThanGradientProcedureMatrix()
{
gradientProcedure2IsFasterUnrolledThanGradientProcedure1(new LSQLVMGradientProcedureMatrixFactory(),
new LSQLVMGradientProcedureFactory());
}
@Test
public void gradientProcedureLinearIsFasterUnrolledThanGradientProcedureMatrix()
{
gradientProcedure2IsFasterUnrolledThanGradientProcedure1(new LSQLVMGradientProcedureMatrixFactory(),
new LSQLVMGradientProcedureLinearFactory());
}
private void gradientProcedure2IsFasterUnrolledThanGradientProcedure1(BaseLSQLVMGradientProcedureFactory factory1,
BaseLSQLVMGradientProcedureFactory factory2)
{
// Assert the unrolled versions
gradientProcedureLinearIsFasterThanGradientProcedureMatrix(4, factory1, factory2, true);
gradientProcedureLinearIsFasterThanGradientProcedureMatrix(5, factory1, factory2, true);
gradientProcedureLinearIsFasterThanGradientProcedureMatrix(6, factory1, factory2, true);
gradientProcedureLinearIsFasterThanGradientProcedureMatrix(11, factory1, factory2, false);
gradientProcedureLinearIsFasterThanGradientProcedureMatrix(21, factory1, factory2, false);
}
private void gradientProcedureLinearIsFasterThanGradientProcedureMatrix(final int nparams,
final BaseLSQLVMGradientProcedureFactory factory1, final BaseLSQLVMGradientProcedureFactory factory2,
boolean doAssert)
{
org.junit.Assume.assumeTrue(speedTests || TestSettings.RUN_SPEED_TESTS);
final int iter = 100;
rdg = new RandomDataGenerator(new Well19937c(30051977));
final ArrayList<double[]> paramsList = new ArrayList<double[]>(iter);
final ArrayList<double[]> yList = new ArrayList<double[]>(iter);
createData(1, iter, paramsList, yList);
// Remove the timing of the function call by creating a dummy function
final Gradient1Function func = new FakeGradientFunction(blockWidth, nparams);
for (int i = 0; i < paramsList.size(); i++)
{
BaseLSQLVMGradientProcedure p = factory1.createProcedure(yList.get(i), func);
p.gradient(paramsList.get(i));
p.gradient(paramsList.get(i));
BaseLSQLVMGradientProcedure p2 = factory2.createProcedure(yList.get(i), func);
p2.gradient(paramsList.get(i));
p2.gradient(paramsList.get(i));
// Check they are the same
Assert.assertArrayEquals("A " + i, p.getAlphaLinear(), p2.getAlphaLinear(), 0);
Assert.assertArrayEquals("B " + i, p.beta, p2.beta, 0);
}
// Realistic loops for an optimisation
final int loops = 15;
// Run till stable timing
Timer t1 = new Timer()
{
@Override
void run()
{
for (int i = 0, k = 0; i < paramsList.size(); i++)
{
BaseLSQLVMGradientProcedure p = factory1.createProcedure(yList.get(i), func);
for (int j = loops; j-- > 0;)
p.gradient(paramsList.get(k++ % iter));
}
}
};
long time1 = t1.getTime();
Timer t2 = new Timer(t1.loops)
{
@Override
void run()
{
for (int i = 0, k = 0; i < paramsList.size(); i++)
{
BaseLSQLVMGradientProcedure p2 = factory2.createProcedure(yList.get(i), func);
for (int j = loops; j-- > 0;)
p2.gradient(paramsList.get(k++ % iter));
}
}
};
long time2 = t2.getTime();
log("Standard %s = %d : Unrolled %s %d = %d : %fx\n", factory1.getClass().getSimpleName(), time1,
factory2.getClass().getSimpleName(), nparams, time2, (1.0 * time1) / time2);
if (doAssert)
Assert.assertTrue(time2 < time1);
}
@Test
public void gradientProcedureComputesGradient()
{
gradientProcedureComputesGradient(new SingleFreeCircularErfGaussian2DFunction(blockWidth, blockWidth));
}
private void gradientProcedureComputesGradient(ErfGaussian2DFunction func)
{
int nparams = func.getNumberOfGradients();
int[] indices = func.gradientIndices();
int iter = 100;
rdg = new RandomDataGenerator(new Well19937c(30051977));
ArrayList<double[]> paramsList = new ArrayList<double[]>(iter);
ArrayList<double[]> yList = new ArrayList<double[]>(iter);
createData(1, iter, paramsList, yList, true);
double delta = 1e-3;
DoubleEquality eq = new DoubleEquality(3, 1e-3);
for (int i = 0; i < paramsList.size(); i++)
{
double[] y = yList.get(i);
double[] a = paramsList.get(i);
double[] a2 = a.clone();
BaseLSQLVMGradientProcedure p = LSQLVMGradientProcedureFactory.create(y, null, func);
p.gradient(a);
//double s = p.ssx;
double[] beta = p.beta.clone();
for (int j = 0; j < nparams; j++)
{
int k = indices[j];
double d = (a[k] == 0) ? 1e-3 : a[k] * delta;
a2[k] = a[k] + d;
p.value(a2);
double s1 = p.value;
a2[k] = a[k] - d;
p.value(a2);
double s2 = p.value;
a2[k] = a[k];
// Apply a factor of -2 to compute the actual gradients:
// See Numerical Recipes in C++, 2nd Ed. Equation 15.5.6 for Nonlinear Models
beta[j] *= -2;
double gradient = (s1 - s2) / (2 * d);
//System.out.printf("[%d,%d] %f (%s %f+/-%f) %f ?= %f\n", i, k, s, func.getName(k), a[k], d, beta[j],
// gradient);
Assert.assertTrue("Not same gradient @ " + j, eq.almostEqualComplement(beta[j], gradient));
}
}
}
@Test
public void gradientProcedureComputesSameOutputWithBias()
{
ErfGaussian2DFunction func = new SingleFreeCircularErfGaussian2DFunction(blockWidth, blockWidth);
int nparams = func.getNumberOfGradients();
int iter = 100;
rdg = new RandomDataGenerator(new Well19937c(30051977));
ArrayList<double[]> paramsList = new ArrayList<double[]>(iter);
ArrayList<double[]> yList = new ArrayList<double[]>(iter);
ArrayList<double[]> alphaList = new ArrayList<double[]>(iter);
ArrayList<double[]> betaList = new ArrayList<double[]>(iter);
ArrayList<double[]> xList = new ArrayList<double[]>(iter);
// Manipulate the background
double defaultBackground = Background;
try
{
Background = 1e-2;
createData(1, iter, paramsList, yList, true);
EJMLLinearSolver solver = new EJMLLinearSolver(5, 1e-6);
for (int i = 0; i < paramsList.size(); i++)
{
double[] y = yList.get(i);
double[] a = paramsList.get(i);
BaseLSQLVMGradientProcedure p = LSQLVMGradientProcedureFactory.create(y, null, func);
p.gradient(a);
double[] beta = p.beta;
alphaList.add(p.getAlphaLinear());
betaList.add(beta.clone());
for (int j = 0; j < nparams; j++)
{
if (Math.abs(beta[j]) < 1e-6)
System.out.printf("[%d] Tiny beta %s %g\n", i, func.getName(j), beta[j]);
}
// Solve
if (!solver.solve(p.getAlphaMatrix(), beta))
throw new AssertionError();
xList.add(beta);
//System.out.println(Arrays.toString(beta));
}
//for (int b = 1; b < 1000; b *= 2)
for (double b : new double[] { -500, -100, -10, -1, -0.1, 0, 0.1, 1, 10, 100, 500 })
{
Statistics[] rel = new Statistics[nparams];
Statistics[] abs = new Statistics[nparams];
for (int i = 0; i < nparams; i++)
{
rel[i] = new Statistics();
abs[i] = new Statistics();
}
for (int i = 0; i < paramsList.size(); i++)
{
double[] y = add(yList.get(i), b);
double[] a = paramsList.get(i).clone();
a[0] += b;
BaseLSQLVMGradientProcedure p = LSQLVMGradientProcedureFactory.create(y, null, func);
p.gradient(a);
double[] beta = p.beta;
double[] alpha2 = alphaList.get(i);
double[] beta2 = betaList.get(i);
double[] x2 = xList.get(i);
Assert.assertArrayEquals("Beta", beta2, beta, 1e-10);
Assert.assertArrayEquals("Alpha", alpha2, p.getAlphaLinear(), 1e-10);
// Solve
solver.solve(p.getAlphaMatrix(), beta);
Assert.assertArrayEquals("X", x2, beta, 1e-10);
for (int j = 0; j < nparams; j++)
{
rel[j].add(DoubleEquality.relativeError(x2[j], beta[j]));
abs[j].add(Math.abs(x2[j] - beta[j]));
}
}
for (int i = 0; i < nparams; i++)
System.out.printf("Bias = %.2f : %s : Rel %g +/- %g: Abs %g +/- %g\n", b, func.getName(i),
rel[i].getMean(), rel[i].getStandardDeviation(), abs[i].getMean(),
abs[i].getStandardDeviation());
}
}
finally
{
Background = defaultBackground;
}
}
private double[] add(double[] d, double b)
{
d = d.clone();
for (int i = 0; i < d.length; i++)
d[i] += b;
return d;
}
/**
* Create random elliptical Gaussian data an returns the data plus an estimate of the parameters.
* Only the chosen parameters are randomised and returned for a maximum of (background, amplitude, angle, xpos,
* ypos, xwidth, ywidth }
*
* @param npeaks
* the npeaks
* @param params
* set on output
* @param randomiseParams
* Set to true to randomise the params
* @return the double[]
*/
private double[] doubleCreateGaussianData(int npeaks, double[] params, boolean randomiseParams)
{
int n = blockWidth * blockWidth;
// Generate a 2D Gaussian
ErfGaussian2DFunction func = (ErfGaussian2DFunction) GaussianFunctionFactory.create2D(npeaks, blockWidth,
blockWidth, GaussianFunctionFactory.FIT_ERF_FREE_CIRCLE, null);
params[0] = random(Background);
for (int i = 0, j = 1; i < npeaks; i++, j += 6)
{
params[j] = random(Signal);
params[j + 2] = random(Xpos);
params[j + 3] = random(Ypos);
params[j + 4] = random(Xwidth);
params[j + 5] = random(Ywidth);
}
double[] y = new double[n];
func.initialise(params);
for (int i = 0; i < y.length; i++)
{
// Add random Poisson noise
y[i] = rdg.nextPoisson(func.eval(i));
}
if (randomiseParams)
{
params[0] = random(params[0]);
for (int i = 0, j = 1; i < npeaks; i++, j += 6)
{
params[j] = random(params[j]);
params[j + 2] = random(params[j + 2]);
params[j + 3] = random(params[j + 3]);
params[j + 4] = random(params[j + 4]);
params[j + 5] = random(params[j + 5]);
}
}
return y;
}
private double random(double d)
{
return d + rdg.nextUniform(-d * 0.1, d * 0.1);
}
protected int[] createData(int npeaks, int iter, ArrayList<double[]> paramsList, ArrayList<double[]> yList)
{
return createData(npeaks, iter, paramsList, yList, true);
}
protected int[] createData(int npeaks, int iter, ArrayList<double[]> paramsList, ArrayList<double[]> yList,
boolean randomiseParams)
{
int[] x = new int[blockWidth * blockWidth];
for (int i = 0; i < x.length; i++)
x[i] = i;
for (int i = 0; i < iter; i++)
{
double[] params = new double[1 + 6 * npeaks];
double[] y = doubleCreateGaussianData(npeaks, params, randomiseParams);
paramsList.add(params);
yList.add(y);
}
return x;
}
protected int[] createFakeData(int nparams, int iter, ArrayList<double[]> paramsList, ArrayList<double[]> yList)
{
int[] x = new int[blockWidth * blockWidth];
for (int i = 0; i < x.length; i++)
x[i] = i;
for (int i = 0; i < iter; i++)
{
double[] params = new double[nparams];
double[] y = createFakeData(params);
paramsList.add(params);
yList.add(y);
}
return x;
}
private double[] createFakeData(double[] params)
{
int n = blockWidth * blockWidth;
RandomGenerator r = rdg.getRandomGenerator();
for (int i = 0; i < params.length; i++)
{
params[i] = r.nextDouble();
}
double[] y = new double[n];
for (int i = 0; i < y.length; i++)
{
y[i] = r.nextDouble();
}
return y;
}
protected ArrayList<double[]> copyList(ArrayList<double[]> paramsList)
{
ArrayList<double[]> params2List = new ArrayList<double[]>(paramsList.size());
for (int i = 0; i < paramsList.size(); i++)
{
params2List.add(copydouble(paramsList.get(i)));
}
return params2List;
}
private double[] copydouble(double[] d)
{
double[] d2 = new double[d.length];
for (int i = 0; i < d.length; i++)
d2[i] = d[i];
return d2;
}
void log(String format, Object... args)
{
System.out.printf(format, args);
}
}