/*- * Copyright 2014 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.diffraction; import org.apache.commons.math3.analysis.MultivariateFunction; import org.apache.commons.math3.exception.TooManyEvaluationsException; import org.apache.commons.math3.optim.InitialGuess; import org.apache.commons.math3.optim.MaxEval; import org.apache.commons.math3.optim.PointValuePair; import org.apache.commons.math3.optim.SimpleBounds; import org.apache.commons.math3.optim.SimplePointChecker; import org.apache.commons.math3.optim.nonlinear.scalar.GoalType; import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateFunctionPenaltyAdapter; import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer; import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction; import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer; import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer; import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.MultiDirectionalSimplex; import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer; import org.apache.commons.math3.random.Well19937c; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Utilities for fitting */ public class FittingUtils { private static Logger logger = LoggerFactory.getLogger(FittingUtils.class); public static Long seed = null; private static final int MAX_ITER = 10000; private static final int MAX_EVAL = 1000000; private static final double REL_TOL = 1e-7; private static final double ABS_TOL = 1e-15; public enum Optimizer { Simplex, CMAES, BOBYQA } static Optimizer optimizer = Optimizer.CMAES; /** * @param n * @return optimizer */ public static MultivariateOptimizer createOptimizer(int n) { return createOptimizer(optimizer, n); } /** * @param opt * @param n * @return optimizer */ public static MultivariateOptimizer createOptimizer(Optimizer opt, int n) { switch (opt) { case BOBYQA: return new BOBYQAOptimizer(n + 2); case CMAES: return new CMAESOptimizer(MAX_ITER, 0., true, 0, 10, seed == null ? new Well19937c() : new Well19937c(seed), false, new SimplePointChecker<PointValuePair>(REL_TOL, ABS_TOL)); case Simplex: default: return new SimplexOptimizer(new SimplePointChecker<PointValuePair>(REL_TOL*1e3, ABS_TOL*1e3)); } } /** * Optimize given function * @param f * @param opt * @param min * @return residual */ public static double optimize(FitFunction f, MultivariateOptimizer opt, double min) { double res = Double.NaN; try { PointValuePair result; if (opt instanceof BOBYQAOptimizer) { result = opt.optimize(new InitialGuess(f.getInitial()), GoalType.MINIMIZE, new ObjectiveFunction(f), new MaxEval(MAX_EVAL), f.getBounds()); } else if (opt instanceof CMAESOptimizer) { int p = (int) Math.ceil(4 + Math.log(f.getN())) + 1; logger.trace("Population size: {}", p); result = opt.optimize(new InitialGuess(f.getInitial()), GoalType.MINIMIZE, new ObjectiveFunction(f), new CMAESOptimizer.Sigma(f.getSigma()), new CMAESOptimizer.PopulationSize(p), new MaxEval(MAX_EVAL), f.getBounds()); } else { int n = f.getN(); double offset = 1e12; double[] scale = new double[n]; for (int i = 0; i < n; i++) { scale[i] = offset*0.25; } SimpleBounds bnds = f.getBounds(); MultivariateFunctionPenaltyAdapter of = new MultivariateFunctionPenaltyAdapter(f, bnds.getLower(), bnds.getUpper(), offset, scale); result = opt.optimize(new InitialGuess(f.getInitial()), GoalType.MINIMIZE, new ObjectiveFunction(of), new MaxEval(MAX_EVAL), new MultiDirectionalSimplex(n)); // new NelderMeadSimplex(n)); } // logger.info("Q-space fit: rms = {}, x^2 = {}", opt.getRMS(), opt.getChiSquare()); double ires = f.value(opt.getStartPoint()); logger.trace("Residual: {} from {}", result.getValue(), ires); res = result.getValue(); if (res < min) f.setParameters(result.getPoint()); logger.trace("Used {} evals and {} iters", opt.getEvaluations(), opt.getIterations()); // System.err.printf("Used %d evals and %d iters\n", opt.getEvaluations(), opt.getIterations()); // logger.info("Q-space fit: rms = {}, x^2 = {}", opt.getRMS(), opt.getChiSquare()); } catch (IllegalArgumentException e) { logger.error("Start point has wrong dimension", e); // should not happen! } catch (TooManyEvaluationsException e) { throw new IllegalArgumentException("Could not fit as optimizer did not converge"); // logger.error("Convergence problem: max iterations ({}) exceeded", opt.getMaxIterations()); } return res; } /** * Basic fit function interface */ public interface FitFunction extends MultivariateFunction { /** * Set stored parameters * @param arg */ public void setParameters(double[] arg); /** * @return stored parameters */ public double[] getParameters(); /** * @return standard deviations which may be used for sampling parameter space */ public double[] getSigma(); /** * @return bounds on parameters */ public SimpleBounds getBounds(); /** * @return initial parameters */ public double[] getInitial(); /** * Set initial parameters * @param init */ public void setInitial(double... init); /** * @return number of parameters */ public int getN(); } }