package gdsc.smlm.utils;
import java.util.Arrays;
import org.apache.commons.math3.util.FastMath;
import org.jtransforms.fft.DoubleFFT_1D;
import org.jtransforms.utils.CommonUtils;
/*-----------------------------------------------------------------------------
* GDSC SMLM Software
*
* Copyright (C) 2013 Alex Herbert
* Genome Damage and Stability Centre
* University of Sussex, UK
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 3 of the License, or
* (at your option) any later version.
*---------------------------------------------------------------------------*/
/**
* Simple class to perform convolution
*/
public class Convolution
{
/**
* Calculates the <a href="http://en.wikipedia.org/wiki/Convolution">
* convolution</a> between two sequences.
* <p>
* The solution is obtained via straightforward computation of the convolution sum (and not via FFT). Whenever the
* computation needs an element that would be located at an index outside the input arrays, the value is assumed to
* be zero.
* <p>
* This has been taken from Apache Commons Math v3.3: org.apache.commons.math3.util.MathArrays
*
* @param x
* First sequence.
* Typically, this sequence will represent an input signal to a system.
* @param h
* Second sequence.
* Typically, this sequence will represent the impulse response of the system.
* @return the convolution of {@code x} and {@code h}.
* This array's length will be {@code x.length + h.length - 1}.
* @throws IllegalArgumentException
* if either {@code x} or {@code h} is {@code null} or either {@code x} or {@code h} is empty.
*/
public static double[] convolve(double[] x, double[] h) throws IllegalArgumentException
{
checkInput(x, h);
final int xLen = x.length;
final int hLen = h.length;
// initialize the output array
final int totalLength = xLen + hLen - 1;
final double[] y = new double[totalLength];
// straightforward implementation of the convolution sum
for (int n = 0; n < totalLength; n++)
{
double yn = 0;
int k = FastMath.max(0, n + 1 - xLen);
int j = n - k;
while (k < hLen && j >= 0)
{
yn += x[j--] * h[k++];
}
y[n] = yn;
}
return y;
}
/**
* Calculates the <a href="http://en.wikipedia.org/wiki/Convolution">
* convolution</a> between two sequences.
* <p>
* The solution is obtained via multiplication in the frequency domain.
*
* @param x
* First sequence.
* Typically, this sequence will represent an input signal to a system.
* @param h
* Second sequence.
* Typically, this sequence will represent the impulse response of the system.
* @return the convolution of {@code x} and {@code h}.
* This array's length will be {@code x.length + h.length - 1}.
* @throws IllegalArgumentException
* if either {@code x} or {@code h} is {@code null} or either {@code x} or {@code h} is empty.
*/
public static double[] convolveFFT(double[] x, double[] h) throws IllegalArgumentException
{
checkInput(x, h);
if (x.length < h.length)
{
// Swap so that the longest array is the signal
final double[] tmp = x;
x = h;
h = tmp;
}
final int xLen = x.length;
final int hLen = h.length;
final int totalLength = xLen + hLen - 1;
// Get length to a power of 2
int newL = CommonUtils.nextPow2(totalLength);
// Double the new length for complex values in DoubleFFT_1D
x = Arrays.copyOf(x, 2 * newL);
h = Arrays.copyOf(h, x.length);
//double[] tmp = new double[x.length];
DoubleFFT_1D fft = new DoubleFFT_1D(newL);
// FFT
fft.realForwardFull(x);
fft.realForwardFull(h);
// Complex multiply. Reuse data array
for (int i = 0; i < x.length; i += 2)
{
int j = i + 1;
double xi = x[i];
double xj = x[j];
x[i] = xi * h[i] - xj * h[j];
x[j] = xi * h[j] + xj * h[i];
}
// Inverse FFT
fft.complexInverse(x, true);
// Fill result with real part
final double[] y = new double[totalLength];
for (int i = 0; i < totalLength; i++)
{
y[i] = x[2 * i];
}
return y;
}
/**
* Calculates the <a href="http://en.wikipedia.org/wiki/Convolution">
* convolution</a> between two sequences.
* <p>
* The solution is obtained using either the spatial or frequency domain depending on the size. The switch is made
* when the min array length is above 127 and the product of the lengths is above 40000. Speed tests have
* been performed for single threaded FFT computation. The FFT library begins multi-threaded computation when the
* size of the array is above length 8192.
*
* @param x
* First sequence.
* Typically, this sequence will represent an input signal to a system.
* @param h
* Second sequence.
* Typically, this sequence will represent the impulse response of the system.
* @return the convolution of {@code x} and {@code h}.
* This array's length will be {@code x.length + h.length - 1}.
* @throws IllegalArgumentException
* if either {@code x} or {@code h} is {@code null} or either {@code x} or {@code h} is empty.
*/
public static double[] convolveFast(double[] x, double[] h) throws IllegalArgumentException
{
checkInput(x, h);
// See Junit class ConvolveTest to determine when to switch to the FFT method.
// This is not perfect for all length combinations but the switch will happen
// when the two methods are roughly the same speed.
int min, max;
if (x.length < h.length)
{
min = x.length;
max = h.length;
}
else
{
min = h.length;
max = x.length;
}
if (min >= 128 && (long) min * (long) max > 40000L)
return convolveFFT(x, h);
return convolve(x, h);
}
private static void checkInput(double[] x, double[] h)
{
if (x == null)
throw new IllegalArgumentException("Input x is null");
if (h == null)
throw new IllegalArgumentException("Input g is null");
if (x.length == 0 || h.length == 0)
{
throw new IllegalArgumentException("Input x or h have no length");
}
}
}