/*- * Copyright 2017 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.dataset.function; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.eclipse.dawnsci.analysis.api.roi.IRectangularROI; import org.eclipse.dawnsci.analysis.dataset.DatasetCache; import org.eclipse.dawnsci.analysis.dataset.impl.FFT; import org.eclipse.dawnsci.analysis.dataset.impl.Signal; import org.eclipse.dawnsci.analysis.dataset.impl.function.DatasetToDatasetFunction; import org.eclipse.january.IMonitor; import org.eclipse.january.dataset.Dataset; import org.eclipse.january.dataset.DatasetFactory; import org.eclipse.january.dataset.DatasetUtils; import org.eclipse.january.dataset.DoubleDataset; import org.eclipse.january.dataset.IDataset; import org.eclipse.january.dataset.LinearAlgebra; import org.eclipse.january.dataset.Maths; import org.eclipse.january.dataset.Slice; import org.eclipse.january.dataset.SliceND; /** * Register 1D data using a cross correlation method that has sub-pixel accuracy * <p> * This is suitable for noisy 1D data */ public class RegisterData1D implements DatasetToDatasetFunction { private IRectangularROI roi = null; private double tukeyWidth = 0.0; private int[] shape; private int[] pShape; // padded shape private Dataset window = null; // window function private SliceND slice = null; private boolean dirty = true; private Dataset tFilter = null; private Dataset filter; private DatasetCache filtered; private IMonitor monitor = null; private boolean shiftImage = true; public RegisterData1D() { filtered = new DatasetCache(new DatasetToDatasetFunction() { @Override public List<? extends IDataset> value(IDataset... datasets) { List<Dataset>list = new ArrayList<>(); for (IDataset d : datasets) { list.add(pTF(d)); } return list; } }); } /** * @param monitor The monitor to set. */ public void setMonitor(IMonitor monitor) { this.monitor = monitor; } private int[] padShape(int[] shape) { int[] s = shape.clone(); // for (int i = 0; i < s.length; i++) { // s[i] += shape[i] - 1; // pad // } return s; } public void setShiftImage(boolean shift) { this.shiftImage = shift; } public void update(int[] shape) { this.shape = shape; update(); } private void update() { int[] wShape; if (roi == null) { slice = new SliceND(shape); pShape = padShape(shape); wShape = shape; } else { double[] beg = roi.getPoint(); double[] end = roi.getEndPoint(); // use row-major ordering slice = new SliceND(shape, new Slice((int) Math.floor(beg[0]), (int) Math.ceil(end[0]))); wShape = slice.getShape(); pShape = padShape(wShape); } window = Signal.tukeyWindow(wShape[0], tukeyWidth); if (filter != null) { Dataset nfilter = DatasetFactory.zeros(pShape); nfilter.setSlice(filter, null, filter.getShapeRef(), null); tFilter = FFT.fftn(nfilter, pShape, null); } dirty = false; } /** * Set filter to use for convolving images * @param filter */ public void setFilter(IDataset filter) { this.filter = DatasetUtils.convertToDataset(filter); dirty = true; } /** * Preprocess, transform and filter * @param image * @return result */ public Dataset pTF(IDataset image) { // TODO use gradient images (dx, dy) as complex pair DoubleDataset preprocessed = DatasetUtils.cast(DoubleDataset.class, image.getSlice(slice)); preprocessed.isubtract(preprocessed.mean()); if (window != null) preprocessed.imultiply(window); Dataset transform = FFT.fftn(preprocessed, pShape, null); if (tFilter != null) { transform.imultiply(tFilter); } return transform; } public Dataset getPTF(IDataset image) { return filtered.get(image); } /** * Set width for window function. Default is 0 * @param width for Tukey window */ public void setWindowFunction(double width) { tukeyWidth = width; dirty = true; } /** * * @param rectangle (can be null) */ public void setRectangle(IRectangularROI rectangle) { roi = rectangle; dirty = true; } private static final double R_LIMIT = 1; /** * @param datasets array of datasets * @return pairs of datasets of shift and shifted images */ @Override public List<Dataset> value(IDataset... datasets) { if (monitor != null) { monitor.subTask("Registering images"); } if (dirty) { if (datasets == null || datasets.length == 0) { throw new IllegalArgumentException("Need at least one image"); } IDataset image = datasets[0]; if (image.getRank() != 1) { throw new IllegalArgumentException("Dataset must be 1D"); } shape = image.getShape(); update(); } List<Dataset> result = new ArrayList<Dataset>(); int n = datasets.length; /* * Work out shifts for each of n(n-1)/2 pairs * * r_i = displacement between I_i and I_(i+1) * b_(ij) = computed distance between I_i and I_j (where i < j) * * then * b_(01) = r_0 * b_(02) = r_0 + r_1 * ... * b_(0(N-1)) = r_0 + r_1 + r_(N-2) * b_(12) = r_1 * ... * ... * ... * b_((N-2)(N-1)) = r_(N-2) * or * b_(ij) = A_((ij)k) r_k, k = [0, N-2] * which can be solved as a linear least squares problem: * r = (A^T A)^-1 A^T b */ int m = n*(n-1)/2; // total number of knowns // create all values for A and b as list of rows double shift; List<double[]> list = new ArrayList<>(); int rlen = n; for (int i = 0; i < n - 1; i++) { double[] rowk0 = new double[rlen]; rowk0[i] = 1; Dataset cf = Maths.conjugate(filtered.get(datasets[i])); shift = ccFindShift(cf, datasets[i+1]); if (monitor != null) { if(monitor.isCancelled()) { return result; } monitor.worked(1); } rowk0[rlen - 1] = shift; list.add(rowk0); for (int j = i + 2; j < n; j++) { rowk0 = new double[rlen]; for (int l = i; l < j; l++) { rowk0[l] = 1; } shift = ccFindShift(cf, datasets[j]); if (monitor != null) { if(monitor.isCancelled()) { return result; } monitor.worked(1); } rowk0[rlen - 1] = shift; list.add(rowk0); } } assert list.size() == m; boolean[] use = new boolean[m]; Arrays.fill(use, true); DoubleDataset[] fit = null; int used = 0; do { // need to check residual and prune outliers // loop if any false until too many rows are dropped // (used > (n - 1) for least square) fit = fitLeastSquares(list, use); DoubleDataset residuals = fit[1]; int rows = residuals.getSize(); used = 0; for (int i = 0, k = 0; i < m; i++) { if (use[i]) { if (residuals.getDouble(k) > R_LIMIT) { use[i] = false; } else { used++; } k++; } } if (used == rows) { // no change break; } if (monitor != null) { if(monitor.isCancelled()) { return result; } monitor.worked(1); } // System.err.println("Useable " + Arrays.toString(use)); } while (used >= rlen - 1); if (used < rlen - 1) { // dropped too many rows return null; } @SuppressWarnings("null") DoubleDataset vecr = fit[0]; shift = 0; result.add(DatasetFactory.createFromObject(shift)); Dataset shiftedImage = shiftImage ? DatasetUtils.convertToDataset(datasets[0]).clone() : null; result.add(shiftedImage); for (int i = 1; i < n; i++) { int x = i - 1; shift += vecr.get(x); // System.err.println("Cumulative shifts for " + i + ": " + Arrays.toString(shift)); result.add(DatasetFactory.createFromObject(shift)); shiftedImage = shiftImage ? shiftData(DatasetUtils.convertToDataset(datasets[i]), shift) : null; result.add(shiftedImage); if (shiftImage && monitor != null) { if(monitor.isCancelled()) { return result; } monitor.worked(1); } } return result; } private DoubleDataset[] fitLeastSquares(List<double[]> rows, boolean[] use) { int used = 0; for (boolean u : use) { if (u) { used++; } } int end = rows.get(0).length - 1; DoubleDataset vecb = DatasetFactory.zeros(DoubleDataset.class, used); DoubleDataset matA = DatasetFactory.zeros(DoubleDataset.class, used, end); int k = 0; for (int i = 0; i < use.length; i++) { if (use[i]) { double[] row = rows.get(i); for (int j = 0; j < end; j++) { matA.set(row[j], k, j); } vecb.set(row[end], k); k++; } } // System.err.println(matA.toString(true)); Dataset matAT = matA.getTransposedView(); Dataset matInv = LinearAlgebra.calcPseudoInverse(LinearAlgebra.dotProduct(matAT, matA)); DoubleDataset vecr = DatasetUtils.cast(DoubleDataset.class, LinearAlgebra.dotProduct(matInv, LinearAlgebra.dotProduct(matAT, vecb))); // System.err.println("Shifts : " + vecr.toString(true)); DoubleDataset residuals = DatasetUtils.cast(DoubleDataset.class, Maths.abs(Maths.subtract(LinearAlgebra.dotProduct(matA, vecr), vecb))); // System.err.println("Residuals are " + residuals.toString(true)); return new DoubleDataset[] {vecr, residuals}; } private double peakThresholdFraction = 0.90; // fraction of peak height to use as threshold /** * Set threshold for determination of the centroid of cross-correlation peak * @param threshold as a fraction of peak maximum (default is 0.85) */ public void setPeakCentroidThresholdFraction(double threshold) { peakThresholdFraction = threshold; } public double ccFindShift(IDataset fim, IDataset imb) { Dataset cc = crossCorrelate(fim, imb); int[] maxpos = cc.maxPos(); // peak pos // crop to threshold intercept with at least one side of peak and find centroid double threshold = cc.max().doubleValue() * peakThresholdFraction; int left = maxpos[0]; int right = left; double sum = cc.max().doubleValue(), sum0 = sum*left; double vl, vr; do { vl = cc.getDouble(--left); sum += vl; sum0 += vl*left; vr = cc.getDouble(++right); sum += vr; sum0 += vl*right; } while (vl > threshold && vr > threshold); int hs = (cc.getShapeRef()[0] + 1)/2; return sum0/sum - hs; } /** * @param fim * @param imb * @return return central region of cross-correlation */ public Dataset crossCorrelate(IDataset fim, IDataset imb) { if (dirty) { update(); } Dataset spectrum = Maths.multiply(fim, filtered.get(imb)); Dataset cc = FFT.ifftn(spectrum, pShape, null).getRealView(); return FFT.shift(cc, true); } /** * @param im * @param shift * @return shifted data */ public static Dataset shiftData(Dataset im, double shift) { Dataset newImage = DatasetFactory.zeros(im); int[] shape = im.getShapeRef(); double cx0; for (int x0 = 0; x0 < shape[0]; x0++) { cx0 = x0 + shift; newImage.set(Maths.interpolate(im, cx0), x0); } return newImage; } }