/*- * Copyright (c) 2012 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.fitting.functions; import java.text.DecimalFormat; import java.util.Arrays; import java.util.Collections; import java.util.Comparator; import java.util.List; import org.apache.commons.lang.ArrayUtils; import org.apache.commons.math3.complex.Complex; import org.ddogleg.solver.PolynomialOps; import org.ddogleg.solver.PolynomialRoots; import org.ddogleg.solver.RootFinderType; import org.eclipse.dawnsci.analysis.api.fitting.functions.IParameter; import org.eclipse.january.dataset.Dataset; import org.eclipse.january.dataset.DatasetFactory; import org.eclipse.january.dataset.DatasetUtils; import org.eclipse.january.dataset.DoubleDataset; import org.eclipse.january.dataset.Maths; import org.ejml.data.Complex64F; /** * Class that wrappers the equation <br> * y(x) = a_0 x^n + a_1 x^(n-1) + a_2 x^(n-2) + ... + a_(n-1) x + a_n */ public class Polynomial extends AFunction { private static final String NAME = "Polynomial"; private static final String DESC = "A polynomial of degree n." + "\n y(x) = a_0 x^n + a_1 x^(n-1) + a_2 x^(n-2) + ... + a_(n-1) x + a_n"; private transient double[] a; private transient int nparams; // actually degree + 1 /** * Basic constructor, not advisable to use */ public Polynomial() { this(0); } /** * Make a polynomial of given degree (0 - constant, 1 - linear, 2 - quadratic, etc) * * @param degree */ public Polynomial(final int degree) { super(degree + 1); } /** * Make a polynomial with given parameters * * @param params */ public Polynomial(double[] params) { super(params); } /** * Make a polynomial with given parameters * * @param params */ public Polynomial(IParameter... params) { super(params); } /** * Constructor that allows for the positioning of all the parameter bounds * * @param min * minimum boundaries * @param max * maximum boundaries */ public Polynomial(double[] min, double[] max) { super(0); if (min.length != max.length) { throw new IllegalArgumentException("Bound arrays must be of equal length"); } nparams = min.length; parameters = new Parameter[nparams]; a = new double[nparams]; for (int i = 0; i < nparams; i++) { a[i] = 0.5 * (min[i] + max[i]); parameters[i] = new Parameter(a[i], min[i], max[i]); } setNames(); } @Override protected void setNames() { if (isDirty() && nparams < getNoOfParameters()) { nparams = getNoOfParameters(); } String[] paramNames = new String[nparams]; for (int i = 0; i < nparams; i++) { paramNames[i] = "a_" + i; } setNames(NAME, DESC, paramNames); } private void calcCachedParameters() { if (a == null || a.length != nparams) { a = new double[nparams]; } for (int i = 0; i < nparams; i++) { a[i] = getParameterValue(i); } setDirty(false); } @Override public double val(double... values) { if (isDirty()) { calcCachedParameters(); } final double position = values[0]; double v = a[0]; for (int i = 1; i < nparams; i++) { v = v * position + a[i]; } return v; } @Override public void fillWithValues(DoubleDataset data, CoordinatesIterator it) { if (isDirty()) calcCachedParameters(); it.reset(); double[] coords = it.getCoordinates(); int i = 0; double[] buffer = data.getData(); while (it.hasNext()) { double v = a[0]; double p = coords[0]; for (int j = 1; j < nparams; j++) { v = v * p + a[j]; } buffer[i++] = v; } } @Override public double partialDeriv(IParameter parameter, double... position) { if (isDuplicated(parameter)) return super.partialDeriv(parameter, position); int i = indexOfParameter(parameter); if (i < 0) return 0; final double pos = position[0]; final int n = nparams - 1 - i; switch (n) { case 0: return 1.0; case 1: return pos; case 2: return pos * pos; default: return Math.pow(pos, n); } } @Override public void fillWithPartialDerivativeValues(IParameter parameter, DoubleDataset data, CoordinatesIterator it) { Dataset pos = DatasetUtils.convertToDataset(it.getValues()[0]); final int n = nparams - 1 - indexOfParameter(parameter); switch (n) { case 0: data.fill(1); break; case 1: data.setSlice(pos); break; case 2: Maths.square(pos, data); break; default: Maths.power(pos, n, data); break; } } /** * Create a 2D dataset which contains in each row a coordinate raised to n-th powers. * <p> * This is for solving the linear least squares problem * * @param coords * @return matrix */ public DoubleDataset makeMatrix(Dataset coords) { final int rows = coords.getSize(); DoubleDataset matrix = DatasetFactory.zeros(DoubleDataset.class, rows, nparams); for (int i = 0; i < rows; i++) { final double x = coords.getDouble(i); double v = 1.0; for (int j = nparams - 1; j >= 0; j--) { matrix.setItem(v, i, j); v *= x; } } return matrix; } /** * Set the degree after a class instantiation * * @param degree */ public void setDegree(int degree) { nparams = degree + 1; parameters = createParameters(nparams); dirty = true; setNames(); if (parent != null) { parent.updateParameters(); } } public String getStringEquation() { StringBuilder out = new StringBuilder(); DecimalFormat df = new DecimalFormat("0.#####E0"); for (int i = nparams - 1; i >= 2; i--) { out.append(df.format(parameters[nparams - 1 - i].getValue())); out.append(String.format("x^%d + ", i)); } if (nparams >= 2) out.append(df.format(parameters[nparams - 2].getValue()) + "x + "); if (nparams >= 1) out.append(df.format(parameters[nparams - 1].getValue())); return out.toString(); } /** * Find all roots * * @return all roots or null if there is any problem finding the roots */ public Complex[] findRoots() { if (isDirty()) { calcCachedParameters(); } return findRoots(a); } /** * Find all roots * * @param coeffs * @return all roots or null if there is any problem finding the roots */ public static Complex[] findRoots(double... coeffs) { double[] reverse = coeffs.clone(); ArrayUtils.reverse(reverse); double max = Double.NEGATIVE_INFINITY; for (double r : reverse) { max = Math.max(max, Math.abs(r)); } for (int i = 0; i < reverse.length; i++) { reverse[i] /= max; } org.ddogleg.solver.Polynomial p = org.ddogleg.solver.Polynomial.wrap(reverse); PolynomialRoots rf = PolynomialOps.createRootFinder(p.computeDegree(), RootFinderType.EVD); if (rf.process(p)) { // reorder to NumPy's roots output List<Complex64F> rts = rf.getRoots(); Complex[] out = new Complex[rts.size()]; int i = 0; for (Complex64F r : rts) { out[i++] = new Complex(r.getReal(), r.getImaginary()); } return sort(out); } return null; } private static Complex[] sort(Complex[] values) { // reorder to NumPy's roots output List<Complex> rts = Arrays.asList(values); Collections.sort(rts, new Comparator<Complex>() { @Override public int compare(Complex o1, Complex o2) { double a = o1.getReal(); double b = o2.getReal(); double u = 10 * Math.ulp(Math.max(Math.abs(a), Math.abs(b))); if (Math.abs(a - b) > u) return a < b ? -1 : 1; a = o1.getImaginary(); b = o2.getImaginary(); if (a == b) return 0; return a < b ? 1 : -1; } }); return rts.toArray(new Complex[0]); } }