/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/
package org.nd4j.linalg.fft;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.VectorFFT;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals;
/**
* Base class for FFTs
*
* @author Adam Gibson
*/
@Ignore
@RunWith(Parameterized.class)
public class FFTTests extends BaseNd4jTest {
public FFTTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testVectorFftOnes() {
INDArray arr = Nd4j.ones(5);
VectorFFT fft = new VectorFFT(arr);
fft.exec();
INDArray assertion = Nd4j.create(5);
assertion.putScalar(0, 5);
assertEquals(getFailureMessage(), assertion, fft.z());
}
@Test
public void testColumnVector() {
Nd4j.EPS_THRESHOLD = 1e-1;
Nd4j.MAX_ELEMENTS_PER_SLICE = Integer.MAX_VALUE;
Nd4j.MAX_SLICES_TO_PRINT = Integer.MAX_VALUE;
IComplexNDArray complexLinSpace = Nd4j.complexLinSpace(1, 8, 8);
IComplexNDArray n = (IComplexNDArray) Nd4j.getExecutioner().execAndReturn(new VectorFFT(complexLinSpace, 8));
IComplexNDArray assertion = Nd4j.createComplex(new double[] {36., 0., -4., 9.65685425, -4., 4, -4., 1.65685425,
-4., 0., -4., -1.65685425, -4., -4., -4., -9.65685425}, new int[] {1, 8});
assertEquals(getFailureMessage(), n, assertion);
}
@Test
public void testWithOffset() {
INDArray n = Nd4j.create(Nd4j.linspace(1, 30, 30).data(), new int[] {3, 5, 2});
INDArray swapped = n.swapAxes(n.shape().length - 1, 1);
INDArray firstSlice = swapped.slice(0).slice(0);
IComplexNDArray test = Nd4j.createComplex(firstSlice);
IComplexNDArray testNoOffset =
Nd4j.createComplex(new double[] {1, 0, 4, 0, 7, 0, 10, 0, 13, 0}, new int[] {1, 5});
assertEquals(getFailureMessage(), Nd4j.getExecutioner().execAndReturn(new VectorFFT(testNoOffset, 5)),
Nd4j.getExecutioner().execAndReturn(new VectorFFT(test, 5)));
}
@Test
public void testSimple() {
Nd4j.EPS_THRESHOLD = 1e-1;
IComplexNDArray arr = Nd4j.createComplex(
new IComplexNumber[] {Nd4j.createComplexNumber(5, 0), Nd4j.createComplexNumber(1, 0)});
IComplexNDArray arr2 = Nd4j.createComplex(
new IComplexNumber[] {Nd4j.createComplexNumber(1, 0), Nd4j.createComplexNumber(5, 0)});
IComplexNDArray assertion = Nd4j.createComplex(
new IComplexNumber[] {Nd4j.createComplexNumber(6, 0), Nd4j.createComplexNumber(4, 0)});
IComplexNDArray assertion2 = Nd4j.createComplex(new IComplexNumber[] {Nd4j.createComplexNumber(6, 0),
Nd4j.createComplexNumber(-4, 4.371139E-7)});
assertEquals(getFailureMessage(), assertion, Nd4j.getFFt().fft(arr));
assertEquals(getFailureMessage(), assertion2, Nd4j.getFFt().fft(arr2));
}
@Test
public void testMultiDimFFT() {
Nd4j.EPS_THRESHOLD = 1e-1;
INDArray a = Nd4j.linspace(1, 8, 8).reshape(2, 2, 2);
IComplexNDArray fftedAnswer = Nd4j.createComplex(2, 2, 2);
IComplexNDArray matrix1 = Nd4j.createComplex(
new IComplexNumber[][] {{Nd4j.createComplexNumber(36, 0), Nd4j.createComplexNumber(-16, 0)},
{Nd4j.createComplexNumber(-8, 0), Nd4j.createComplexNumber(0, 0)}});
IComplexNDArray matrix2 = Nd4j.createComplex(
new IComplexNumber[][] {{Nd4j.createComplexNumber(-4, 0), Nd4j.createComplexNumber(0, 0)},
{Nd4j.createComplexNumber(0, 0), Nd4j.createComplexNumber(0, 0)}});
fftedAnswer.putSlice(0, matrix1);
fftedAnswer.putSlice(1, matrix2);
IComplexNDArray ffted = FFT.fftn(a);
assertEquals(getFailureMessage(), fftedAnswer, ffted);
Nd4j.EPS_THRESHOLD = 1e-1;
}
@Test
public void testOnes() {
Nd4j.EPS_THRESHOLD = 1e-1;
IComplexNDArray ones = Nd4j.complexOnes(5, 5);
IComplexNDArray ffted = FFT.fftn(ones);
IComplexNDArray zeros = Nd4j.createComplex(5, 5);
zeros.putScalar(0, 0, Nd4j.createComplexNumber(25, 0));
assertEquals(getFailureMessage(), zeros, ffted);
}
@Test
public void testConv4d() {
IComplexNDArray test = Nd4j.complexOnes(new int[] {5, 5, 5, 5});
Nd4j.getFFt().fftn(test);
}
@Test
public void testRawfft() {
Nd4j.EPS_THRESHOLD = 1e-1;
IComplexNDArray test = Nd4j.complexOnes(5, 5);
IComplexNDArray result = Nd4j.getFFt().rawfft(test, 3, 1);
IComplexNDArray assertion = Nd4j.createComplex(5, 3);
for (int i = 0; i < assertion.rows(); i++)
assertion.slice(i).putScalar(0, Nd4j.createComplexNumber(3, 0));
for (int i = 0; i < result.slices(); i++) {
IComplexNDArray assertionSlice = assertion.slice(i);
IComplexNDArray resultSlice = result.slice(i);
assertEquals(getFailureMessage() + " Failed on iteration " + i, assertionSlice, resultSlice);
}
}
@Override
public char ordering() {
return 'f';
}
}