/*- ******************************************************************************* * Copyright (c) 2011, 2014 Diamond Light Source Ltd. * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html * * Contributors: * Peter Chang - initial API and implementation and/or initial documentation *******************************************************************************/ package org.eclipse.dawnsci.analysis.dataset.impl; import java.util.Arrays; import java.util.SortedSet; import java.util.TreeSet; import org.eclipse.january.dataset.ComplexDoubleDataset; import org.eclipse.january.dataset.ComplexFloatDataset; import org.eclipse.january.dataset.Dataset; import org.eclipse.january.dataset.DatasetFactory; import org.eclipse.january.dataset.DatasetUtils; import org.eclipse.january.dataset.DoubleDataset; import org.eclipse.january.dataset.PositionIterator; import org.eclipse.january.dataset.Slice; import org.eclipse.january.dataset.SliceIterator; import org.eclipse.january.dataset.SliceND; import org.jtransforms.fft.DoubleFFT_1D; import org.jtransforms.fft.DoubleFFT_2D; import org.jtransforms.fft.DoubleFFT_3D; import org.jtransforms.fft.FloatFFT_1D; import org.jtransforms.fft.FloatFFT_2D; import org.jtransforms.fft.FloatFFT_3D; import org.slf4j.Logger; import org.slf4j.LoggerFactory; // TODO fast path for defaults? // axes and shape correspondence /** * Class to hold methods to compute discrete Fourier transforms * * Emulates numpy.fft */ public class FFT { /** * Setup the logging facilities */ protected static final Logger logger = LoggerFactory.getLogger(FFT.class); /** * forward 1D fast Fourier transform * @param a dataset * @return new dataset holding transform */ public static Dataset fft(final Dataset a) { return fft(a, a.getShapeRef()[a.getRank() - 1], -1); } /** * forward 1D fast Fourier transform * @param a dataset * @param n number of points * @param axis (negative numbers refer to axes from end, eg. -1 is last axis) * @return new dataset holding transform */ public static Dataset fft(final Dataset a, final int n, int axis) { if (n <= 0) { logger.error("number of points should be greater than zero"); throw new IllegalArgumentException("number of points should be greater than zero"); } axis = a.checkAxis(axis); return fft1d(a, n, axis); } /** * forward 2D fast Fourier transform * @param a dataset * @param s shape of FFT dataset (if null, use whole dataset) * @param axes for FFT (if null, default as [-2,-1]) * @return new dataset holding transform */ public static Dataset fft2(final Dataset a, int[] s, int[] axes) { int rank = a.getRank(); if (rank < 2) { logger.error("dataset should be at least 2 dimensional"); throw new IllegalArgumentException("dataset should be at least 2 dimensional"); } if (axes == null) { axes = new int[] {rank-2, rank-1}; } else if (axes.length != 2) { logger.error("axes should have two entries"); throw new IllegalArgumentException("axes should have two entries"); } if (s == null) { s = new int[2]; int[] shape = a.getShapeRef(); for (int i = 0; i < 2; i++) { axes[i] = a.checkAxis(axes[i]); s[i] = shape[axes[i]]; } } else if (s.length < 2) { logger.error("shape should not have more than 2 dimensions"); throw new IllegalArgumentException("shape should not have more than 2 dimensions"); } return fftn(a, s, axes); } /** * forward nD fast Fourier transform * @param a dataset * @param s shape of FFT dataset (if null, use whole dataset) * @param axes for FFT (if null, default as [..., -1]) * @return new dataset holding transform */ public static Dataset fftn(final Dataset a, int[] s, int[] axes) { int[] shape = a.getShape(); int rank = shape.length; Dataset result = null; if (s == null) { if (axes == null) { s = shape; axes = new int[rank]; for (int i = 0; i < rank; i++) axes[i] = i; } else { s = new int[axes.length]; Arrays.sort(axes); for (int i = 0; i < axes.length; i++) { axes[i] = a.checkAxis(axes[i]); s[i] = shape[axes[i]]; } } } else { if (s.length > rank) { logger.error("shape of FFT should not have more dimensions than dataset"); throw new IllegalArgumentException("shape of FFT should not have more dimensions than dataset"); } if (axes == null) { axes = new int[s.length]; for (int i = 0; i < s.length; i++) axes[i] = rank - s.length + i; } else { if (s.length != axes.length) { logger.error("shape of FFT should have same rank as axes"); throw new IllegalArgumentException("shape of FFT should have same rank as axes"); } } } if (s.length > 3) { logger.error("Fourier transform across more than 3 dimensions are not supported"); throw new IllegalArgumentException("Fourier transform across more than 3 dimensions are not supported"); } for (int i = 0; i < axes.length; i++) { if (s[i] <= 0) { logger.error("dimensions should be greater than zero"); throw new IllegalArgumentException("dimensions should be greater than zero"); } axes[i] = a.checkAxis(axes[i]); } switch (s.length) { case 1: result = fft1d(a, s[0], axes[0]); break; case 2: result = fft2d(a, s, axes); break; case 3: result = fft3d(a, s, axes); break; } return result; } private static int[] newShape(final int[] shape, final int[] s, final int[] axes) { int[] nshape = shape.clone(); for (int i = 0; i < s.length; i++) { nshape[axes[i]] = s[i]; } return nshape; } private static Dataset fft1d(final Dataset a, final int n, final int axis) { Dataset result = null; Dataset dest = null; int[] shape; PositionIterator pi; int[] pos; boolean[] hit; switch (a.getDType()) { case Dataset.FLOAT32: case Dataset.COMPLEX64: FloatFFT_1D ffft = new FloatFFT_1D(n); shape = a.getShape(); shape[axis] = n; result = DatasetFactory.zeros(ComplexFloatDataset.class, shape); dest = DatasetFactory.zeros(ComplexFloatDataset.class, new int[] {2*n}); float[] fdata = (float[]) dest.getBuffer(); pi = a.getPositionIterator(axis); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(fdata, 0.f); a.copyItemsFromAxes(pos, hit, dest); ffft.complexForward(fdata); result.setItemsOnAxes(pos, hit, fdata); } break; case Dataset.FLOAT64: case Dataset.COMPLEX128: DoubleFFT_1D dfft = new DoubleFFT_1D(n); shape = a.getShape(); shape[axis] = n; result = DatasetFactory.zeros(ComplexDoubleDataset.class, shape); dest = DatasetFactory.zeros(ComplexDoubleDataset.class, new int[] {2*n}); double[] ddata = (double[]) dest.getBuffer(); pi = a.getPositionIterator(axis); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(ddata, 0.); a.copyItemsFromAxes(pos, hit, dest); dfft.complexForward(ddata); result.setItemsOnAxes(pos, hit, ddata); } break; default: logger.warn("Non-float dataset not yet supported"); break; } return result; } private static Dataset fft2d(final Dataset a, final int[] s, final int[] axes) { Dataset result = null; Dataset dest = null; PositionIterator pi; int[] pos; boolean[] hit; switch (a.getDType()) { case Dataset.FLOAT32: case Dataset.COMPLEX64: FloatFFT_2D ffft = new FloatFFT_2D(s[0], s[1]); float[] fdata = null; result = DatasetFactory.zeros(ComplexFloatDataset.class, newShape(a.getShapeRef(), s, axes)); dest = DatasetFactory.zeros(ComplexFloatDataset.class, s); fdata = (float[]) dest.getBuffer(); pi = a.getPositionIterator(axes); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(fdata, 0.f); a.copyItemsFromAxes(pos, hit, dest); ffft.complexForward(fdata); result.setItemsOnAxes(pos, hit, fdata); } break; case Dataset.FLOAT64: case Dataset.COMPLEX128: DoubleFFT_2D dfft = new DoubleFFT_2D(s[0], s[1]); double[] ddata = null; result = DatasetFactory.zeros(ComplexDoubleDataset.class, newShape(a.getShapeRef(), s, axes)); dest = DatasetFactory.zeros(ComplexDoubleDataset.class, s); ddata = (double[]) dest.getBuffer(); pi = a.getPositionIterator(axes); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(ddata, 0.); a.copyItemsFromAxes(pos, hit, dest); dfft.complexForward(ddata); result.setItemsOnAxes(pos, hit, ddata); } break; default: logger.warn("Non-float dataset not yet supported"); break; } return result; } private static Dataset fft3d(final Dataset a, final int[] s, final int[] axes) { Dataset result = null; Dataset dest = null; PositionIterator pi; int[] pos; boolean[] hit; switch (a.getDType()) { case Dataset.FLOAT32: case Dataset.COMPLEX64: FloatFFT_3D ffft = new FloatFFT_3D(s[0], s[1], s[2]); float[] fdata = null; result = DatasetFactory.zeros(ComplexFloatDataset.class, newShape(a.getShapeRef(), s, axes)); dest = DatasetFactory.zeros(ComplexFloatDataset.class, s); fdata = (float[]) dest.getBuffer(); pi = a.getPositionIterator(axes); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(fdata, 0.f); a.copyItemsFromAxes(pos, hit, dest); ffft.complexForward(fdata); result.setItemsOnAxes(pos, hit, fdata); } break; case Dataset.FLOAT64: case Dataset.COMPLEX128: DoubleFFT_3D dfft = new DoubleFFT_3D(s[0], s[1], s[2]); double[] ddata = null; result = DatasetFactory.zeros(ComplexDoubleDataset.class, newShape(a.getShapeRef(), s, axes)); dest = DatasetFactory.zeros(ComplexDoubleDataset.class, s); ddata = (double[]) dest.getBuffer(); pi = a.getPositionIterator(axes); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(ddata, 0.); a.copyItemsFromAxes(pos, hit, dest); dfft.complexForward(ddata); result.setItemsOnAxes(pos, hit, ddata); } break; default: logger.warn("Non-float dataset not yet supported"); break; } return result; } /** * inverse 1D fast Fourier transform * @param a dataset * @return new dataset holding transform */ public static Dataset ifft(final Dataset a) { return ifft(a, a.getShapeRef()[a.getRank() - 1], -1); } /** * inverse 1D fast Fourier transform * @param a dataset * @param n number of points * @param axis (negative numbers refer to axes from end, eg. -1 is last axis) * @return new dataset holding transform */ public static Dataset ifft(final Dataset a, final int n, int axis) { if (n <= 0) { logger.error("number of points should be greater than zero"); throw new IllegalArgumentException("number of points should be greater than zero"); } axis = a.checkAxis(axis); return ifft1d(a, n, axis); } /** * inverse 2D fast Fourier transform * @param a dataset * @param s shape of FFT dataset (if null, use whole dataset) * @param axes for FFT (default as [-2,-1]) * @return new dataset holding transform */ public static Dataset ifft2(final Dataset a, int[] s, int[] axes) { int rank = a.getRank(); if (rank < 2) { logger.error("dataset should be at least 2 dimensional"); throw new IllegalArgumentException("dataset should be at least 2 dimensional"); } if (axes == null) { axes = new int[] {rank-2, rank-1}; } else if (axes.length != 2) { logger.error("axes should have two entries"); throw new IllegalArgumentException("axes should have two entries"); } if (s == null) { s = new int[2]; int[] shape = a.getShapeRef(); for (int i = 0; i < 2; i++) { axes[i] = a.checkAxis(axes[i]); s[i] = shape[axes[i]]; } } else if (s.length < 2) { logger.error("shape should not have more than 2 dimensions"); throw new IllegalArgumentException("shape should not have more than 2 dimensions"); } return ifftn(a, s, axes); } /** * inverse nD fast Fourier transform * @param a dataset * @param s shape of FFT dataset (if null, use whole dataset) * @param axes for FFT (if null, default as [..., -1]) * @return new dataset holding transform */ public static Dataset ifftn(final Dataset a, int[] s, int[] axes) { int[] shape = a.getShape(); int rank = shape.length; Dataset result = null; if (s == null) { if (axes == null) { s = shape; axes = new int[rank]; for (int i = 0; i < rank; i++) axes[i] = i; } else { s = new int[axes.length]; Arrays.sort(axes); for (int i = 0; i < axes.length; i++) { axes[i] = a.checkAxis(axes[i]); s[i] = shape[axes[i]]; } } } else { if (s.length > rank) { logger.error("shape of FFT should not have more dimensions than dataset"); throw new IllegalArgumentException("shape of FFT should not have more dimensions than dataset"); } if (axes == null) { axes = new int[s.length]; for (int i = 0; i < s.length; i++) axes[i] = rank - s.length + i; } else { if (s.length != axes.length) { logger.error("shape of FFT should have same rank as axes"); throw new IllegalArgumentException("shape of FFT should have same rank as axes"); } } } if (s.length > 3) { logger.error("Fourier transform across more than 3 dimensions are not supported"); throw new IllegalArgumentException("Fourier transform across more than 3 dimensions are not supported"); } for (int i = 0; i < axes.length; i++) { if (s[i] <= 0) { logger.error("dimensions should be greater than zero"); throw new IllegalArgumentException("dimensions should be greater than zero"); } axes[i] = a.checkAxis(axes[i]); } switch (s.length) { case 1: result = ifft1d(a, s[0], axes[0]); break; case 2: result = ifft2d(a, s, axes); break; case 3: result = ifft3d(a, s, axes); break; } return result; } private static Dataset ifft1d(final Dataset a, final int n, final int axis) { Dataset result = null; Dataset dest = null; int[] shape; PositionIterator pi; int[] pos; boolean[] hit; switch (a.getDType()) { case Dataset.FLOAT32: case Dataset.COMPLEX64: FloatFFT_1D ffft = new FloatFFT_1D(n); float[] fdata = null; shape = a.getShape(); shape[axis] = n; result = DatasetFactory.zeros(ComplexFloatDataset.class, shape); dest = DatasetFactory.zeros(ComplexFloatDataset.class, new int[] {n}); fdata = (float[]) dest.getBuffer(); pi = a.getPositionIterator(axis); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(fdata, 0.f); a.copyItemsFromAxes(pos, hit, dest); ffft.complexInverse(fdata, true); result.setItemsOnAxes(pos, hit, fdata); } break; case Dataset.FLOAT64: case Dataset.COMPLEX128: DoubleFFT_1D dfft = new DoubleFFT_1D(n); double[] ddata = null; shape = a.getShape(); shape[axis] = n; result = DatasetFactory.zeros(ComplexDoubleDataset.class, shape); dest = DatasetFactory.zeros(ComplexDoubleDataset.class, new int[] {n}); ddata = (double[]) dest.getBuffer(); pi = a.getPositionIterator(axis); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(ddata, 0.); a.copyItemsFromAxes(pos, hit, dest); dfft.complexInverse(ddata, true); result.setItemsOnAxes(pos, hit, ddata); } break; default: logger.warn("Non-float dataset not yet supported"); break; } return result; } private static Dataset ifft2d(final Dataset a, final int[] s, final int[] axes) { Dataset result = null; Dataset dest = null; PositionIterator pi; int[] pos; boolean[] hit; switch (a.getDType()) { case Dataset.FLOAT32: case Dataset.COMPLEX64: FloatFFT_2D ffft = new FloatFFT_2D(s[0], s[1]); float[] fdata = null; result = DatasetFactory.zeros(ComplexFloatDataset.class, newShape(a.getShapeRef(), s, axes)); dest = DatasetFactory.zeros(ComplexFloatDataset.class, s); fdata = (float[]) dest.getBuffer(); pi = a.getPositionIterator(axes); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(fdata, 0.f); a.copyItemsFromAxes(pos, hit, dest); ffft.complexInverse(fdata, true); result.setItemsOnAxes(pos, hit, fdata); } break; case Dataset.FLOAT64: case Dataset.COMPLEX128: DoubleFFT_2D dfft = new DoubleFFT_2D(s[0], s[1]); double[] ddata = null; result = DatasetFactory.zeros(ComplexDoubleDataset.class, newShape(a.getShapeRef(), s, axes)); dest = DatasetFactory.zeros(ComplexDoubleDataset.class, s); ddata = (double[]) dest.getBuffer(); pi = a.getPositionIterator(axes); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(ddata, 0.); a.copyItemsFromAxes(pos, hit, dest); dfft.complexInverse(ddata, true); result.setItemsOnAxes(pos, hit, ddata); } break; default: logger.warn("Non-float dataset not yet supported"); break; } return result; } private static Dataset ifft3d(final Dataset a, final int[] s, final int[] axes) { Dataset result = null; Dataset dest = null; PositionIterator pi; int[] pos; boolean[] hit; switch (a.getDType()) { case Dataset.FLOAT32: case Dataset.COMPLEX64: FloatFFT_3D ffft = new FloatFFT_3D(s[0], s[1], s[2]); float[] fdata = null; result = DatasetFactory.zeros(ComplexFloatDataset.class, newShape(a.getShapeRef(), s, axes)); dest = DatasetFactory.zeros(ComplexFloatDataset.class, s); fdata = (float[]) dest.getBuffer(); pi = a.getPositionIterator(axes); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(fdata, 0.f); a.copyItemsFromAxes(pos, hit, dest); ffft.complexInverse(fdata, true); result.setItemsOnAxes(pos, hit, fdata); } break; case Dataset.FLOAT64: case Dataset.COMPLEX128: DoubleFFT_3D dfft = new DoubleFFT_3D(s[0], s[1], s[2]); double[] ddata = null; result = DatasetFactory.zeros(ComplexDoubleDataset.class, newShape(a.getShapeRef(), s, axes)); dest = DatasetFactory.zeros(ComplexDoubleDataset.class, s); ddata = (double[]) dest.getBuffer(); pi = a.getPositionIterator(axes); pos = pi.getPos(); hit = pi.getOmit(); while (pi.hasNext()) { Arrays.fill(ddata, 0.); a.copyItemsFromAxes(pos, hit, dest); dfft.complexInverse(ddata, true); result.setItemsOnAxes(pos, hit, ddata); } break; default: logger.warn("Non-float dataset not yet supported"); break; } return result; } /** * Shift zero-frequency component to centre of dataset * @param a * @param axes (if null, then shift all axes) * @return shifted dataset */ public static Dataset fftshift(final Dataset a, int[] axes) { SortedSet<Integer> axis = new TreeSet<Integer>(); int alen; if (axes == null) { alen = a.getRank(); for (int i = 0; i < alen; i++) { axis.add(i); } // return shift(a, true); } else { alen = axes.length; for (int i = 0; i < alen; i++) { axis.add(a.checkAxis(axes[i])); } // boolean all = axis.size() == a.getRank(); // if (all) { // int n = 0; // for (int i : axis) { // if (i != n++) { // all = false; // break; // } // } // } } // if (all) { // return shift(a, true); // } Dataset result = a; int[] indices; for (int i : axis) { int n = a.getShapeRef()[i]; int p = (n+1)/2; // logger.info("Shift {} by {}", i, p); indices = new int[n]; for (int j = 0; j < n; j++) { if (j < n - p) indices[j] = p + j; else indices[j] = j - n + p; } result = DatasetUtils.take(result, indices, i); } return result; } /** * Reverse shift of zero-frequency component to centre of dataset * @param a * @param axes (if null, then shift all axes) * @return shifted dataset */ public static Dataset ifftshift(final Dataset a, int[] axes) { SortedSet<Integer> axis = new TreeSet<Integer>(); int alen; if (axes == null) { alen = a.getRank(); for (int i = 0; i < alen; i++) { axis.add(i); } } else { alen = axes.length; for (int i = 0; i < alen; i++) { axis.add(a.checkAxis(axes[i])); } // // boolean all = axis.size() == a.getRank(); // if (all) { // int n = 0; // for (int i : axis) { // if (i != n++) { // all = false; // break; // } // } // } } Dataset result = a; int[] indices; for (int i : axis) { int n = a.getShapeRef()[i]; int p = n - (n+1)/2; // logger.info("Shift {} by {}", axis, p); indices = new int[n]; for (int j = 0; j < n; j++) { if (j < n - p) indices[j] = p + j; else indices[j] = j - n + p; } result = DatasetUtils.take(result, indices, i); } return result; } /** * Discrete FFT sample frequencies * @param n number of samples * @param d sample spacing * @return frequencies */ public static Dataset sampleFrequencies(int n, double d) { int hn = n/2; return DatasetUtils.roll(DatasetFactory.createRange(DoubleDataset.class, n).isubtract(hn).imultiply(1/(d*n)), n - hn, null); } /** * Pad out dataset to new shape with zeros. There are two ways to zero-pad: * <ol> * <li>set data in volume starting at origin</li> * <li>split data in 2^N volumes and scatter to corners of larger volume</li> * </ol> * * @param input * @param newShape * @param inFreqSpace if true, then pad higher frequencies as zero * @return zero-padded dataset */ public static Dataset zeroPad(Dataset input, int[] newShape, boolean inFreqSpace) { Dataset output = DatasetFactory.zeros(input.getElementsPerItem(), newShape, input.getDType()); if (inFreqSpace) { int rank = input.getRank(); int[] iShape = input.getShapeRef(); int[] hShape = iShape.clone(); int[] rShape = iShape.clone(); for (int i = 0; i < rank; i++) { hShape[i] = (hShape[i] + 1) / 2; rShape[i] -= hShape[i]; } int[] del = new int[rank]; Arrays.fill(del, 2); SliceIterator it = new SliceIterator(input.getShapeRef(), input.getSize(), null, hShape, del.clone()); int[] pos = it.getPos(); SliceND iSlice = new SliceND(iShape, new Slice(1)); SliceND oSlice = new SliceND(newShape, new Slice(1)); while (it.hasNext()) { for (int i = 0; i < rank; i++) { int b = pos[i]; int l = hShape[i]; if (b == 0) { iSlice.setSlice(i, b, b + l, 1); oSlice.setSlice(i, 0, l, 1); } else { iSlice.setSlice(i, b, iShape[i], 1); l = newShape[i]; oSlice.setSlice(i, l - rShape[i], l, 1); } } output.setSlice(input.getSliceView(iSlice), oSlice); } } else { output.setSlice(input, null, input.getShapeRef(), null); } return output; } /** * Shift frequency components * <p> * A forward shift moves the zero-frequency from the origin position * to the centre of the dataset. A backward shift reverts a forward shift * * @param input * @param forward if true, shift forward, else shift back * @return shifted dataset */ public static Dataset shift(Dataset input, boolean forward) { Dataset output = DatasetFactory.zeros(input); int rank = input.getRank(); int[] iShape = input.getShapeRef(); int[] hShape = new int[rank]; int[] rShape = new int[rank]; int adjust = forward ? 1 : -1; for (int i = 0; i < rank; i++) { int l = (iShape[i] + adjust) / 2; hShape[i] = l; rShape[i] = iShape[i] - l; } int[] del = new int[rank]; Arrays.fill(del, 2); SliceIterator it = new SliceIterator(input.getShapeRef(), input.getSize(), null, hShape, del.clone()); int[] pos = it.getPos(); SliceND iSlice = new SliceND(iShape, new Slice(1)); SliceND oSlice = new SliceND(iShape, new Slice(1)); while (it.hasNext()) { for (int i = 0; i < rank; i++) { int b = pos[i]; int l = hShape[i]; if (b == 0) { iSlice.setSlice(i, b, b + l, 1); oSlice.setSlice(i, rShape[i], iShape[i], 1); } else { iSlice.setSlice(i, b, iShape[i], 1); oSlice.setSlice(i, 0, rShape[i], 1); } } output.setSlice(input.getSliceView(iSlice), oSlice); } return output; } }