package com.nativelibs4java.opencl.util.fft; import com.nativelibs4java.opencl.*; import com.nativelibs4java.opencl.util.Transformer.AbstractTransformer; import com.nativelibs4java.util.NIOUtils; import java.util.HashMap; import java.util.Map; import java.util.logging.Level; import java.util.logging.Logger; import org.bridj.*; import static org.bridj.Pointer.*; // TODO implement something like http://locklessinc.com/articles/non_power_of_2_fft/ abstract class AbstractFFTPow2<T, A> extends AbstractTransformer<T, A> { AbstractFFTPow2(CLContext context, Class<T> primitiveClass) { super(context, primitiveClass); } private Map<Integer, CLBuffer<Integer>> cachedOffsetsBufs = new HashMap<Integer, CLBuffer<Integer>>(); protected synchronized CLBuffer<Integer> getOffsetsBuf(int length) { CLBuffer<Integer> offsetsBuf = cachedOffsetsBufs.get(length); if (offsetsBuf == null) { int[] offsets = new int[length]; fft_compute_offsetsX(offsets, length, 1, 0, 0); offsetsBuf = context.createBuffer(CLMem.Usage.InputOutput, pointerToInts(offsets), true); cachedOffsetsBufs.put(length, offsetsBuf); } return offsetsBuf; } protected abstract CLEvent cooleyTukeyFFTTwiddleFactors(CLQueue queue, int N, CLBuffer<T> buf, CLEvent... evts) throws CLException ; protected abstract CLEvent cooleyTukeyFFTCopy(CLQueue queue, CLBuffer<T> inBuf, CLBuffer<T> outBuf, int length, CLBuffer<Integer> offsetsBuf, boolean inverse, CLEvent... evts) throws CLException; protected abstract CLEvent cooleyTukeyFFT(CLQueue queue, CLBuffer<T> Y, int N, CLBuffer<T> twiddleFactors, int inverse, int[] dims, CLEvent... evts) throws CLException; Map<Integer, CLBuffer<T>> cachedTwiddleFactors = new HashMap<Integer, CLBuffer<T>>(); protected synchronized CLBuffer<T> getTwiddleFactorsBuf(CLQueue queue, int N) throws CLException { CLBuffer<T> buf = cachedTwiddleFactors.get(N); if (buf == null) { int halfN = N / 2; buf = context.createBuffer(CLMem.Usage.InputOutput, primitiveClass, N); CLEvent.waitFor(cooleyTukeyFFTTwiddleFactors(queue, N, buf)); cachedTwiddleFactors.put(N, buf); } return buf; } private void fft_compute_offsetsX(int[] offsetsX, int N, int s, int offsetX, int offsetY) { if (N == 1) { offsetsX[offsetY] = offsetX; } else { int halfN = N / 2; int twiceS = s * 2; fft_compute_offsetsX(offsetsX, halfN, twiceS, offsetX, offsetY); fft_compute_offsetsX(offsetsX, halfN, twiceS, offsetX + s, offsetY + halfN); } } @Override public CLEvent transform(CLQueue queue, CLBuffer<T> inBuf, CLBuffer<T> outBuf, boolean inverse, CLEvent... eventsToWaitFor) throws CLException { int length = (int)inBuf.getElementCount() / 2; if (Integer.bitCount(length) != 1) throw new UnsupportedOperationException("Only supports FFTs of power-of-two-sized arrays (was given array of length " + length + ")"); CLBuffer<Integer> offsetsBuf = getOffsetsBuf(length); CLEvent copyEvt = cooleyTukeyFFTCopy(queue, inBuf, outBuf, length, offsetsBuf, inverse, eventsToWaitFor); CLEvent dftEvt = fft(queue, inBuf, length, 1, inverse ? 1 : 0, 1, outBuf, copyEvt); return dftEvt; } private CLEvent fft(CLQueue queue, CLBuffer<T> X, int N, int s, int inverse, int blocks, CLBuffer<T> Y, CLEvent... eventsToWaitFor) throws CLException { if (N == 1) { return null; } else { int halfN = N / 2; int twiceS = s * 2; CLEvent[] evts; if (halfN > 1) { evts = new CLEvent[] { fft(queue, X, halfN, twiceS, inverse, blocks * 2, Y, eventsToWaitFor) }; } else { evts = eventsToWaitFor; } return cooleyTukeyFFT(queue, Y, N, getTwiddleFactorsBuf(queue, N), inverse, new int[] { halfN, blocks }, evts); } } }