package gdsc.smlm.fitting.linear;
import java.util.Arrays;
import org.apache.commons.math3.random.RandomDataGenerator;
import org.apache.commons.math3.random.Well19937c;
import org.junit.Assert;
import org.junit.Test;
import gdsc.core.test.BaseTimingTask;
import gdsc.core.test.TimingService;
import gdsc.core.utils.DoubleEquality;
import gdsc.core.utils.TurboList;
import gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator;
import gdsc.smlm.fitting.nonlinear.gradient.GradientCalculatorFactory;
import gdsc.smlm.function.ValueProcedure;
import gdsc.smlm.function.gaussian.Gaussian2DFunction;
import gdsc.smlm.function.gaussian.GaussianFunctionFactory;
public class EJMLLinearSolverTest
{
//@formatter:off
@Test
public void canSolveLinearEquation()
{
EJMLLinearSolver solver = new EJMLLinearSolver(3, 1e-6);
// Solves (one) linear equation, a x = b, for x[n]
// Taken from https://en.wikipedia.org/wiki/Positive-definite_matrix
double[][] a = new double[][] {
new double[] { 2, -1, 0 },
new double[] { -1, 2, -1 },
new double[] { 0, -1, 2 } };
double[] b = new double[] { 3, 3, 4 };
// Expected solution
double[] x = new double[] { 4.75, 6.5, 5.25 };
double[][] a_inv = new double[][] {
new double[] { 0.75, 0.5, 0.25 },
new double[] { 0.5, 1, 0.5 },
new double[] { 0.25, 0.5, 0.75 } };
boolean result = solver.solve(a, b);
solver.invert(a);
Assert.assertTrue("Failed to solve", result);
Assert.assertArrayEquals("Bad solution", x, b, 1e-4f);
log("x = %s\n", Arrays.toString(b));
for (int i = 0; i < b.length; i++)
{
log("a[%d] = %s\n", i, Arrays.toString(a[i]));
Assert.assertArrayEquals("Bad inversion", a_inv[i], a[i], 1e-4f);
}
}
@Test
public void canSolveLinearEquationWithZeroInB()
{
EJMLLinearSolver solver = new EJMLLinearSolver(3, 1e-6);
// Solves (one) linear equation, a x = b, for x[n]
// Taken from https://en.wikipedia.org/wiki/Positive-definite_matrix
double[][] a = new double[][] {
new double[] { 2, -1, 0 },
new double[] { -1, 2, -1 },
new double[] { 0, -1, 2 } };
double[] b = new double[] { 3, 0, 4 };
// Expected solution
double[] x = new double[] { 3.25, 3.5, 3.75 };
double[][] a_inv = new double[][] {
new double[] { 0.75, 0.5, 0.25 },
new double[] { 0.5, 1, 0.5 },
new double[] { 0.25, 0.5, 0.75 } };
boolean result = solver.solve(a, b);
solver.invert(a);
Assert.assertTrue("Failed to solve", result);
Assert.assertArrayEquals("Bad solution", x, b, 1e-4f);
log("x = %s\n", Arrays.toString(b));
for (int i = 0; i < b.length; i++)
{
log("a[%d] = %s\n", i, Arrays.toString(a[i]));
Assert.assertArrayEquals("Bad inversion", a_inv[i], a[i], 1e-4f);
}
}
@Test
public void canSolveLinearEquationWithZeroInA()
{
EJMLLinearSolver solver = new EJMLLinearSolver(3, 1e-6);
// Solves (one) linear equation, a x = b, for x[n]
double[][] a = new double[][] {
new double[] { 2, 0, -1, 0 },
new double[] { 0, 0, 0, 0 },
new double[] { -1, 0, 2, -1 },
new double[] { 0, 0, -1, 2 } };
double[] b = new double[] { 3, 0, 3, 4 };
// Expected solution
double[] x = new double[] { 4.75, 0, 6.5, 5.25 };
double[][] a_inv = new double[][] {
new double[] { 0.75, 0, 0.5, 0.25 },
new double[] { 0, 0, 0, 0 },
new double[] { 0.5, 0, 1, 0.5 },
new double[] { 0.25, 0, 0.5, 0.75 } };
boolean result = solver.solve(a, b);
solver.invert(a);
Assert.assertTrue("Failed to solve", result);
Assert.assertArrayEquals("Bad solution", x, b, 1e-4f);
log("x = %s\n", Arrays.toString(b));
for (int i = 0; i < b.length; i++)
{
log("a[%d] = %s\n", i, Arrays.toString(a[i]));
Assert.assertArrayEquals("Bad inversion", a_inv[i], a[i], 1e-4f);
}
}
@Test
public void canSolveLinearEquationWithZerosInA()
{
EJMLLinearSolver solver = new EJMLLinearSolver();
DoubleEquality eq = new DoubleEquality(3, 1e-16);
solver.setEqual(eq);
// Solves (one) linear equation, a x = b, for x[n]
double[][] a = new double[][] {
new double[] { 2, 0, -1, 0, 0, 0 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { -1, 0, 2, 0, 0, -1 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0, 0, -1, 0, 0, 2 } };
double[] b = new double[] { 3, 0, 3, 0, 0, 4 };
// Expected solution
double[] x = new double[] { 4.75, 0, 6.5, 0, 0, 5.25 };
double[][] a_inv = new double[][] {
new double[] { 0.75, 0, 0.5, 0, 0, 0.25 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0.5, 0, 1, 0, 0, 0.5 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0.25, 0, 0.5, 0, 0, 0.75 } };
boolean result = solver.solve(a, b);
solver.invert(a);
Assert.assertTrue("Failed to solve", result);
Assert.assertArrayEquals("Bad solution", x, b, 1e-4f);
log("x = %s\n", Arrays.toString(b));
for (int i = 0; i < b.length; i++)
{
log("a[%d] = %s\n", i, Arrays.toString(a[i]));
Assert.assertArrayEquals("Bad inversion", a_inv[i], a[i], 1e-4f);
}
}
@Test
public void canInvert()
{
EJMLLinearSolver solver = new EJMLLinearSolver(3, 1e-6);
// Solves (one) linear equation, a x = b, for x[n]
// Taken from https://en.wikipedia.org/wiki/Positive-definite_matrix
double[][] a = new double[][] {
new double[] { 2, -1, 0 },
new double[] { -1, 2, -1 },
new double[] { 0, -1, 2 } };
// Expected solution
double[][] a_inv = new double[][] {
new double[] { 0.75, 0.5, 0.25 },
new double[] { 0.5, 1, 0.5 },
new double[] { 0.25, 0.5, 0.75 } };
boolean result = solver.invertSymmPosDef(a);
Assert.assertTrue("Failed to invert", result);
for (int i = 0; i < a[0].length; i++)
{
log("a[%d] = %s\n", i, Arrays.toString(a[i]));
Assert.assertArrayEquals("Bad inversion", a_inv[i], a[i], 1e-4f);
}
}
@Test
public void canInvertWithZeros()
{
EJMLLinearSolver solver = new EJMLLinearSolver(3, 1e-6);
// Solves (one) linear equation, a x = b, for x[n]
// Taken from https://en.wikipedia.org/wiki/Positive-definite_matrix
double[][] a = new double[][] {
new double[] { 2, 0, -1, 0, 0, 0 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { -1, 0, 2, 0, 0, -1 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0, 0, -1, 0, 0, 2 } };
// Expected solution
double[][] a_inv = new double[][] {
new double[] { 0.75, 0, 0.5, 0, 0, 0.25 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0.5, 0, 1, 0, 0, 0.5 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0.25, 0, 0.5, 0, 0, 0.75 } };
boolean result = solver.invertSymmPosDef(a);
Assert.assertTrue("Failed to invert", result);
for (int i = 0; i < a[0].length; i++)
{
log("a[%d] = %s\n", i, Arrays.toString(a[i]));
Assert.assertArrayEquals("Bad inversion", a_inv[i], a[i], 1e-4f);
}
}
@Test
public void canInvertDiagonal()
{
EJMLLinearSolver solver = new EJMLLinearSolver(3, 1e-6);
// Solves (one) linear equation, a x = b, for x[n]
// Taken from https://en.wikipedia.org/wiki/Positive-definite_matrix
double[][] a = new double[][] {
new double[] { 2, -1, 0 },
new double[] { -1, 2, -1 },
new double[] { 0, -1, 2 } };
// Expected solution
double[] e = new double[] { 0.75, 1, 0.75 };
double[] o = solver.invertSymmPosDefDiagonal(a);
Assert.assertNotNull("Failed to invert", o);
log("a diagonal = %s\n", Arrays.toString(o));
Assert.assertArrayEquals("Bad inversion", e, o, 1e-4);
}
@Test
public void canInvertDiagonalWithZeros()
{
EJMLLinearSolver solver = new EJMLLinearSolver(3, 1e-6);
// Solves (one) linear equation, a x = b, for x[n]
// Taken from https://en.wikipedia.org/wiki/Positive-definite_matrix
double[][] a = new double[][] {
new double[] { 2, 0, -1, 0, 0, 0 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { -1, 0, 2, 0, 0, -1 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0, 0, 0, 0, 0, 0 },
new double[] { 0, 0, -1, 0, 0, 2 } };
// Expected solution
double[] e = new double[] { 0.75, 0, 1, 0, 0, 0.75 };
double[] o = solver.invertSymmPosDefDiagonal(a);
Assert.assertNotNull("Failed to invert", o);
log("a diagonal = %s\n", Arrays.toString(o));
Assert.assertArrayEquals("Bad inversion", e, o, 1e-4);
}
//@formatter:on
private abstract class SolverTimingTask extends BaseTimingTask
{
double[][][] a;
double[][] b;
// No validation for a pure speed test
EJMLLinearSolver solver = new EJMLLinearSolver();
public SolverTimingTask(String name, double[][][] a, double[][] b)
{
super(name + " " + b[0].length);
// Clone the data
this.a = a;
this.b = b;
// Check the solver gets a good answer
solver.setEqual(new DoubleEquality(3, 1e-6));
Object data = getData(0);
a = (double[][][]) ((Object[]) data)[0];
b = (double[][]) ((Object[]) data)[1];
for (int i = 0; i < a.length; i++)
{
if (!solve(a[i], b[i]))
{
throw new RuntimeException(getName() + " failed to solve");
}
}
solver.setEqual(null);
}
public int getSize()
{
return 1;
}
public Object getData(int i)
{
// Clone
int n = b.length;
int m = b[0].length;
double[][][] a = new double[n][][];
double[][] b = new double[n][];
while (n-- > 0)
{
a[n] = new double[m][];
for (int j = m; j-- > 0;)
a[n][j] = this.a[n][j].clone();
b[n] = this.b[n].clone();
}
return new Object[] { a, b };
}
public Object run(Object data)
{
double[][][] a = (double[][][]) ((Object[]) data)[0];
double[][] b = (double[][]) ((Object[]) data)[1];
for (int i = 0; i < a.length; i++)
{
solve(a[i], b[i]);
}
return null;
}
abstract boolean solve(double[][] a, double[] b);
}
private class LinearSolverTimingTask extends SolverTimingTask
{
public LinearSolverTimingTask(double[][][] a, double[][] b)
{
super("Linear", a, b);
}
boolean solve(double[][] a, double[] b)
{
return solver.solveLinear(a, b);
}
}
private class CholeskySolverTimingTask extends SolverTimingTask
{
public CholeskySolverTimingTask(double[][][] a, double[][] b)
{
super("Cholesky", a, b);
}
boolean solve(double[][] a, double[] b)
{
return solver.solveCholesky(a, b);
}
}
private class CholeskyLDLTSolverTimingTask extends SolverTimingTask
{
public CholeskyLDLTSolverTimingTask(double[][][] a, double[][] b)
{
super("CholeskyLDLT", a, b);
}
boolean solve(double[][] a, double[] b)
{
return solver.solveCholeskyLDLT(a, b);
}
}
private class PseudoInverseSolverTimingTask extends SolverTimingTask
{
public PseudoInverseSolverTimingTask(double[][][] a, double[][] b)
{
super("PseudoInverse", a, b);
}
boolean solve(double[][] a, double[] b)
{
return solver.solvePseudoInverse(a, b);
}
}
private class DirectInversionSolverTimingTask extends SolverTimingTask
{
public DirectInversionSolverTimingTask(double[][][] a, double[][] b)
{
super("DirectInversion", a, b);
}
boolean solve(double[][] a, double[] b)
{
return solver.solveDirectInversion(a, b);
}
}
// Create a speed test of the different methods
@Test
public void runSpeedTest5()
{
runSpeedTest(GaussianFunctionFactory.FIT_ERF_CIRCLE);
}
@Test
public void runSpeedTest4()
{
runSpeedTest(GaussianFunctionFactory.FIT_ERF_FIXED);
}
@Test
public void runSpeedTest3()
{
runSpeedTest(GaussianFunctionFactory.FIT_SIMPLE_NB_FIXED);
}
@Test
public void runSpeedTest2()
{
runSpeedTest(GaussianFunctionFactory.FIT_SIMPLE_NS_NB_FIXED);
}
private void runSpeedTest(int flags)
{
final Gaussian2DFunction f0 = GaussianFunctionFactory.create2D(1, 10, 10, flags, null);
int n = f0.size();
final double[] y = new double[n];
final TurboList<double[][]> aList = new TurboList<double[][]>();
final TurboList<double[]> bList = new TurboList<double[]>();
double[] testbackground = new double[] { 0.2, 0.7 };
double[] testsignal1 = new double[] { 30, 100, 300 };
double[] testcx1 = new double[] { 4.9, 5.3 };
double[] testcy1 = new double[] { 4.8, 5.2 };
double[] testw1 = new double[] { 1.1, 1.2, 1.5 };
int np = f0.getNumberOfGradients();
GradientCalculator calc = GradientCalculatorFactory.newCalculator(np);
final RandomDataGenerator rdg = new RandomDataGenerator(new Well19937c(30051977));
//double lambda = 10;
for (double background : testbackground)
// Peak 1
for (double signal1 : testsignal1)
for (double cx1 : testcx1)
for (double cy1 : testcy1)
for (double w1 : testw1)
{
double[] p = new double[] { background, signal1, 0, cx1, cy1, w1, w1 };
f0.initialise(p);
f0.forEach(new ValueProcedure()
{
int i = 0;
public void execute(double value)
{
// Poisson data
y[i++] = rdg.nextPoisson(value);
}
});
double[][] alpha = new double[np][np];
double[] beta = new double[np];
//double ss =
calc.findLinearised(n, y, p, alpha, beta, f0);
//System.out.printf("SS = %f\n", ss);
// As per the LVM algorithm
//for (int i = 0; i < np; i++)
// alpha[i][i] *= lambda;
aList.add(alpha);
bList.add(beta);
}
double[][][] a = aList.toArray(new double[aList.size()][][]);
double[][] b = bList.toArray(new double[bList.size()][]);
int runs = 10000 / a.length;
TimingService ts = new TimingService(runs);
ts.execute(new LinearSolverTimingTask(a, b));
ts.execute(new CholeskySolverTimingTask(a, b));
ts.execute(new CholeskyLDLTSolverTimingTask(a, b));
ts.execute(new PseudoInverseSolverTimingTask(a, b));
ts.execute(new DirectInversionSolverTimingTask(a, b));
ts.repeat();
ts.report();
}
void log(String format, Object... args)
{
System.out.printf(format, args);
}
}