/*
* Copyright (c) 2012 Diamond Light Source Ltd.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Public License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/legal/epl-v10.html
*/
package uk.ac.diamond.scisoft.analysis.optimize;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;
import org.apache.commons.math3.exception.TooManyEvaluationsException;
import org.apache.commons.math3.fitting.leastsquares.EvaluationRmsChecker;
import org.apache.commons.math3.fitting.leastsquares.GaussNewtonOptimizer;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresBuilder;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer.Optimum;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer;
import org.apache.commons.math3.fitting.leastsquares.MultivariateJacobianFunction;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.SimplePointChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionPenaltyAdapter;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunctionGradient;
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer.Formula;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.MultiDirectionalSimplex;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.PowellOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
import org.apache.commons.math3.random.Well19937c;
import org.apache.commons.math3.util.Pair;
import org.eclipse.dawnsci.analysis.api.fitting.functions.IParameter;
import org.eclipse.january.IMonitor;
import org.eclipse.january.dataset.Dataset;
import org.eclipse.january.dataset.DatasetFactory;
import org.eclipse.january.dataset.DatasetUtils;
import org.eclipse.january.dataset.DoubleDataset;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import uk.ac.diamond.scisoft.analysis.fitting.functions.AFunction;
import uk.ac.diamond.scisoft.analysis.fitting.functions.CoordinatesIterator;
public class ApacheOptimizer extends AbstractOptimizer implements ILeastSquaresOptimizer {
private static Logger logger = LoggerFactory.getLogger(ApacheOptimizer.class);
public MultivariateFunction createFunction() {
// provide the fitting function which wrappers all the normal fitting functionality
MultivariateFunction f = new MultivariateFunction() {
@Override
public double value(double[] parameters) {
return calculateResidual(parameters);
}
};
return f;
}
public MultivariateVectorFunction createGradientFunction() {
MultivariateVectorFunction f = new MultivariateVectorFunction() {
@Override
public double[] value(double[] parameters) throws IllegalArgumentException {
double[] result = new double[n];
for (int i = 0; i < n; i++) {
result[i] = calculateResidualDerivative(params.get(i), parameters);
}
return result;
}
};
return f;
}
public MultivariateJacobianFunction createJacobianFunction() {
final int size = coords[0].getSize();
final AFunction afn;
final CoordinatesIterator it;
final DoubleDataset vd, pvd;
final double[][] dm = new double[size][n];
if (function instanceof AFunction) {
afn = (AFunction) function;
it = CoordinatesIterator.createIterator(data == null ? null : data.getShapeRef(), coords);
vd = (DoubleDataset) DatasetFactory.zeros(coords[0].getShapeRef(), Dataset.FLOAT64);
pvd = vd.clone();
} else {
afn = null;
it = null;
vd = null;
pvd = null;
}
MultivariateJacobianFunction f = new MultivariateJacobianFunction() {
@SuppressWarnings("null")
@Override
public Pair<RealVector, RealMatrix> value(RealVector point) {
IMonitor monitor = function.getMonitor();
if (monitor != null && monitor.isCancelled()) {
throw new IllegalMonitorStateException("Monitor cancelled");
}
if (point instanceof ArrayRealVector) {
setParameterValues(((ArrayRealVector) point).getDataRef());
} else {
setParameterValues(point.toArray());
}
final double[] dv ;
if (afn != null) {
dv = vd.getData();
AFunction afn = (AFunction) function;
afn.fillWithValues(vd, it);
double[] pd = pvd.getData();
for (int i = 0; i < n; i++) { // assuming number of parameters is less than number of coordinates
IParameter p = params.get(i);
afn.fillWithPartialDerivativeValues(p, pvd, it);
for (int j = 0; j < size; j++) {
dm[j][i] = pd[j];
}
}
} else {
dv = calculateValues().getData();
for (int i = 0; i < n; i++) {
IParameter p = params.get(i);
DoubleDataset dp = (DoubleDataset) DatasetUtils.cast(function.calculatePartialDerivativeValues(p, coords), Dataset.FLOAT64);
double[] pd = dp.getData();
for (int j = 0; j < size; j++) {
dm[j][i] = pd[j];
}
}
}
return new Pair<RealVector, RealMatrix>(new ArrayRealVector(dv, false), new Array2DRowRealMatrix(dm, false));
}
};
return f;
}
public Long seed = null;
private static final int MAX_ITER = 10000;
private static final int MAX_EVAL = 1000000;
private static final double REL_TOL = 1e-7;
private static final double ABS_TOL = 1e-15;
public enum Optimizer { SIMPLEX_MD, SIMPLEX_NM, POWELL, CMAES, BOBYQA, CONJUGATE_GRADIENT, GAUSS_NEWTON, LEVENBERG_MARQUARDT }
private Optimizer optimizer;
private double[] errors = null;
public ApacheOptimizer(Optimizer opt) {
optimizer = opt;
}
private MultivariateOptimizer createOptimizer() {
SimplePointChecker<PointValuePair> checker = new SimplePointChecker<PointValuePair>(REL_TOL, ABS_TOL);
switch (optimizer) {
case CONJUGATE_GRADIENT:
return new NonLinearConjugateGradientOptimizer(Formula.POLAK_RIBIERE, checker);
case BOBYQA:
return new BOBYQAOptimizer(n + 2);
case CMAES:
return new CMAESOptimizer(MAX_ITER, 0., true, 0, 10, seed == null ? new Well19937c() : new Well19937c(seed),
false, new SimplePointChecker<PointValuePair>(REL_TOL, ABS_TOL));
case POWELL:
return new PowellOptimizer(REL_TOL, ABS_TOL, checker);
case SIMPLEX_MD:
case SIMPLEX_NM:
return new SimplexOptimizer(checker);
default:
throw new IllegalStateException("Should not be called");
}
}
private LeastSquaresOptimizer createLeastSquaresOptimizer() {
switch (optimizer) {
case GAUSS_NEWTON:
return new GaussNewtonOptimizer();
case LEVENBERG_MARQUARDT:
return new LevenbergMarquardtOptimizer();
default:
throw new IllegalStateException("Should not be called");
}
}
private SimpleBounds createBounds() {
double[] lb = new double[n];
double[] ub = new double[n];
for (int i = 0; i < n; i++) {
IParameter p = params.get(i);
lb[i] = p.getLowerLimit();
ub[i] = p.getUpperLimit();
}
return new SimpleBounds(lb, ub);
}
@Override
void internalOptimize() throws Exception {
switch (optimizer) {
case LEVENBERG_MARQUARDT:
case GAUSS_NEWTON:
internalLeastSquaresOptimize();
break;
default:
internalScalarOptimize();
break;
}
}
private void internalScalarOptimize() {
MultivariateOptimizer opt = createOptimizer();
SimpleBounds bd = createBounds();
double offset = 1e12;
double[] scale = new double[n];
for (int i = 0; i < n; i++) {
scale[i] = offset*0.25;
}
MultivariateFunction fn = createFunction();
if (optimizer == Optimizer.SIMPLEX_MD || optimizer == Optimizer.SIMPLEX_NM) {
fn = new MultivariateFunctionPenaltyAdapter(fn, bd.getLower(), bd.getUpper(), offset, scale);
}
ObjectiveFunction of = new ObjectiveFunction(fn);
InitialGuess ig = new InitialGuess(getParameterValues());
MaxEval me = new MaxEval(MAX_EVAL);
double min = Double.POSITIVE_INFINITY;
double res = Double.NaN;
try {
PointValuePair result;
switch (optimizer) {
case CONJUGATE_GRADIENT:
// af = new MultivariateFunctionPenaltyAdapter(fn, bd.getLower(), bd.getUpper(), offset, scale);
result = opt.optimize(ig, GoalType.MINIMIZE, of, me,
new ObjectiveFunctionGradient(createGradientFunction()));
break;
case BOBYQA:
result = opt.optimize(ig, GoalType.MINIMIZE, of, me, bd);
break;
case CMAES:
double[] sigma = new double[n];
for (int i = 0; i < n; i++) {
IParameter p = params.get(i);
double v = p.getValue();
double r = Math.max(p.getUpperLimit()-v, v - p.getLowerLimit());
sigma[i] = r*0.05; // 5% of range
}
int p = (int) Math.ceil(4 + Math.log(n)) + 1;
logger.trace("Population size: {}", p);
result = opt.optimize(ig, GoalType.MINIMIZE, of, me, bd,
new CMAESOptimizer.Sigma(sigma), new CMAESOptimizer.PopulationSize(p));
break;
case SIMPLEX_MD:
result = opt.optimize(ig, GoalType.MINIMIZE, of, me, new MultiDirectionalSimplex(n));
break;
case SIMPLEX_NM:
result = opt.optimize(ig, GoalType.MINIMIZE, of, me, new NelderMeadSimplex(n));
break;
default:
throw new IllegalStateException("Should not be called");
}
// logger.info("Q-space fit: rms = {}, x^2 = {}", opt.getRMS(), opt.getChiSquare());
double ires = calculateResidual(opt.getStartPoint());
logger.trace("Residual: {} from {}", result.getValue(), ires);
res = result.getValue();
if (res < min)
setParameterValues(result.getPoint());
logger.trace("Used {} evals and {} iters", opt.getEvaluations(), opt.getIterations());
// System.err.printf("Used %d evals and %d iters\n", opt.getEvaluations(), opt.getIterations());
// logger.info("Q-space fit: rms = {}, x^2 = {}", opt.getRMS(), opt.getChiSquare());
} catch (IllegalArgumentException e) {
logger.error("Start point has wrong dimension", e);
// should not happen!
} catch (TooManyEvaluationsException e) {
throw new IllegalArgumentException("Could not fit as optimizer did not converge");
// logger.error("Convergence problem: max iterations ({}) exceeded", opt.getMaxIterations());
}
}
/**
* create a multivariateJacobianFunction from MVF and MMF (using builder?)
*
*/
private void internalLeastSquaresOptimize() {
LeastSquaresOptimizer opt = createLeastSquaresOptimizer();
try {
LeastSquaresBuilder builder = new LeastSquaresBuilder().model(createJacobianFunction())
.target(data.getData()).start(getParameterValues()).lazyEvaluation(false)
.maxEvaluations(MAX_EVAL).maxIterations(MAX_ITER);
builder.checker(new EvaluationRmsChecker(REL_TOL, ABS_TOL));
if (weight != null) {
builder.weight(MatrixUtils.createRealDiagonalMatrix(weight.getData()));
}
// TODO add checker, validator
LeastSquaresProblem problem = builder.build();
Optimum result = opt.optimize(problem);
RealVector res = result.getPoint();
setParameterValues(res instanceof ArrayRealVector ? ((ArrayRealVector) res).getDataRef() : res.toArray());
try {
RealVector err = result.getSigma(1e-14);
// sqrt(S / (n - m) * C[i][i]);
double c = result.getCost();
int n = data.getSize();
int m = getParameterValues().length;
double[] s = err instanceof ArrayRealVector ? ((ArrayRealVector) err).getDataRef() : err.toArray();
errors = new double[s.length];
for (int i = 0; i < errors.length; i++) errors[i] = Math.sqrt(((c*c)/((n-m)) * (s[i]*s[i])));
} catch (SingularMatrixException e) {
logger.warn("Could not find errors as covariance matrix was singular");
}
logger.trace("Residual: {} from {}", result.getRMS(), Math.sqrt(calculateResidual()));
} catch (Exception e) {
logger.error("Problem with least squares optimizer", e);
throw new IllegalArgumentException("Problem with least squares optimizer");
}
}
@Override
public double[] guessParametersErrors() {
return errors;
}
}