/*- * Copyright (c) 2012 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.fitting.functions; import java.io.Serializable; import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.Arrays; import org.eclipse.dawnsci.analysis.api.fitting.functions.IFunction; import org.eclipse.dawnsci.analysis.api.fitting.functions.IOperator; import org.eclipse.dawnsci.analysis.api.fitting.functions.IParameter; import org.eclipse.january.IMonitor; import org.eclipse.january.dataset.Comparisons; import org.eclipse.january.dataset.Dataset; import org.eclipse.january.dataset.DatasetFactory; import org.eclipse.january.dataset.DatasetUtils; import org.eclipse.january.dataset.DoubleDataset; import org.eclipse.january.dataset.IDataset; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Base abstract class for IFunction implementation. At a minimum, the fillWithValues() method needs * to be added. The fillWithPartialDerivativeValues() and/or calculatePartialDerivativeValues() * methods can be overridden if exact derivatives are needed. * * Note, if the implemented function can alter the number of parameters then it should call its * parent operator's update parameters method. */ public abstract class AFunction implements IFunction, Serializable { /** * Setup the logging facilities */ private static transient final Logger logger = LoggerFactory.getLogger(AFunction.class); /** * The array of parameters which specify all the variables in the minimisation problem */ protected IParameter[] parameters; /** * The name of the function, a description more than anything else. */ protected String name = "default"; /** * The description of the function */ protected String description = "default"; protected boolean dirty = true; protected IMonitor monitor = null; protected IOperator parent; protected void setNames(String name, String description, String... parameterNames) { this.name = name; this.description = description; int n = Math.min(parameterNames.length, parameters.length); for (int i = 0; i < n; i++) { IParameter p = getParameter(i); p.setName(parameterNames[i]); } } /** * Implement to set names and descriptions for function and its parameters */ protected abstract void setNames(); /** * Constructor for zero parameter functions */ public AFunction() { parameters = new IParameter[0]; setNames(); } /** * Constructor which simply generates the parameters but uninitialised * * @param numberOfParameters */ public AFunction(int numberOfParameters) { parameters = createParameters(numberOfParameters); setNames(); } protected static IParameter[] createParameters(int numberOfParameters) { IParameter[] params = new IParameter[numberOfParameters]; for (int i = 0; i < numberOfParameters; i++) { params[i] = new Parameter(); } return params; } /** * Constructor which takes a list of parameter values as its starting configuration * * @param params * An array of starting parameter values as doubles. */ public AFunction(double... params) { parameters = createParameters(params.length); setParameterValues(params); setNames(); } /** * Constructor which is given a set of parameters to begin with. * * @param params * An array of parameters */ public AFunction(IParameter... params) { parameters = createParameters(params.length); setParameters(params); setNames(); } /** * @param function * @param parameter * @return index of parameter or -1 if parameter is not in function */ public static int indexOfParameter(IFunction function, IParameter parameter) { if (function == null || parameter == null) return -1; if (function instanceof AFunction) return ((AFunction) function).indexOfParameter(parameter); for (int j = 0, jmax = function.getNoOfParameters(); j < jmax; j++) { if (parameter == function.getParameter(j)) { return j; } } return -1; } /** * @param parameter * @return index of parameter or -1 if parameter is not in function */ protected int indexOfParameter(IParameter parameter) { for (int i = 0; i < parameters.length; i++) { if (parameter == parameters[i]) { return i; } } return -1; } @Override public String getName() { return name; } @Override public void setName(String newName) { name = newName; } @Override public String getDescription() { return description; } @Override public void setDescription(String newDescription) { description = newDescription; } @Override public IParameter getParameter(int index) { return parameters[index]; } @Override public IParameter[] getParameters() { IParameter[] params = new IParameter[parameters.length]; for (int i = 0; i < parameters.length; i++) { params[i] = parameters[i]; } return params; } @Override public int getNoOfParameters() { return parameters.length; } @Override public double getParameterValue(int index) { return parameters[index].getValue(); } @Override final public double[] getParameterValues() { int n = getNoOfParameters(); double[] result = new double[n]; for (int j = 0; j < n; j++) { result[j] = getParameterValue(j); } return result; } @Override public void setParameter(int index, IParameter parameter) { if (indexOfParameter(parameter) == index) return; parameters[index] = parameter; dirty = true; } @Override public void setParameterValues(double... params) { int nparams = Math.min(params.length, parameters.length); for (int j = 0; j < nparams; j++) { parameters[j].setValue(params[j]); } dirty = true; } protected void setParameters(IParameter... params) { int nparams = Math.min(params.length, parameters.length); for (int j = 0; j < nparams; j++) { IParameter op = params[j]; IParameter np = parameters[j]; np.setValue(op.getValue()); np.setLimits(op.getLowerLimit(), op.getUpperLimit()); np.setFixed(op.isFixed()); } dirty = true; } @Override public String toString() { StringBuffer out = new StringBuffer(); int n = getNoOfParameters(); out.append(String.format("'%s' has %d parameters:\n", name, n)); for (int i = 0; i < n; i++) { IParameter p = getParameter(i); out.append(String.format("%d) %s = %g in range [%g, %g]\n", i, p.getName(), p.getValue(), p.getLowerLimit(), p.getUpperLimit())); } return out.toString(); } /** * This implementation is a numerical approximation. Overriding methods should check * for duplicated parameters before doing any calculation and either cope with this * or use this numerical approximation */ @Override public double partialDeriv(IParameter parameter, double... values) { if (indexOfParameter(parameter) < 0) return 0; return calcNumericalDerivative(A_TOLERANCE, R_TOLERANCE, parameter, values); } /** * @param param * @return true if there is more than one occurrence of given parameter in function */ protected boolean isDuplicated(IParameter param) { boolean found = false; int n = getNoOfParameters(); for (int i = 0; i < n; i++) { if (getParameter(i) == param) { if (found) { // found twice return true; } found = true; } } return false; } private final static double DELTA = 1/256.; // initial value private final static double DELTA_FACTOR = 0.25; protected final static double A_TOLERANCE = 1e-9; // absolute tolerance protected final static double R_TOLERANCE = 1e-9; // relative tolerance /** * @param param * @param delta * @return true if delta is large enough to change the parameter value */ private static final boolean isDeltaLargeEnough(IParameter param, double delta) { double v = Math.abs(param.getValue()); return (v == 0 ? delta : delta * v) > Math.ulp(v); } /** * @param abs * @param rel * @param param * @param values * @return partial derivative up to tolerances */ protected double calcNumericalDerivative(double abs, double rel, IParameter param, double... values) { double delta = DELTA; double previous = numericalDerivative(delta, param, values); double aprevious = Math.abs(previous); double current = 0; double acurrent = 0; double absDifference; double previousAbsDifference = Double.POSITIVE_INFINITY; final double absDifferenceRatio = 1.00; delta *= DELTA_FACTOR; while (isDeltaLargeEnough(param, delta)) { current = numericalDerivative(delta, param, values); acurrent = Math.abs(current); absDifference = Math.abs(current - previous); if (absDifference <= abs + rel*Math.max(acurrent, aprevious)) break; // If the difference is increasing, then we are no longer // approaching the convergence criterion. Assume we have just // passed the best we are going to get, and break, passing back // the previous value. if (absDifference > absDifferenceRatio * previousAbsDifference) { current = previous; break; } previousAbsDifference = absDifference; previous = current; aprevious = acurrent; delta *= DELTA_FACTOR; } return current; } /** * Calculate partial derivative. This is a numerical approximation. * @param param * @param values * @return partial derivative */ private double numericalDerivative(double delta, IParameter param, double... values) { double v = param.getValue(); double dv = delta * (v != 0 ? v : 1); param.setValue(v - dv); dirty = true; double minval = val(values); param.setValue(v + dv); dirty = true; double maxval = val(values); param.setValue(v); dirty = true; return (maxval - minval) / (2. * dv); } @Override public DoubleDataset calculateValues(IDataset... coords) { return calculateValues(null, coords); } private DoubleDataset calculateValues(int[] outShape, IDataset... coords) { CoordinatesIterator it = CoordinatesIterator.createIterator(outShape, coords); DoubleDataset result = DatasetFactory.zeros(DoubleDataset.class, it.getShape()); fillWithValues(result, it); result.setName(name); return result; } @Override public DoubleDataset calculatePartialDerivativeValues(IParameter parameter, IDataset... coords) { return calculatePartialDerivativeValues(null, parameter, coords); } private DoubleDataset calculatePartialDerivativeValues(int[] outShape, IParameter parameter, IDataset... coords) { CoordinatesIterator it = CoordinatesIterator.createIterator(outShape, coords); DoubleDataset result = DatasetFactory.zeros(DoubleDataset.class, it.getShape()); if (indexOfParameter(parameter) >= 0) internalFillWithPartialDerivativeValues(parameter, result, it); result.setName(name); return result; } private void internalFillWithPartialDerivativeValues(IParameter parameter, DoubleDataset data, CoordinatesIterator it) { if (isDuplicated(parameter)) { calcNumericalDerivativeDataset(A_TOLERANCE, R_TOLERANCE, parameter, data, it); } else { fillWithPartialDerivativeValues(parameter, data, it); } } /** * Fill dataset with values. Implementations should reset the iterator before use * @param data * @param it */ abstract public void fillWithValues(DoubleDataset data, CoordinatesIterator it); /** * Fill dataset with partial derivatives. Implementations should reset the iterator before use * <p> * This implementation is a numerical approximation. * <p> * Note that is called only if there are no duplicated parameters otherwise, * a numerical approximation is used. To change this behaviour, also override * {@link #calculatePartialDerivativeValues(IParameter, IDataset...)} * @param parameter * @param data * @param it */ public void fillWithPartialDerivativeValues(IParameter parameter, DoubleDataset data, CoordinatesIterator it) { calcNumericalDerivativeDataset(A_TOLERANCE, R_TOLERANCE, parameter, data, it); } /** * Calculate partial derivatives up to tolerances * @param abs * @param rel * @param param * @param data * @param it */ protected void calcNumericalDerivativeDataset(double abs, double rel, IParameter param, DoubleDataset data, CoordinatesIterator it) { DoubleDataset previous = DatasetFactory.zeros(DoubleDataset.class, it.getShape()); double delta = DELTA; fillWithNumericalDerivativeDataset(delta, param, previous, it); DoubleDataset current = DatasetFactory.zeros(DoubleDataset.class, it.getShape()); delta *= DELTA_FACTOR; while (isDeltaLargeEnough(param, delta)) { fillWithNumericalDerivativeDataset(delta, param, current, it); if (Comparisons.allCloseTo(previous, current, rel, abs)) break; DoubleDataset temp = previous; previous = current; current = temp; delta *= DELTA_FACTOR; } // if (!isDeltaLargeEnough(param, delta)) { // logger.warn("Numerical derivative did not converge!"); // } data.setSlice(current); } /** * Calculate partial derivative. This is a numerical approximation. * @param delta * @param param * @param data * @param it */ private void fillWithNumericalDerivativeDataset(double delta, IParameter param, DoubleDataset data, CoordinatesIterator it) { double v = param.getValue(); double dv = delta * (v != 0 ? v : 1); param.setValue(v + dv); dirty = true; fillWithValues(data, it); it.reset(); param.setValue(v - dv); dirty = true; DoubleDataset temp = DatasetFactory.zeros(DoubleDataset.class, it.getShape()); fillWithValues(temp, it); data.isubtract(temp); data.imultiply(0.5/dv); param.setValue(v); dirty = true; } /** * @return true if any parameters have changed */ public boolean isDirty() { return dirty; } @Override public void setDirty(boolean isDirty) { dirty = isDirty; } @Override public double residual(boolean allValues, IDataset data, IDataset weight, IDataset... coords) { double residual = 0; if (allValues) { DoubleDataset ddata = (DoubleDataset) DatasetUtils.convertToDataset(data).cast(Dataset.FLOAT64); residual = ddata.residual(calculateValues(ddata.getShapeRef(), coords), DatasetUtils.convertToDataset(weight), false); } else { // stochastic sampling of coords; // int NUMBER_OF_SAMPLES = 100; //TODO logger.error("Stochastic sampling has not been implemented yet"); throw new UnsupportedOperationException("Stochastic sampling has not been implemented yet"); } if (monitor != null) { monitor.worked(1); if (monitor.isCancelled()) { throw new IllegalMonitorStateException("Monitor cancelled"); } } return residual; } @Override public int hashCode() { final int prime = 31; int result = 1; result = prime * result + ((name == null) ? 0 : name.hashCode()); result = prime * result + Arrays.hashCode(parameters); return result; } @Override public boolean equals(Object obj) { if (this == obj) return true; if (obj == null) return false; if (getClass() != obj.getClass()) return false; AFunction other = (AFunction) obj; if (name == null) { if (other.name != null) return false; } else if (!name.equals(other.name)) return false; if (!Arrays.equals(parameters, other.parameters)) return false; return true; } @Override public AFunction copy() throws Exception { Constructor<? extends AFunction> c = getClass().getConstructor(IParameter[].class); //Makes a copy of each parameter, rather passing reference. int nParameters = parameters.length; IParameter[] paramCopy = new IParameter[nParameters]; for (int i = 0; i < parameters.length; i++) { paramCopy[i] = new Parameter(parameters[i]); } AFunction function = c.newInstance((Object) paramCopy); return function; } @Override public IMonitor getMonitor() { return monitor; } @Override public void setMonitor(IMonitor monitor) { this.monitor = monitor; } @Override public boolean isValid() { return true; } @Override public IOperator getParentOperator() { return parent; } @Override public void setParentOperator(IOperator parent) { this.parent = parent; } /** * Get the parameter values as an array, excluding parameters which are fixed * @return a double[] of non fixed parameter values */ public double[] getParameterValuesNoFixed() { ArrayList<Double> values = new ArrayList<Double>(); for (int i = 0; i < getNoOfParameters(); i++) { if (getParameter(i).isFixed() == false) { values.add(getParameter(i).getValue()); } } double[] start = new double[values.size()]; for (int i= 0; i < start.length; i++) { start[i] = values.get(i); } return start; } /** * Get the parameter upper bounds as an array, excluding parameters which are fixed * @return a double[] of non fixed parameter upper bounds */ public double[] getUpperBoundsNoFixed() { ArrayList<Double> values = new ArrayList<Double>(); for (int i = 0; i < getNoOfParameters(); i++) { if (getParameter(i).isFixed() == false) { values.add(getParameter(i).getUpperLimit()); } } double[] start = new double[values.size()]; for (int i= 0; i < start.length; i++) { start[i] = values.get(i); } return start; } /** * Get the parameter lower bounds as an array, excluding parameters which are fixed * @return a double[] of non fixed parameter lower bounds */ public double[] getLowerBoundsNoFixed() { ArrayList<Double> values = new ArrayList<Double>(); for (int i = 0; i < getNoOfParameters(); i++) { if (getParameter(i).isFixed() == false) { values.add(getParameter(i).getLowerLimit()); } } double[] start = new double[values.size()]; for (int i= 0; i < start.length; i++) { start[i] = values.get(i); } return start; } /** * Set the values of all non fixed parameters * @param values */ public void setParameterValuesNoFixed(double[] values) { int argpos = 0; for (int i = 0; i < getNoOfParameters(); i++) { if (getParameter(i).isFixed() == false) { getParameter(i).setValue(values[argpos]); argpos++; } } setDirty(true); } }