package gdsc.smlm.fitting.nonlinear;
import java.util.Arrays;
import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.exception.TooManyIterationsException;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer;
import org.apache.commons.math3.fitting.leastsquares.ValueAndJacobianFunction;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.Pair;
import org.apache.commons.math3.util.Precision;
import gdsc.smlm.fitting.FisherInformationMatrix;
import gdsc.smlm.fitting.FitStatus;
import gdsc.smlm.fitting.nonlinear.gradient.GradientCalculator;
import gdsc.smlm.fitting.nonlinear.gradient.GradientCalculatorFactory;
import gdsc.smlm.function.ExtendedNonLinearFunction;
import gdsc.smlm.function.MultivariateMatrixFunctionWrapper;
import gdsc.smlm.function.MultivariateVectorFunctionWrapper;
import gdsc.smlm.function.NonLinearFunction;
import gdsc.smlm.function.ValueProcedure;
import gdsc.smlm.function.gaussian.Gaussian2DFunction;
/*-----------------------------------------------------------------------------
* GDSC SMLM Software
*
* Copyright (C) 2013 Alex Herbert
* Genome Damage and Stability Centre
* University of Sussex, UK
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*---------------------------------------------------------------------------*/
/**
* Uses Apache Commons Math Levenberg-Marquardt method to fit a nonlinear model with coefficients (a) for a
* set of data points (x, y).
*/
public class ApacheLVMFitter extends LSEBaseFunctionSolver
{
/**
* Default constructor
*/
public ApacheLVMFitter(Gaussian2DFunction gf)
{
super(gf);
}
public FitStatus computeFit(double[] y, final double[] y_fit, double[] a, double[] a_dev)
{
int n = y.length;
try
{
// Different convergence thresholds seem to have no effect on the resulting fit, only the number of
// iterations for convergence
final double initialStepBoundFactor = 100;
final double costRelativeTolerance = 1e-10;
final double parRelativeTolerance = 1e-10;
final double orthoTolerance = 1e-10;
final double threshold = Precision.SAFE_MIN;
// Extract the parameters to be fitted
final double[] initialSolution = getInitialSolution(a);
// TODO - Pass in more advanced stopping criteria.
// Create the target and weight arrays
final double[] yd = new double[n];
final double[] w = new double[n];
for (int i = 0; i < n; i++)
{
yd[i] = y[i];
w[i] = 1;
}
LevenbergMarquardtOptimizer optimizer = new LevenbergMarquardtOptimizer(initialStepBoundFactor,
costRelativeTolerance, parRelativeTolerance, orthoTolerance, threshold);
//@formatter:off
LeastSquaresBuilder builder = new LeastSquaresBuilder()
.maxEvaluations(Integer.MAX_VALUE)
.maxIterations(getMaxEvaluations())
.start(initialSolution)
.target(yd)
.weight(new DiagonalMatrix(w));
//@formatter:on
if (f instanceof ExtendedNonLinearFunction && ((ExtendedNonLinearFunction) f).canComputeValuesAndJacobian())
{
// Compute together, or each individually
builder.model(new ValueAndJacobianFunction()
{
final ExtendedNonLinearFunction fun = (ExtendedNonLinearFunction) f;
public Pair<RealVector, RealMatrix> value(RealVector point)
{
final double[] p = point.toArray();
final Pair<double[], double[][]> result = fun.computeValuesAndJacobian(p);
return new Pair<RealVector, RealMatrix>(new ArrayRealVector(result.getFirst(), false),
new Array2DRowRealMatrix(result.getSecond(), false));
}
public RealVector computeValue(double[] params)
{
return new ArrayRealVector(fun.computeValues(params), false);
}
public RealMatrix computeJacobian(double[] params)
{
return new Array2DRowRealMatrix(fun.computeJacobian(params), false);
}
});
}
else
{
// Compute separately
builder.model(new MultivariateVectorFunctionWrapper((NonLinearFunction) f, a, n),
new MultivariateMatrixFunctionWrapper((NonLinearFunction) f, a, n));
}
LeastSquaresProblem problem = builder.build();
Optimum optimum = optimizer.optimize(problem);
final double[] parameters = optimum.getPoint().toArray();
setSolution(a, parameters);
iterations = optimum.getIterations();
evaluations = optimum.getEvaluations();
if (a_dev != null)
{
try
{
double[][] covar = optimum.getCovariances(threshold).getData();
setDeviations(a_dev, covar);
}
catch (SingularMatrixException e)
{
// Matrix inversion failed. In order to return a solution
// return the reciprocal of the diagonal of the Fisher information
// for a loose bound on the limit
final int[] gradientIndices = f.gradientIndices();
final int nparams = gradientIndices.length;
GradientCalculator calculator = GradientCalculatorFactory.newCalculator(nparams);
final double[] I = calculator.fisherInformationDiagonal(n, a, (NonLinearFunction) f);
Arrays.fill(a_dev, 0);
for (int i = nparams; i-- > 0;)
a_dev[gradientIndices[i]] = FisherInformationMatrix.reciprocalSqrt(I[i]);
}
}
// Compute sum-of-squares
if (y_fit != null)
{
Gaussian2DFunction f = (Gaussian2DFunction) this.f;
f.initialise0(a);
f.forEach(new ValueProcedure()
{
int i = 0;
public void execute(double value)
{
y_fit[i] = value;
}
});
}
// As this is unweighted then we can do this to get the sum of squared residuals
value = optimum.getResiduals().dotProduct(optimum.getResiduals());
}
catch (TooManyEvaluationsException e)
{
return FitStatus.TOO_MANY_EVALUATIONS;
}
catch (TooManyIterationsException e)
{
return FitStatus.TOO_MANY_ITERATIONS;
}
catch (ConvergenceException e)
{
// Occurs when QR decomposition fails - mark as a singular non-linear model (no solution)
return FitStatus.SINGULAR_NON_LINEAR_MODEL;
}
catch (Exception e)
{
// TODO - Find out the other exceptions from the fitter and add return values to match.
return FitStatus.UNKNOWN;
}
return FitStatus.OK;
}
@Override
public boolean computeValue(double[] y, double[] y_fit, double[] a)
{
final int nparams = f.gradientIndices().length;
GradientCalculator calculator = GradientCalculatorFactory.newCalculator(nparams, false);
// Since we know the function is a Gaussian2DFunction
value = calculator.findLinearised(y.length, y, y_fit, a, (NonLinearFunction) f);
return true;
}
}