package gdsc.smlm.fitting.linear;
import gdsc.smlm.TestSettings;
import gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator;
import gdsc.smlm.function.gaussian.SingleFreeCircularGaussian2DFunction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import org.apache.commons.math3.util.FastMath;
import org.junit.Assert;
import org.junit.Test;
public class SolverSpeedTest
{
int MAX_ITER = 20000;
Random rand;
ArrayList<float[][]> Adata;
ArrayList<float[]> Bdata;
public SolverSpeedTest()
{
rand = new Random(30051977);
Adata = new ArrayList<float[][]>();
Bdata = new ArrayList<float[]>();
for (int i = 0; i < MAX_ITER; i++)
{
float[][] a = new float[6][6];
float[] b = new float[6];
if (createData(a, b, false))
{
Adata.add(a);
Bdata.add(b);
}
}
}
@Test
public void solveLinearAndGaussJordanReturnSameSolutionAndInversionResult()
{
int ITER = 100;
ArrayList<double[][]> A = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B = copyBdouble(this.Bdata, ITER);
ArrayList<double[][]> A2 = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B2 = copyBdouble(this.Bdata, ITER);
GaussJordan solver = new GaussJordan();
EJMLLinearSolver solver2 = new EJMLLinearSolver();
for (int i = 0; i < A.size(); i++)
{
double[][] a = A.get(i);
double[] b = B.get(i);
double[][] a2 = A2.get(i);
double[] b2 = B2.get(i);
boolean r1 = solver.solve(a, b);
boolean r2 = solver2.solveLinear(a2, b2);
solver2.invert(a2);
Assert.assertSame("Different solve result @ " + i, r1, r2);
for (int j = 0; j < b.length; j++)
{
Assert.assertEquals("Different b result @ " + i, b[j] / b2[j], 1.0, 1e-2);
String msg2 = "Different a[" + j + "] result @ " + i;
for (int k = 0; k < b.length; k++)
Assert.assertEquals(msg2, a[j][k] / a2[j][k], 1, 0.2);
}
}
}
@Test
public void solveLinearAndGaussJordanReturnSameSolutionResult()
{
int ITER = 100;
ArrayList<double[][]> A = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B = copyBdouble(this.Bdata, ITER);
ArrayList<double[][]> A2 = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B2 = copyBdouble(this.Bdata, ITER);
GaussJordan solver = new GaussJordan();
EJMLLinearSolver solver2 = new EJMLLinearSolver();
for (int i = 0; i < ITER; i++)
{
double[][] a = A.get(i);
double[] b = B.get(i);
double[][] a2 = A2.get(i);
double[] b2 = B2.get(i);
boolean r1 = solver.solve(a, b);
boolean r2 = solver2.solve(a2, b2);
Assert.assertSame("Different solve result @ " + i, r1, r2);
for (int j = 0; j < b.length; j++)
Assert.assertEquals("Different b result @ " + i, b[j] / b2[j], 1.0, 1e-2);
}
}
@Test
public void gaussJordanFloatAndDoubleReturnSameSolutionAndInversionResult()
{
int ITER = 100;
ArrayList<float[][]> A = copyAfloat(this.Adata, ITER);
ArrayList<float[]> B = copyBfloat(this.Bdata, ITER);
ArrayList<double[][]> A2 = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B2 = copyBdouble(this.Bdata, ITER);
GaussJordan solver = new GaussJordan();
for (int i = 0; i < A.size(); i++)
{
float[][] a = A.get(i);
float[] b = B.get(i);
double[][] a2 = A2.get(i);
double[] b2 = B2.get(i);
boolean r1 = solver.solve(a, b);
boolean r2 = solver.solve(a2, b2);
Assert.assertSame("Different solve result @ " + i, r1, r2);
for (int j = 0; j < b.length; j++)
{
Assert.assertEquals("Different b result @ " + i, b[j] / b2[j], 1.0, 1e-2);
String msg2 = "Different a[" + j + "] result @ " + i;
for (int k = 0; k < b.length; k++)
Assert.assertEquals(msg2, a[j][k] / a2[j][k], 1, 0.2);
}
}
}
@Test
public void solveLinearWithInversionIsFasterThanGaussJordanFloat()
{
org.junit.Assume.assumeTrue(TestSettings.RUN_SPEED_TESTS);
int ITER = 10000;
ArrayList<float[][]> A = copyAfloat(this.Adata, ITER);
ArrayList<float[]> B = copyBfloat(this.Bdata, ITER);
ArrayList<double[][]> A2 = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B2 = copyBdouble(this.Bdata, ITER);
GaussJordan solver = new GaussJordan();
EJMLLinearSolver solver2 = new EJMLLinearSolver();
runFloat(copyAfloat(this.Adata, ITER), copyBfloat(this.Bdata, ITER), ITER, solver);
solveLinearWithInversion(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver2);
long start1 = System.nanoTime();
runFloat(A, B, ITER, solver);
start1 = System.nanoTime() - start1;
long start2 = System.nanoTime();
solveLinearWithInversion(A2, B2, ITER, solver2);
start2 = System.nanoTime() - start2;
log("GaussJordanFloat = %d : LinearSolver.solveLinearWithInversion = %d : %fx\n", start1, start2,
(1.0 * start1) / start2);
if (TestSettings.ASSERT_SPEED_TESTS)
Assert.assertTrue(start2 < start1);
}
@Test
public void solveLinearIsFasterThanGaussJordanFloat()
{
org.junit.Assume.assumeTrue(TestSettings.RUN_SPEED_TESTS);
int ITER = 10000;
ArrayList<float[][]> A = copyAfloat(this.Adata, ITER);
ArrayList<float[]> B = copyBfloat(this.Bdata, ITER);
ArrayList<double[][]> A2 = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B2 = copyBdouble(this.Bdata, ITER);
GaussJordan solver = new GaussJordan();
EJMLLinearSolver solver2 = new EJMLLinearSolver();
runFloat(copyAfloat(this.Adata, ITER), copyBfloat(this.Bdata, ITER), ITER, solver);
solveLinear(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver2);
long start1 = System.nanoTime();
runFloat(A, B, ITER, solver);
start1 = System.nanoTime() - start1;
long start2 = System.nanoTime();
solveLinear(A2, B2, ITER, solver2);
start2 = System.nanoTime() - start2;
log("GaussJordanFloat = %d : LinearSolver.solveLinear = %d : %fx\n", start1, start2, (1.0 * start1) / start2);
if (TestSettings.ASSERT_SPEED_TESTS)
Assert.assertTrue(start2 < start1);
}
protected void runFloat(ArrayList<float[][]> A, ArrayList<float[]> B, int ITER, GaussJordan solver)
{
for (int i = 0; i < ITER; i++)
{
solver.solve(A.get(i), B.get(i));
}
}
@Test
public void solveLinearWithInversionIsFasterThanGaussJordanDouble()
{
org.junit.Assume.assumeTrue(TestSettings.RUN_SPEED_TESTS);
int ITER = 10000;
ArrayList<double[][]> A = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B = copyBdouble(this.Bdata, ITER);
ArrayList<double[][]> A2 = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B2 = copyBdouble(this.Bdata, ITER);
GaussJordan solver = new GaussJordan();
EJMLLinearSolver solver2 = new EJMLLinearSolver();
solveGaussJordan(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver);
solveLinearWithInversion(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver2);
long start1 = System.nanoTime();
solveGaussJordan(A, B, ITER, solver);
start1 = System.nanoTime() - start1;
long start2 = System.nanoTime();
solveLinearWithInversion(A2, B2, ITER, solver2);
start2 = System.nanoTime() - start2;
log("GaussJordanDouble = %d : LinearSolver.solveLinearWithInversion = %d : %fx\n", start1, start2,
(1.0 * start1) / start2);
if (TestSettings.ASSERT_SPEED_TESTS)
Assert.assertTrue(start2 < start1);
}
@Test
public void solveLinearIsFasterThanGaussJordanDouble()
{
org.junit.Assume.assumeTrue(TestSettings.RUN_SPEED_TESTS);
int ITER = 10000;
ArrayList<double[][]> A = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B = copyBdouble(this.Bdata, ITER);
ArrayList<double[][]> A2 = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B2 = copyBdouble(this.Bdata, ITER);
GaussJordan solver = new GaussJordan();
EJMLLinearSolver solver2 = new EJMLLinearSolver();
solveGaussJordan(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver);
solveLinear(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver2);
long start1 = System.nanoTime();
solveGaussJordan(A, B, ITER, solver);
start1 = System.nanoTime() - start1;
long start2 = System.nanoTime();
solveLinear(A2, B2, ITER, solver2);
start2 = System.nanoTime() - start2;
log("GaussJordanDouble = %d : LinearSolver.solveLinear = %d : %fx\n", start1, start2, (1.0 * start1) / start2);
if (TestSettings.ASSERT_SPEED_TESTS)
Assert.assertTrue(start2 < start1);
}
@Test
public void solveCholeskyIsFasterThanGaussJordanDouble()
{
org.junit.Assume.assumeTrue(TestSettings.RUN_SPEED_TESTS);
int ITER = 10000;
ArrayList<double[][]> A = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B = copyBdouble(this.Bdata, ITER);
ArrayList<double[][]> A2 = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B2 = copyBdouble(this.Bdata, ITER);
GaussJordan solver = new GaussJordan();
EJMLLinearSolver solver2 = new EJMLLinearSolver();
solveGaussJordan(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver);
solveCholesky(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver2);
long start1 = System.nanoTime();
solveGaussJordan(A, B, ITER, solver);
start1 = System.nanoTime() - start1;
long start2 = System.nanoTime();
solveCholesky(A2, B2, ITER, solver2);
start2 = System.nanoTime() - start2;
log("GaussJordanDouble = %d : LinearSolver.solveCholesky = %d : %fx\n", start1, start2,
(1.0 * start1) / start2);
if (TestSettings.ASSERT_SPEED_TESTS)
Assert.assertTrue(start2 < start1);
}
@Test
public void solveCholeskyLDLTIsFasterThanGaussJordanDouble()
{
org.junit.Assume.assumeTrue(TestSettings.RUN_SPEED_TESTS);
int ITER = 10000;
ArrayList<double[][]> A = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B = copyBdouble(this.Bdata, ITER);
ArrayList<double[][]> A2 = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B2 = copyBdouble(this.Bdata, ITER);
GaussJordan solver = new GaussJordan();
EJMLLinearSolver solver2 = new EJMLLinearSolver();
solveGaussJordan(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver);
solveCholeskyLDLT(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver2);
long start1 = System.nanoTime();
solveGaussJordan(A, B, ITER, solver);
start1 = System.nanoTime() - start1;
long start2 = System.nanoTime();
solveCholeskyLDLT(A2, B2, ITER, solver2);
start2 = System.nanoTime() - start2;
log("GaussJordanDouble = %d : LinearSolver.solveCholeskyLDLT = %d : %fx\n", start1, start2,
(1.0 * start1) / start2);
if (TestSettings.ASSERT_SPEED_TESTS)
Assert.assertTrue(start2 < start1);
}
@Test
public void solveIsFasterThanGaussJordanDouble()
{
org.junit.Assume.assumeTrue(TestSettings.RUN_SPEED_TESTS);
int ITER = 10000;
ArrayList<double[][]> A = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B = copyBdouble(this.Bdata, ITER);
ArrayList<double[][]> A2 = copyAdouble(this.Adata, ITER);
ArrayList<double[]> B2 = copyBdouble(this.Bdata, ITER);
GaussJordan solver = new GaussJordan();
EJMLLinearSolver solver2 = new EJMLLinearSolver();
solveGaussJordan(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver);
solve(copyAdouble(this.Adata, ITER), copyBdouble(this.Bdata, ITER), ITER, solver2);
long start1 = System.nanoTime();
solveGaussJordan(A, B, ITER, solver);
start1 = System.nanoTime() - start1;
long start2 = System.nanoTime();
solve(A2, B2, ITER, solver2);
start2 = System.nanoTime() - start2;
log("GaussJordanDouble = %d : LinearSolver.solve = %d : %fx\n", start1, start2, (1.0 * start1) / start2);
if (TestSettings.ASSERT_SPEED_TESTS)
Assert.assertTrue(start2 < start1);
}
private boolean createData(float[][] alpha, float[] beta, boolean positiveDifinite)
{
// Generate a 2D Gaussian
SingleFreeCircularGaussian2DFunction func = new SingleFreeCircularGaussian2DFunction(10, 10);
double[] a = new double[] {
// Background, Amplitude, Angle, Xpos, Ypos, Xwidth, yWidth
20 + rand.nextDouble() * 5, 10 + rand.nextDouble() * 5, 0, 5 + rand.nextDouble() * 2,
5 + rand.nextDouble() * 2, 5 + rand.nextDouble() * 2, 5 + rand.nextDouble() * 2 };
int[] x = new int[100];
double[] y = new double[100];
func.initialise(a);
for (int i = 0; i < x.length; i++)
{
// Add random noise
y[i] = func.eval(i) + ((rand.nextDouble() < 0.5) ? -rand.nextDouble() * 5 : rand.nextDouble() * 5);
}
// Randomise parameters
for (int i = 0; i < a.length; i++)
a[i] += (rand.nextDouble() < 0.5) ? -rand.nextDouble() : rand.nextDouble();
// Compute the Hessian and parameter gradient vector
GradientCalculator calc = new GradientCalculator(6);
double[][] alpha2 = new double[6][6];
double[] beta2 = new double[6];
calc.findLinearised(y.length, y, a, alpha2, beta2, func);
// Update the Hessian using a lambda shift
double lambda = 1.001;
for (int i = 0; i < alpha2.length; i++)
alpha2[i][i] *= lambda;
// Copy back
for (int i = 0; i < beta.length; i++)
{
beta[i] = (float) beta2[i];
for (int j = 0; j < beta.length; j++)
{
alpha[i][j] = (float) alpha2[i][j];
}
}
// Check for a positive definite matrix
if (positiveDifinite)
{
EJMLLinearSolver solver = new EJMLLinearSolver();
return solver.solveCholeskyLDLT(copydouble(alpha), copydouble(beta));
}
return true;
}
private ArrayList<float[][]> copyAfloat(ArrayList<float[][]> a, int iter)
{
iter = FastMath.min(a.size(), iter);
ArrayList<float[][]> a2 = new ArrayList<float[][]>(iter);
for (int i = 0; i < iter; i++)
a2.add(copyfloat(a.get(i)));
return a2;
}
private float[][] copyfloat(float[][] d)
{
float[][] d2 = new float[d.length][d.length];
for (int i = 0; i < d.length; i++)
for (int j = 0; j < d.length; j++)
d2[i][j] = d[i][j];
return d2;
}
private ArrayList<float[]> copyBfloat(ArrayList<float[]> b, int iter)
{
iter = FastMath.min(b.size(), iter);
ArrayList<float[]> b2 = new ArrayList<float[]>(iter);
for (int i = 0; i < iter; i++)
b2.add(Arrays.copyOf(b.get(i), b.get(i).length));
return b2;
}
private ArrayList<double[][]> copyAdouble(ArrayList<float[][]> a, int iter)
{
iter = FastMath.min(a.size(), iter);
ArrayList<double[][]> a2 = new ArrayList<double[][]>(iter);
for (int i = 0; i < iter; i++)
a2.add(copydouble(a.get(i)));
return a2;
}
private double[][] copydouble(float[][] d)
{
double[][] d2 = new double[d.length][d.length];
for (int i = 0; i < d.length; i++)
for (int j = 0; j < d.length; j++)
d2[i][j] = d[i][j];
return d2;
}
private ArrayList<double[]> copyBdouble(ArrayList<float[]> b, int iter)
{
iter = FastMath.min(b.size(), iter);
ArrayList<double[]> b2 = new ArrayList<double[]>(iter);
for (int i = 0; i < iter; i++)
b2.add(copydouble(b.get(i)));
return b2;
}
private double[] copydouble(float[] d)
{
double[] d2 = new double[d.length];
for (int i = 0; i < d.length; i++)
d2[i] = d[i];
return d2;
}
protected void solveGaussJordan(ArrayList<double[][]> A, ArrayList<double[]> B, int ITER, GaussJordan solver)
{
ITER = FastMath.min(ITER, A.size());
for (int i = 0; i < ITER; i++)
{
solver.solve(A.get(i), B.get(i));
}
}
protected void solveLinearWithInversion(ArrayList<double[][]> A, ArrayList<double[]> B, int ITER,
EJMLLinearSolver solver)
{
ITER = FastMath.min(ITER, A.size());
for (int i = 0; i < ITER; i++)
{
double[][]a = A.get(i);
solver.solveLinear(a, B.get(i));
solver.invert(a);
}
}
protected void solveLinear(ArrayList<double[][]> A, ArrayList<double[]> B, int ITER, EJMLLinearSolver solver)
{
ITER = FastMath.min(ITER, A.size());
for (int i = 0; i < ITER; i++)
{
solver.solveLinear(A.get(i), B.get(i));
}
}
protected void solveCholesky(ArrayList<double[][]> A, ArrayList<double[]> B, int ITER, EJMLLinearSolver solver)
{
ITER = FastMath.min(ITER, A.size());
for (int i = 0; i < ITER; i++)
{
solver.solveCholesky(A.get(i), B.get(i));
}
}
protected void solveCholeskyLDLT(ArrayList<double[][]> A, ArrayList<double[]> B, int ITER, EJMLLinearSolver solver)
{
ITER = FastMath.min(ITER, A.size());
for (int i = 0; i < ITER; i++)
{
solver.solveCholeskyLDLT(A.get(i), B.get(i));
}
}
protected void solve(ArrayList<double[][]> A, ArrayList<double[]> B, int ITER, EJMLLinearSolver solver)
{
ITER = FastMath.min(ITER, A.size());
for (int i = 0; i < ITER; i++)
{
solver.solve(A.get(i), B.get(i));
}
}
void log(String format, Object... args)
{
System.out.printf(format, args);
}
}