/*- * Copyright (c) 2012 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.fitting.functions; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.LUDecomposition; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.NonSquareMatrixException; import org.eclipse.dawnsci.analysis.api.fitting.functions.IParameter; import org.eclipse.january.dataset.DoubleDataset; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * A ND Gaussian function * * The parameters are mean peak position coordinates (N), volume (1), diagonal elements of covariance matrix (N) * and normalized upper triangle elements of covariance matrix (N*(N-1)/2). The last set of parameters * are normalized by the diagonal elements. */ public class GaussianND extends AFunction { private static String NAME = "GaussianND"; private static String DESC = "An N-dimensional Gaussian function." + "\nThe parameters are mean peak position coordinates (N), volume (1)," + "\ndiagonal elements of covariance matrix (N) and normalized upper triangle" + "\nelements of covariance matrix N*(N-1)/2. The last set of parameters are" + "\nnormalized by the diagonal elements."; private transient int rank; private transient double[] pos = null; /** * Setup the logging facilities */ private static transient final Logger logger = LoggerFactory.getLogger(GaussianND.class); public GaussianND() { super(new double[]{0,0,0}); } public GaussianND(IParameter... params) { super(params); } /** * Constructor which takes the N + 1 + N + N(N-1)/2 properties required, which are volume, elements of the mean vector, * diagonal elements and upper triangle elements of the covariance matrix (as this is symmetric) * @param params */ public GaussianND(double... params) { super(params); int nparams = params.length; // check if correct number of parameters given int guess = -1; for (rank = 0; guess < nparams; ) { guess = 1 + (rank * (rank + 3)) / 2; } if (guess != nparams) { logger.error("Given number of parameters {} is not equal to {}", nparams, guess); throw new IllegalArgumentException("Given number of parameters " + nparams + " is not equal to" + guess); } // check triangle values int n = 2*rank + 1; for (int i = 0; i < nparams; i++) { double tri = getParameter(n).getValue(); if (Math.abs(tri) > 1) { logger.error("Parameter {} ({}) is outside valid range [-1,1]", i, tri); throw new IllegalArgumentException("Parameter " + i + " (" + tri + ") is outside valid range [-1,1]"); } } } /** * Create a multi-dimensional Gaussian function * @param maxVol maximum "volume" * @param minPeakPosition array containing minimum peak position * @param maxPeakPosition array containing maximum peak position * @param maxSigma maximum magnitude for any element in covariance matrix */ public GaussianND(double maxVol, double[] minPeakPosition, double[] maxPeakPosition, double maxSigma) { super(1 + (minPeakPosition.length*(minPeakPosition.length+3))/2); rank = minPeakPosition.length; if (maxPeakPosition.length != rank) { logger.error("Two vectors are not of same length"); throw new IllegalArgumentException("Two vectors are not of same length"); } int n = 0; IParameter p; for (int i = 0; i < rank; i++) { p = getParameter(n++); p.setLowerLimit(minPeakPosition[i]); p.setUpperLimit(maxPeakPosition[i]); p.setValue((minPeakPosition[i] + maxPeakPosition[i]) / 2.0); } p = getParameter(n++); p.setLowerLimit(0); p.setUpperLimit(maxVol); p.setValue(maxVol / 2.0); double sigmasq = maxSigma * maxSigma; for (int i = 0; i < rank; i++) { p = getParameter(n++); p.setLowerLimit(0); p.setUpperLimit(sigmasq); p.setValue(sigmasq/100.); } for (int i = 0; i < rank; i++) { for (int j = i + 1; j < rank; j++) { p = getParameter(n++); p.setLowerLimit(-1); p.setUpperLimit(1); p.setValue(0); } } } @Override protected void setNames() { setNames(NAME, DESC); } /** * Setting peak position * @param pos */ public void setPeakPosition(double[] pos) { if (pos.length != rank) { logger.error("Peak position vector has wrong length"); throw new IllegalArgumentException("Peak position vector has wrong length"); } for (int i = 0; i < rank; i++) { getParameter(i).setValue(pos[i]); } } /** * Setting volume of Gaussian (the integrated value) * @param volume */ public void setVolume(double volume) { getParameter(rank).setValue(volume); } /** * @return maximum value */ public double getPeakValue() { if (isDirty()) calcCachedParameters(); return norm; } private transient Array2DRowRealMatrix invcov; // inverse of covariance matrix private transient double norm; private void calcCachedParameters() { if (pos == null || pos.length != rank) { pos = new double[rank]; } int n = 0; for (int i = 0; i < rank; i++) { pos[i] = getParameterValue(n); n++; } // logger.info("New pos at {}", pos); norm = getParameterValue(n); n++; if (rank == 0) return; Array2DRowRealMatrix covar = (Array2DRowRealMatrix) MatrixUtils.createRealMatrix(rank, rank); for (int i = 0; i < rank; i++) { covar.setEntry(i, i, getParameterValue(n)); n++; } for (int i = 0; i < rank; i++) { double diagi = Math.sqrt(covar.getEntry(i, i)); for (int j = i + 1; j < rank; j++) { double diag = Math.sqrt(covar.getEntry(j, j))*diagi; double el = diag*getParameterValue(n); covar.setEntry(i, j, el); covar.setEntry(j, i, el); n++; } } // logger.info("New cov {}", covar); LUDecomposition decomp = null; try { decomp = new LUDecomposition(covar); } catch (NonSquareMatrixException e) { logger.error("Non-square covariance matrix"); throw new IllegalArgumentException("Non-square covariance matrix"); } invcov = (Array2DRowRealMatrix) decomp.getSolver().getInverse(); // logger.info("Inverse covariance matrix is {}", invcov); norm /= Math.sqrt(Math.pow(2.*Math.PI, rank) * decomp.getDeterminant()); // logger.info("Normalization factor is {}", norm); setDirty(false); } @Override public double val(double... values) { if (isDirty()) calcCachedParameters(); double[] v = values.clone(); for (int i = 0; i < v.length; i++) v[i] -= pos[i]; double[] u = invcov.operate(v); double arg = 0; for (int i = 0; i < v.length; i++) arg += u[i] * v[i]; return norm * Math.exp(-0.5 * arg); } @Override public void fillWithValues(DoubleDataset data, CoordinatesIterator it) { if (isDirty()) calcCachedParameters(); it.reset(); double[] coords = it.getCoordinates(); int j = 0; double[] buffer = data.getData(); while (it.hasNext()) { double[] v = coords.clone(); for (int i = 0; i < v.length; i++) { v[i] -= pos.length > 0 ? pos[i] : 0; } if (invcov == null) return; double[] u = invcov.operate(v); double arg = 0; for (int i = 0; i < v.length; i++) arg += u[i] * v[i]; buffer[j++] = norm * Math.exp(-0.5 * arg); } } }