package gdsc.smlm.fitting.nonlinear;
import java.util.Arrays;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.stat.inference.TTest;
import org.junit.Assert;
import org.junit.Test;
import gdsc.core.utils.DoubleEquality;
import gdsc.core.utils.Statistics;
import gdsc.core.utils.StoredDataStatistics;
import gdsc.smlm.fitting.FitStatus;
import gdsc.smlm.fitting.FunctionSolver;
import gdsc.smlm.fitting.nonlinear.stop.ErrorStoppingCriteria;
import gdsc.smlm.function.gaussian.Gaussian2DFunction;
import gdsc.smlm.function.gaussian.GaussianFunctionFactory;
/**
* Test that a bounded fitter can return the same results with and without bounds.
*/
public class BoundedFunctionSolverTest
{
long seed = 30051977; //System.currentTimeMillis() + System.identityHashCode(this);
//long seed = System.currentTimeMillis() + System.identityHashCode(this);
RandomGenerator randomGenerator = new Well19937c(seed);
RandomDataGenerator dataGenerator = new RandomDataGenerator(randomGenerator);
// Basic Gaussian
static double bias = 100;
static double[] params = new double[7];
static double[] base = { 0.8, 1, 1.2 };
static double[] signal = { 1000, 2000, 5000, 10000 }; // 100, 200, 400, 800 };
static double[] noise = { 0.1, 0.5, 1 };
static double[] shift = { -1, 0, 1 };
static double[] factor = { 0.7, 1, 1.3 };
static int size = 11;
static
{
params[Gaussian2DFunction.BACKGROUND] = 5;
params[Gaussian2DFunction.X_POSITION] = size / 2;
params[Gaussian2DFunction.Y_POSITION] = size / 2;
params[Gaussian2DFunction.X_SD] = 1.4;
}
private static double[] defaultClampValues;
static
{
defaultClampValues = new double[7];
// Taken from the 3D-DAO-STORM paper:
// (Babcock et al. 2012) A high-density 3D localization algorithm for stochastic optical
// reconstruction microscopy. Optical Nanoscopy. 2012 1:6
// DOI: 10.1186/2192-2853-1-6
// Page 3
// Note: It is not clear if the background/signal are in ADUs or photons. I assume photons.
// This seems big for background in photons
defaultClampValues[Gaussian2DFunction.BACKGROUND] = 100;
//defaultClampValues[Gaussian2DFunction.BACKGROUND] = 20;
defaultClampValues[Gaussian2DFunction.SIGNAL] = 1000;
defaultClampValues[Gaussian2DFunction.SHAPE] = Math.PI;
defaultClampValues[Gaussian2DFunction.X_POSITION] = 1;
defaultClampValues[Gaussian2DFunction.Y_POSITION] = 1;
defaultClampValues[Gaussian2DFunction.X_SD] = 3;
defaultClampValues[Gaussian2DFunction.Y_SD] = 3;
}
// TODO - Test if local search param if useful when using clamping
// Standard LVM
@Test
public void canFitSingleGaussianLVM()
{
fitSingleGaussianLVM(0, 0, false);
}
// Bounded/Clamped LVM
@Test
public void canFitSingleGaussianBLVMNoBounds()
{
fitSingleGaussianLVM(1, 0, false);
}
@Test
public void canFitSingleGaussianBLVM()
{
fitSingleGaussianLVM(2, 0, false);
}
@Test
public void canFitSingleGaussianCLVM()
{
fitSingleGaussianLVM(0, 1, false);
}
@Test
public void canFitSingleGaussianDCLVM()
{
fitSingleGaussianLVM(0, 2, false);
}
@Test
public void canFitSingleGaussianBCLVM()
{
fitSingleGaussianLVM(2, 1, false);
}
@Test
public void canFitSingleGaussianBDCLVM()
{
fitSingleGaussianLVM(2, 2, false);
}
// MLE LVM
@Test
public void canFitSingleGaussianLVMMLE()
{
fitSingleGaussianLVM(0, 0, true);
}
@Test
public void canFitSingleGaussianBLVMMLENoBounds()
{
fitSingleGaussianLVM(1, 0, true);
}
@Test
public void canFitSingleGaussianBLVMMLE()
{
fitSingleGaussianLVM(2, 0, true);
}
private void fitSingleGaussianLVM(int bounded, int clamping, boolean mle)
{
canFitSingleGaussian(getLVM(bounded, clamping, mle), bounded == 2, !mle);
}
// Is Bounded/Clamped LVM better?
@Test
public void fitSingleGaussianBLVMBetterThanLVM()
{
fitSingleGaussianBetterLVM(true, 0, false, false, 0, false);
}
@Test
public void fitSingleGaussianCLVMBetterThanLVM()
{
fitSingleGaussianBetterLVM(false, 1, false, false, 0, false);
}
@Test
public void fitSingleGaussianBCLVMBetterThanLVM()
{
fitSingleGaussianBetterLVM(true, 1, false, false, 0, false);
}
@Test
public void fitSingleGaussianDCLVMBetterThanLVM()
{
fitSingleGaussianBetterLVM(false, 2, false, false, 0, false);
}
@Test
public void fitSingleGaussianBDCLVMBetterThanLVM()
{
fitSingleGaussianBetterLVM(true, 2, false, false, 0, false);
}
@Test
public void fitSingleGaussianLVMMLEBetterThanLVM()
{
fitSingleGaussianBetterLVM(false, 0, true, false, 0, false);
}
@Test
public void fitSingleGaussianBLVMMLEBetterThanLVM()
{
fitSingleGaussianBetterLVM(true, 0, true, false, 0, false);
}
@Test
public void fitSingleGaussianCLVMMLEBetterThanLVM()
{
fitSingleGaussianBetterLVM(false, 1, true, false, 0, false);
}
@Test
public void fitSingleGaussianBCLVMMLEBetterThanLVM()
{
fitSingleGaussianBetterLVM(true, 1, true, false, 0, false);
}
@Test
public void fitSingleGaussianDCLVMMLEBetterThanLVM()
{
fitSingleGaussianBetterLVM(false, 2, true, false, 0, false);
}
@Test
public void fitSingleGaussianBDCLVMMLEBetterThanLVM()
{
fitSingleGaussianBetterLVM(true, 2, true, false, 0, false);
}
@Test
public void fitSingleGaussianBLVMMLEBetterThanLVMMLE()
{
fitSingleGaussianBetterLVM(true, 0, true, false, 0, true);
}
@Test
public void fitSingleGaussianCLVMMLEBetterThanLVMMLE()
{
fitSingleGaussianBetterLVM(false, 1, true, false, 0, true);
}
@Test
public void fitSingleGaussianDCLVMMLEBetterThanLVMMLE()
{
fitSingleGaussianBetterLVM(false, 2, true, false, 0, true);
}
@Test
public void fitSingleGaussianBDCLVMMLEBetterThanLVMMLE()
{
fitSingleGaussianBetterLVM(true, 2, true, false, 0, true);
}
@Test
public void fitSingleGaussianBLVMMLEBetterThanBLVM()
{
fitSingleGaussianBetterLVM(true, 0, true, true, 0, false);
}
@Test
public void fitSingleGaussianBCLVMMLEBetterThanBCLVM()
{
fitSingleGaussianBetterLVM(true, 1, true, true, 1, false);
}
@Test
public void fitSingleGaussianBDCLVMMLEBetterThanBDCLVM()
{
fitSingleGaussianBetterLVM(true, 2, true, true, 2, false);
}
private void fitSingleGaussianBetterLVM(boolean bounded2, int clamping2, boolean mle2, boolean bounded,
int clamping, boolean mle)
{
NonLinearFit solver = getLVM((bounded) ? 2 : 1, clamping, mle);
NonLinearFit solver2 = getLVM((bounded2) ? 2 : 1, clamping2, mle2);
canFitSingleGaussianBetter(solver, bounded, !mle, solver2, bounded2, !mle2, getLVMName(bounded, clamping, mle),
getLVMName(bounded2, clamping2, mle2));
}
private NonLinearFit getLVM(int bounded, int clamping, boolean mle)
{
Gaussian2DFunction f = GaussianFunctionFactory.create2D(1, size, size, GaussianFunctionFactory.FIT_CIRCLE,
null);
StoppingCriteria sc = new ErrorStoppingCriteria(5);
sc.setMaximumIterations(100);
NonLinearFit solver = (bounded != 0 || clamping != 0) ? new BoundedNonLinearFit(f, sc)
: new NonLinearFit(f, sc);
if (clamping != 0)
{
BoundedNonLinearFit bsolver = (BoundedNonLinearFit) solver;
bsolver.setClampValues(defaultClampValues);
bsolver.setDynamicClamp(clamping == 2);
// Local search anecdotally only works with clamped LVM fitters that are not bounded.
// It must act like a soft-bounding search. For now this is not a used feature and
// will not be formally tested
//bsolver.setLocalSearch(3);
}
solver.setMLE(mle);
solver.setInitialLambda(1);
return solver;
}
private String getLVMName(boolean bounded, int clamping, boolean mle)
{
return ((bounded) ? "B" : "") + ((clamping == 0) ? "" : ((clamping == 1) ? "C" : "DC")) + "LVM" +
((mle) ? " MLE" : "");
}
private void canFitSingleGaussian(FunctionSolver solver, boolean applyBounds, boolean withBias)
{
randomGenerator.setSeed(seed);
for (double s : signal)
{
double[] expected = createParams(1, s, 0, 0, 1, withBias);
double[] lower = createParams(0, s * 0.5, -0.2, -0.2, 0.8, withBias);
double[] upper = createParams(3, s * 2, 0.2, 0.2, 1.2, withBias);
if (applyBounds)
solver.setBounds(lower, upper);
for (double n : noise)
{
double[] data = drawGaussian(expected, n, withBias);
for (double db : base)
for (double dx : shift)
for (double dy : shift)
for (double dsx : factor)
{
double[] p = createParams(db, s, dx, dy, dsx, withBias);
double[] fp = fitGaussian(solver, data, p, expected);
for (int i = 0; i < expected.length; i++)
{
if (fp[i] < lower[i])
Assert.assertTrue(
String.format("Fit Failed: [%d] %.2f < %.2f: %s != %s", i, fp[i],
lower[i], Arrays.toString(fp), Arrays.toString(expected)),
false);
if (fp[i] > upper[i])
Assert.assertTrue(
String.format("Fit Failed: [%d] %.2f > %.2f: %s != %s", i, fp[i],
upper[i], Arrays.toString(fp), Arrays.toString(expected)),
false);
}
}
}
}
}
private void canFitSingleGaussianBetter(FunctionSolver solver, boolean applyBounds, boolean withBias,
FunctionSolver solver2, boolean applyBounds2, boolean withBias2, String name, String name2)
{
int LOOPS = 5;
randomGenerator.setSeed(seed);
double bias2 = (withBias != withBias2) ? (withBias) ? -bias : bias : 0;
StoredDataStatistics[] stats = new StoredDataStatistics[6];
String[] statName = { "Signal", "X", "Y" };
int[] betterPrecision = new int[3];
int[] totalPrecision = new int[3];
int[] betterAccuracy = new int[3];
int[] totalAccuracy = new int[3];
int i1 = 0, i2 = 0;
for (double s : signal)
{
double[] expected = createParams(1, s, 0, 0, 1, withBias);
if (applyBounds)
{
double[] lower = createParams(0, s * 0.5, -0.2, -0.2, 0.8, withBias);
double[] upper = createParams(3, s * 2, 0.2, 0.2, 1.2, withBias);
solver.setBounds(lower, upper);
}
double[] expected2 = createParams(1, s, 0, 0, 1, withBias2);
if (applyBounds2)
{
double[] lower2 = createParams(0, s * 0.5, -0.2, -0.2, 0.8, withBias2);
double[] upper2 = createParams(3, s * 2, 0.2, 0.2, 1.2, withBias2);
solver2.setBounds(lower2, upper2);
}
for (double n : noise)
{
for (int loop = LOOPS; loop-- > 0;)
{
double[] data = drawGaussian(expected, n, withBias);
double[] data2 = data.clone();
for (int i = 0; i < data.length; i++)
data2[i] += bias2;
for (int i = 0; i < stats.length; i++)
stats[i] = new StoredDataStatistics();
for (double db : base)
for (double dx : shift)
for (double dy : shift)
for (double dsx : factor)
{
double[] p = createParams(db, s, dx, dy, dsx, withBias);
double[] fp = fitGaussian(solver, data, p, expected);
i1 += solver.getEvaluations();
double[] p2 = createParams(db, s, dx, dy, dsx, withBias2);
double[] fp2 = fitGaussian(solver2, data2, p2, expected2);
i2 += solver2.getEvaluations();
// Get the mean and sd (the fit precision)
compare(fp, expected, fp2, expected2, Gaussian2DFunction.SIGNAL, stats[0],
stats[1]);
compare(fp, expected, fp2, expected2, Gaussian2DFunction.X_POSITION, stats[2],
stats[3]);
compare(fp, expected, fp2, expected2, Gaussian2DFunction.Y_POSITION, stats[4],
stats[5]);
// Use the distance
//stats[2].add(distance(fp, expected));
//stats[3].add(distance(fp2, expected2));
}
double alpha = 0.05; // two sided
for (int i = 0; i < stats.length; i += 2)
{
double u1 = stats[i].getMean();
double u2 = stats[i + 1].getMean();
double sd1 = stats[i].getStandardDeviation();
double sd2 = stats[i + 1].getStandardDeviation();
TTest tt = new TTest();
boolean diff = tt.tTest(stats[i].getValues(), stats[i + 1].getValues(), alpha);
int index = i / 2;
String msg = String.format("%s vs %s : %.1f (%.1f) %s %f +/- %f vs %f +/- %f (N=%d) %b", name2,
name, s, n, statName[index], u2, sd2, u1, sd1, stats[i].getN(), diff);
if (diff)
{
// Different means. Check they are roughly the same
if (DoubleEquality.almostEqualRelativeOrAbsolute(u1, u2, 0.1, 0))
{
// Basically the same. Check which is more precise
if (!DoubleEquality.almostEqualRelativeOrAbsolute(sd1, sd2, 0.05, 0))
{
if (sd2 < sd1)
{
betterPrecision[index]++;
println(msg + " P*");
}
else
println(msg + " P");
totalPrecision[index]++;
}
}
else
{
// Check which is more accurate (closer to zero)
u1 = Math.abs(u1);
u2 = Math.abs(u2);
if (u2 < u1)
{
betterAccuracy[index]++;
println(msg + " A*");
}
else
println(msg + " A");
totalAccuracy[index]++;
}
}
else
{
// The same means. Check that it is more precise
if (!DoubleEquality.almostEqualRelativeOrAbsolute(sd1, sd2, 0.05, 0))
{
if (sd2 < sd1)
{
betterPrecision[index]++;
println(msg + " P*");
}
else
println(msg + " P");
totalPrecision[index]++;
}
}
}
}
}
}
int better = 0, total = 0;
for (int index = 0; index < statName.length; index++)
{
better += betterPrecision[index] + betterAccuracy[index];
total += totalPrecision[index] + totalAccuracy[index];
test(name2, name, statName[index] + " P", betterPrecision[index], totalPrecision[index],
printBetterDetails);
test(name2, name, statName[index] + " A", betterAccuracy[index], totalAccuracy[index], printBetterDetails);
}
test(name2, name, String.format("All (eval [%d] [%d]) : ", i2, i1), better, total, true);
}
private void test(String name2, String name, String statName, int better, int total, boolean print)
{
double p = 100.0 * better / total;
String msg = String.format("%s vs %s : %s %d / %d (%.1f)", name2, name, statName, better, total, p);
if (print)
System.out.println(msg);
// Do not test if we don't have many examples
if (total <= 10)
{
return;
}
// Disable this for now so builds do not fail during the test phase
// It seems that most of the time clamping and bounds improve things.
// There are a few cases where Bounds or Clamping alone do not improve things.
// Use of Dynamic Clamping is always better.
// Use of Bounded Dynamic Clamping is always better.
// The test may be unrealistic as the initial params are close to the actual answer.
//Assert.assertTrue(msg, p >= 50.0);
}
boolean printBetterDetails = false;
private void println(String msg)
{
// TODO Auto-generated method stub
if (printBetterDetails)
System.out.println(msg);
}
static double distance(double[] o, double[] e)
{
double dx = o[Gaussian2DFunction.X_POSITION] - e[Gaussian2DFunction.X_POSITION];
double dy = o[Gaussian2DFunction.Y_POSITION] - e[Gaussian2DFunction.Y_POSITION];
// Use the signs of the coords to assign a direction vector
return Math.sqrt(dx * dx + dy * dy) * Math.signum(Math.signum(dy) * Math.signum(dx));
}
private void compare(double[] o1, double[] e1, double[] o2, double[] e2, int i, Statistics stats1,
Statistics stats2)
{
compare(o1[i], e1[i], o2[i], e2[i], stats1, stats2);
}
private void compare(double o1, double e1, double o2, double e2, Statistics stats1, Statistics stats2)
{
stats1.add(o1 - e1);
stats2.add(o2 - e2);
}
private double[] createParams(double db, double signal, double dx, double dy, double dsx, boolean withBias)
{
double[] p = params.clone();
p[Gaussian2DFunction.BACKGROUND] *= db;
if (withBias)
p[Gaussian2DFunction.BACKGROUND] += bias;
p[Gaussian2DFunction.SIGNAL] = signal;
p[Gaussian2DFunction.X_POSITION] += dx;
p[Gaussian2DFunction.Y_POSITION] += dy;
p[Gaussian2DFunction.X_SD] *= dsx;
return p;
}
private double[] fitGaussian(FunctionSolver solver, double[] data, double[] params, double[] expected)
{
params = params.clone();
FitStatus status = solver.fit(data, null, params, null);
if (status != FitStatus.OK)
Assert.assertTrue(String.format("Fit Failed: %s i=%d: %s != %s", status.toString(), solver.getIterations(),
Arrays.toString(params), Arrays.toString(expected)), false);
return params;
}
/**
* Draw a Gaussian with Poisson shot noise and Gaussian read noise
*
* @param params
* The Gaussian parameters
* @param noise
* The read noise
* @param withBias
* @return The data
*/
private double[] drawGaussian(double[] params, double noise, boolean withBias)
{
double[] data = new double[size * size];
int n = params.length / 6;
Gaussian2DFunction f = GaussianFunctionFactory.create2D(n, size, size, GaussianFunctionFactory.FIT_CIRCLE,
null);
f.initialise(params);
final double bias = (withBias) ? BoundedFunctionSolverTest.bias : 0;
for (int i = 0; i < data.length; i++)
{
data[i] = bias + dataGenerator.nextPoisson(f.eval(i) - bias);
}
if (noise != 0)
for (int i = 0; i < data.length; i++)
data[i] += dataGenerator.nextGaussian(0, noise);
//gdsc.core.ij.Utils.display("Spot", data, size, size);
return data;
}
}