/* * Copyright (c) 2012 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.optimize; import org.eclipse.january.dataset.Dataset; import org.eclipse.january.dataset.DatasetUtils; import org.eclipse.january.dataset.Maths; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import Jama.Matrix; import Jama.SingularValueDecomposition; /** * Basic least squares solver using SVD method */ public class LinearLeastSquares { /** * Setup the logging facilities */ private static final Logger logger = LoggerFactory.getLogger(LinearLeastSquares.class); private final double threshold; // threshold ratio /** * Base constructor * * @param tolerance * ratio of lowest to highest singular value to allow */ public LinearLeastSquares(double tolerance) { threshold = tolerance; } /** * Solve linear least square problem defined by * <pre> * A x = b * </pre> * @param matrix <pre>A</pre> * @param data <pre>b</pre> * @param sigmasq estimate of squared error on each data point * @return array of values */ public double[] solve(Dataset matrix, Dataset data, Dataset sigmasq) { if (matrix.getRank() != 2) { logger.error("Matrix was not 2D"); throw new IllegalArgumentException("Matrix was not 2D"); } final int[] shape = matrix.getShape(); final int dlen = data.getShape()[0]; if (data.getRank() != 1 && shape[1] != dlen) { logger.error("Data was not 1D or else not correct length"); throw new IllegalArgumentException("Data was not 2D or else not correct length"); } final Matrix X = new Matrix((double [][]) DatasetUtils.createJavaArray(matrix.cast(Dataset.FLOAT64))); final Matrix W = new Matrix((double [][]) DatasetUtils.createJavaArray(DatasetUtils.diag(Maths.reciprocal(sigmasq.cast(Dataset.FLOAT64)), 0))); final Matrix XtW = X.transpose().times(W); final Matrix A = XtW.times(X); final SingularValueDecomposition svd = A.svd(); final double[] values = svd.getSingularValues(); final Matrix b = new Matrix(dlen, dlen); for (int i = 0; i < dlen; i++) { b.set(i, 0, data.getDouble(i)); } final Matrix c = svd.getV().transpose().times(XtW.times(b)); final double limit = threshold*values[0]; final int vlen = values.length; for (int i = 0; i < vlen; i++) { final double v = values[i]; c.set(i, 0, v >= limit ? c.get(i, 0) / v : 0); } final Matrix d = svd.getU().times(c); final double[] result = new double[vlen]; for (int i = 0; i < vlen; i++) { result[i] = d.get(i, 0); } return result; } }