/*- * Copyright 2017 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.dataset.function; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.eclipse.dawnsci.analysis.api.roi.IRectangularROI; import org.eclipse.dawnsci.analysis.dataset.impl.FFT; import org.eclipse.dawnsci.analysis.dataset.impl.Signal; import org.eclipse.dawnsci.analysis.dataset.impl.function.DatasetToDatasetFunction; import org.eclipse.january.dataset.Dataset; import org.eclipse.january.dataset.DatasetFactory; import org.eclipse.january.dataset.DatasetUtils; import org.eclipse.january.dataset.DoubleDataset; import org.eclipse.january.dataset.IDataset; import org.eclipse.january.dataset.IndexIterator; import org.eclipse.january.dataset.LinearAlgebra; import org.eclipse.january.dataset.Maths; import org.eclipse.january.dataset.Slice; import org.eclipse.january.dataset.SliceND; import uk.ac.diamond.scisoft.analysis.fitting.Fitter; import uk.ac.diamond.scisoft.analysis.fitting.functions.Gaussian; /** * Register images using a phase correlation method that has sub-pixel accuracy * <p> * This fails for noisy, relatively featureless images */ public class RegisterImage implements DatasetToDatasetFunction { private IRectangularROI roi = null; private double tukeyWidth = 0.25; private IDataset anchor; private int[] shape; private int[] pShape; // padded shape private Dataset window; // window function private Dataset cfAnchor; // conjugate transform of windowed anchor private Dataset fAnchor; // transform of windowed anchor private SliceND slice = null; private boolean dirty = true; private Dataset tFilter = null; private Dataset filter; public RegisterImage() { } /** * Set reference image that will act as an anchor * @param reference */ public void setReference(IDataset reference) { if (reference.getRank() != 2) { throw new IllegalArgumentException("Reference dataset must be 2D"); } anchor = reference; shape = anchor.getShape(); dirty = true; } private int[] padShape(int[] shape) { int[] s = shape.clone(); // for (int i = 0; i < s.length; i++) { // s[i] += shape[i] - 1; // pad // } return s; } private void update() { int[] wShape; if (roi == null) { slice = new SliceND(shape); pShape = padShape(shape); wShape = shape; } else { double[] beg = roi.getPoint(); double[] end = roi.getEndPoint(); // use row-major ordering slice = new SliceND(shape, new Slice((int) Math.floor(beg[1]), (int) Math.ceil(end[1])), new Slice((int) Math.floor(beg[0]), (int) Math.ceil(end[0]))); wShape = slice.getShape(); pShape = padShape(wShape); } window = LinearAlgebra.outerProduct(Signal.tukeyWindow(wShape[0], tukeyWidth), Signal.tukeyWindow(wShape[1], tukeyWidth)); if (filter != null) { Dataset nfilter = DatasetFactory.zeros(pShape); nfilter.setSlice(filter, null, filter.getShapeRef(), null); tFilter = FFT.fftn(nfilter, pShape, null); } fAnchor = pTF(anchor); cfAnchor = Maths.conjugate(fAnchor); dirty = false; } /** * Set filter to use for convolving images * @param filter */ public void setFilter(IDataset filter) { this.filter = DatasetUtils.convertToDataset(filter); dirty = true; } /** * Preprocess, transform and filter * @param image * @return result */ public Dataset pTF(IDataset image) { // TODO use gradient images (dx, dy) as complex pair DoubleDataset preprocessed = DatasetUtils.cast(DoubleDataset.class, image.getSlice(slice)); preprocessed.isubtract(preprocessed.mean()).imultiply(window); Dataset transform = FFT.fftn(preprocessed, pShape, null); if (tFilter != null) { transform.imultiply(tFilter); } return transform; } /** * Set width for window function * @param width */ public void setWindowFunction(double width) { tukeyWidth = width; dirty = true; } /** * * @param rectangle (can be null) */ public void setRectangle(IRectangularROI rectangle) { roi = rectangle; dirty = true; } /** * @param datasets array of datasets * @return pairs of datasets of shift and shifted images */ @Override public List<Dataset> value(IDataset... datasets) { if (dirty) { update(); } List<Dataset> result = new ArrayList<Dataset>(); double[] shifts; for (IDataset d : datasets) { if (!Arrays.equals(d.getShape(), shape)) { throw new IllegalArgumentException("Shape of dataset must match reference image"); } Dataset pCorrelation = phaseCorrelate(d); shifts = calcForooshShift(pCorrelation); System.err.println("Foroosh : " + Arrays.toString(shifts)); shifts = findCentroid(pCorrelation, Math.min(pCorrelation.getShapeRef()[0], 7)); System.err.println("Centroid: " + Arrays.toString(shifts)); shifts = fitGaussians(pCorrelation, Math.min(pCorrelation.getShapeRef()[0], 11)); System.err.println("Fit : " + Arrays.toString(shifts)); result.add(DatasetFactory.createFromObject(shifts)); Dataset shiftedImage = shiftImage(DatasetUtils.convertToDataset(d) , shifts); result.add(shiftedImage); } return result; } /** * @param im * @return return central region of phase correlation */ public Dataset phaseCorrelate(IDataset im) { if (dirty) { update(); } Dataset fImage = pTF(im); // phase correlate // Dataset spectrum = Maths.phaseAsComplexNumber(fImage.imultiply(cfAnchor), true); Dataset spectrum = Maths.phaseAsComplexNumber(Maths.dividez(fImage, fAnchor), true); // more stable??? Dataset pc = FFT.ifftn(spectrum, pShape, null).getRealView(); return FFT.fftshift(pc, null); } /** * @param im * @return return central region of phase correlation */ public Dataset phaseCorrelate2(IDataset im) { if (dirty) { update(); } Dataset fImage = pTF(im); // phase correlate Dataset spectrum = Maths.phaseAsComplexNumber(fImage.imultiply(cfAnchor), true); // Dataset spectrum = Maths.phaseAsComplexNumber(Maths.dividez(fImage, fAnchor), true); // more stable??? Dataset pc = FFT.ifftn(spectrum, pShape, null).getRealView(); return FFT.fftshift(pc, null); } /** * @param im * @return return central region of cross-correlation */ public Dataset crossCorrelate(IDataset im) { if (dirty) { update(); } Dataset fImage = pTF(im); Dataset spectrum = fImage.imultiply(cfAnchor); Dataset cc = FFT.ifftn(spectrum, pShape, null).getRealView(); return FFT.fftshift(cc, null); } /** * @param im * @param factor * @return return central region of cross-correlation */ public Dataset crossCorrelate(IDataset im, int factor) { if (dirty) { update(); } Dataset fImage = pTF(im); Dataset spectrum = fImage.imultiply(cfAnchor); int[] nshape; if (factor > 1) { nshape = pShape.clone(); for (int i = 0; i < nshape.length; i++) { nshape[i] *= factor; } spectrum = FFT.zeroPad(spectrum, nshape, true); } else { nshape = pShape; } Dataset cc = FFT.ifftn(spectrum, nshape, null).getRealView(); return FFT.fftshift(cc, null); } // Foroosh et al, "Extension of Phase Correlation to Subpixel Registration", // IEEE Trans. Image Processing, v11n3, 188-200 (2002) protected double[] calcForooshShift(Dataset pc) { int[] maxpos = pc.maxPos(); // peak pos System.out.println("Max: " + Arrays.toString(maxpos)); double c0 = pc.getDouble(maxpos); double[] shifts = new double[2]; for (int i = 0; i < 2; i++) { int l = pc.getShapeRef()[i]; maxpos[i]++; if (maxpos[i] < l) { final double c1 = pc.getDouble(maxpos); double shift = c1/(c1-c0); if (Math.abs(shift) > 1) { shift = c1/(c1+c0); } shifts[i] = maxpos[i] - 1 + shift - l/2; } maxpos[i]--; } return shifts; } protected double[] fitGaussians(Dataset pc, int side) { int[] maxpos = pc.maxPos(); // peak pos int hs = (side+1)/2; int[] beg = new int[] {maxpos[0] - hs + 1, maxpos[1] - hs + 1}; double[] shifts = new double[2]; Gaussian gaussian = new Gaussian(); Dataset data = pc.getSliceView(new int[] {beg[0], maxpos[1]}, new int[] {beg[0] + side, maxpos[1]+1}, null).squeeze(); shifts[0] = fitGaussian(gaussian, data) + beg[0] - (pc.getShapeRef()[0])/2; data = pc.getSliceView(new int[] {maxpos[0], beg[1]}, new int[] {maxpos[0]+1, beg[1] + side}, null).squeeze(); shifts[1] = fitGaussian(gaussian, data) + beg[1] - (pc.getShapeRef()[1])/2; return shifts; } private double fitGaussian(Gaussian gaussian, Dataset data) { int length = data.getSize(); gaussian.setParameterValues(length/2, 2., ((Number) data.sum()).doubleValue()); Dataset axis = DatasetFactory.createRange(DoubleDataset.class, data.getSize()); try { Fitter.ApacheConjugateGradientFit(new Dataset[] {axis}, data, gaussian); return gaussian.getPosition(); } catch (Exception e) { } return Double.NaN; } protected double[] findCentroid(Dataset pc, int side) { int[] maxpos = pc.maxPos(); // peak pos int hs = (side+1)/2; int[] beg = new int[] {maxpos[0] - hs + 1, maxpos[1] - hs + 1}; IndexIterator it = pc.getSliceIterator(new int[] {beg[0], beg[1]}, new int[] {beg[0] + side, beg[1] + side}, null); int[] pos = it.getPos(); double sum = 0, sum0 = 0, sum1 = 0; while (it.hasNext()) { double v = pc.getElementDoubleAbs(it.index); sum += v; sum0 += v*pos[0]; sum1 += v*pos[1]; } return new double[] {sum0/sum - (pc.getShapeRef()[0])/2, sum1/sum - (pc.getShapeRef()[1])/2}; } private Dataset shiftImage(Dataset im, double[] shifts) { Dataset newImage = DatasetFactory.zeros(im); double cx0, cx1; for (int x0 = 0; x0 < shape[0]; x0++) { cx0 = x0 - shifts[0]; for (int x1 = 0; x1 < shape[1]; x1++) { cx1 = x1 - shifts[1]; newImage.set(Maths.interpolate(im, cx0, cx1), x0, x1); } } return newImage; } }