/*-
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*
*/
package org.nd4j.linalg;
import org.junit.After;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.VectorFFT;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.ComplexUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.*;
/**
* Tests for a complex ndarray
*
* @author Adam Gibson
*/
@Ignore
@RunWith(Parameterized.class)
public class ComplexNDArrayTestsC extends BaseComplexNDArrayTests {
private static Logger log = LoggerFactory.getLogger(ComplexNDArrayTestsC.class);
public ComplexNDArrayTestsC() {}
public ComplexNDArrayTestsC(Nd4jBackend backend) {
super(backend);
}
@Before
public void before() throws Exception {
super.before();
}
@After
public void after() throws Exception {
super.after();
}
@Test
public void testConstruction() {
IComplexNDArray arr2 = Nd4j.createComplex(new int[] {3, 2});
assertEquals(3, arr2.rows());
assertEquals(arr2.rows(), arr2.rows());
assertEquals(2, arr2.columns());
assertEquals(arr2.columns(), arr2.columns());
assertTrue(arr2.isMatrix());
IComplexNDArray arr = Nd4j.createComplex(new double[] {0, 1}, new int[] {1, 1});
//only each complex double: one element
assertEquals(1, arr.length());
//both real and imaginary components
assertEquals(2, arr.data().length());
IComplexNumber n1 = (IComplexNumber) arr.getScalar(0).element();
assertEquals(0, n1.realComponent().doubleValue(), 1e-1);
IComplexDouble[] two = new IComplexDouble[2];
two[0] = Nd4j.createDouble(1, 0);
two[1] = Nd4j.createDouble(2, 0);
double[] testArr = {1, 0, 2, 0};
IComplexNDArray assertComplexDouble = Nd4j.createComplex(testArr, new int[] {1, 2});
IComplexNDArray testComplexDouble = Nd4j.createComplex(two, new int[] {1, 2});
assertEquals(assertComplexDouble, testComplexDouble);
}
@Test
public void testSort() {
IComplexNDArray matrix = Nd4j.complexLinSpace(1, 4, 4).reshape(2, 2);
IComplexNDArray sorted = Nd4j.sort(matrix.dup(), 1, true);
assertEquals(matrix, sorted);
IComplexNDArray reversed = Nd4j.createComplex(new float[] {2, 0, 1, 0, 4, 0, 3, 0}, new int[] {2, 2});
IComplexNDArray sortedReversed = Nd4j.sort(matrix, 1, false);
assertEquals(reversed, sortedReversed);
}
@Test
public void testSortWithIndicesDescending() {
IComplexNDArray toSort = Nd4j.complexLinSpace(1, 4, 4).reshape(2, 2);
//indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, false);
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, false);
assertEquals(sorted[1], sorted2);
INDArray shouldIndex = Nd4j.create(new float[] {1, 0, 1, 0}, new int[] {2, 2});
assertEquals(shouldIndex, sorted[0]);
}
@Test
public void testSortWithIndices() {
IComplexNDArray toSort = Nd4j.complexLinSpace(1, 4, 4).reshape(2, 2);
//indices,data
INDArray[] sorted = Nd4j.sortWithIndices(toSort.dup(), 1, true);
INDArray sorted2 = Nd4j.sort(toSort.dup(), 1, true);
assertEquals(sorted[1], sorted2);
INDArray shouldIndex = Nd4j.create(new float[] {0, 1, 0, 1}, new int[] {2, 2});
assertEquals(shouldIndex, sorted[0]);
}
@Test
public void testAssignOffset() {
IComplexNDArray arr = Nd4j.complexOnes(5, 5);
IComplexNDArray row = arr.slice(1);
row.assign(1);
assertEquals(Nd4j.complexOnes(5), row);
}
@Test
public void testDimShuffle() {
IComplexNDArray n = Nd4j.complexLinSpace(1, 4, 4).reshape(2, 2);
IComplexNDArray twoOneTwo =
n.dimShuffle(new Object[] {0, 'x', 1}, new int[] {0, 1}, new boolean[] {false, false});
assertTrue(Arrays.equals(new int[] {2, 1, 2}, twoOneTwo.shape()));
IComplexNDArray reverse =
n.dimShuffle(new Object[] {1, 'x', 0}, new int[] {1, 0}, new boolean[] {false, false});
assertTrue(Arrays.equals(new int[] {2, 1, 2}, reverse.shape()));
}
@Test
public void testPutComplex() {
INDArray fourTwoTwo = Nd4j.linspace(1, 16, 16).reshape(4, 2, 2);
IComplexNDArray test = Nd4j.createComplex(4, 2, 2);
for (int i = 0; i < test.vectorsAlongDimension(0); i++) {
INDArray vector = fourTwoTwo.vectorAlongDimension(i, 0);
IComplexNDArray complexVector = test.vectorAlongDimension(i, 0);
for (int j = 0; j < complexVector.length(); j++) {
complexVector.putReal(j, vector.getFloat(j));
}
}
for (int i = 0; i < test.vectorsAlongDimension(0); i++) {
INDArray vector = fourTwoTwo.vectorAlongDimension(i, 0);
IComplexNDArray complexVector = test.vectorAlongDimension(i, 0);
assertEquals(vector, complexVector.real());
}
}
@Test
public void testColumnWithReshape() {
IComplexNDArray ones = Nd4j.complexOnes(4).reshape(2, 2);
IComplexNDArray column = Nd4j.createComplex(new float[] {2, 0, 6, 0});
ones.putColumn(1, column);
assertEquals(column, ones.getColumn(1));
}
@Test
public void testPutSlice() {
}
@Test
public void testSum() {
IComplexNDArray n = Nd4j.createComplex(Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[] {2, 2, 2}));
assertEquals(Nd4j.createDouble(36, 0), n.sumComplex());
}
@Test
public void testCreateComplexFromReal() {
INDArray n = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8}, new int[] {2, 4});
IComplexNDArray nComplex = Nd4j.createComplex(n);
for (int i = 0; i < n.vectorsAlongDimension(0); i++) {
INDArray vec = n.vectorAlongDimension(i, 0);
IComplexNDArray vecComplex = nComplex.vectorAlongDimension(i, 0);
assertEquals(vec.length(), vecComplex.length());
for (int j = 0; j < vec.length(); j++) {
IComplexNumber currComplex = vecComplex.getComplex(j);
double curr = vec.getFloat(j);
assertEquals(curr, currComplex.realComponent().doubleValue(), 1e-1);
}
assertEquals(vec, vecComplex.getReal());
}
}
@Test
public void testVectorOffsetRavel() {
IComplexNDArray arr = Nd4j.complexLinSpace(1, 20, 20).reshape(4, 5);
for (int i = 0; i < arr.slices(); i++) {
assertEquals(arr.slice(i), arr.slice(i).ravel());
}
}
@Test
public void testSliceVsVectorAlongDimension() {
IComplexNDArray arr = Nd4j.complexLinSpace(1, 20, 20).reshape(4, 5);
assertEquals(arr.slices(), arr.vectorsAlongDimension(1));
for (int i = 0; i < arr.slices(); i++) {
assertEquals(arr.vectorAlongDimension(i, 1), arr.slice(i));
assertEquals(arr.vectorAlongDimension(i, 1).ravel(), arr.slice(i).ravel());
}
}
@Test
public void testVectorAlongDimension() {
INDArray n = Nd4j.linspace(1, 8, 8).reshape(2, 4);
IComplexNDArray nComplex = Nd4j.createComplex(Nd4j.linspace(1, 8, 8)).reshape(2, 4);
assertEquals(n.vectorsAlongDimension(0), nComplex.vectorsAlongDimension(0));
for (int i = 0; i < n.vectorsAlongDimension(0); i++) {
INDArray vec = n.vectorAlongDimension(i, 0);
IComplexNDArray vecComplex = nComplex.vectorAlongDimension(i, 0);
assertEquals(vec.length(), vecComplex.length());
for (int j = 0; j < vec.length(); j++) {
IComplexNumber currComplex = vecComplex.getComplex(j);
double curr = vec.getFloat(j);
assertEquals(curr, currComplex.realComponent().doubleValue(), 1e-1);
}
assertEquals(vec, vecComplex.getReal());
}
}
@Test
public void testVectorGet() {
IComplexNDArray arr = Nd4j.createComplex(Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[] {1, 8}));
for (int i = 0; i < arr.length(); i++) {
IComplexNumber curr = arr.getComplex(i);
assertEquals(Nd4j.createDouble(i + 1, 0), curr);
}
IComplexNDArray matrix = Nd4j.createComplex(Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[] {2, 4}));
IComplexNDArray row = matrix.getRow(1);
IComplexNDArray column = matrix.getColumn(1);
IComplexNDArray validate = Nd4j.createComplex(Nd4j.create(new double[] {5, 6, 7, 8}, new int[] {1, 4}));
IComplexNumber d = row.getComplex(3);
assertEquals(Nd4j.createDouble(8, 0), d);
assertEquals(row, validate);
IComplexNumber d2 = column.getComplex(1);
assertEquals(Nd4j.createDouble(6, 0), d2);
}
@Test
public void testTensorStrides() {
INDArray arr = Nd4j.createComplex(106, 1, 3, 3);
//(144, 144, 48, 16)
int[] assertion = ArrayUtil.of(18, 18, 6, 2);
int[] arrShape = arr.stride();
assertTrue(Arrays.equals(assertion, arrShape));
Nd4j.factory().setOrder('f');
arr = Nd4j.createComplex(106, 1, 3, 3);
//(16, 1696, 1696, 5088)
assertion = ArrayUtil.of(2, 212, 212, 636);
arrShape = arr.stride();
assertTrue(Arrays.equals(assertion, arrShape));
}
@Test
public void testLinearView() {
IComplexNDArray n = Nd4j.complexLinSpace(1, 4, 4).reshape(2, 2);
IComplexNDArray row = n.getRow(1);
IComplexNDArray linear = row.linearView();
assertEquals(row, linear);
IComplexNDArray large = Nd4j.complexLinSpace(1, 1000, 1000).reshape(2, 500);
IComplexNDArray largeLinear = large.linearView();
for (int i = 0; i < largeLinear.length(); i++)
assertEquals(i + 1, largeLinear.getReal(i), 1e-1);
IComplexNDArray largeTensor = large.reshape(1000, 1, 1, 1);
for (int i = 0; i < largeLinear.length(); i++)
assertEquals(i + 1, largeTensor.getReal(i), 1e-1);
}
@Test
public void testSwapAxes() {
IComplexNDArray n = Nd4j.createComplex(Nd4j.create(new double[] {1, 2, 3}, new int[] {3, 1}));
IComplexNDArray swapped = n.swapAxes(1, 0);
assertEquals(n.transpose(), swapped);
//vector despite being transposed should have same linear index
assertEquals(swapped.getScalar(0), n.getScalar(0));
assertEquals(swapped.getScalar(1), n.getScalar(1));
assertEquals(swapped.getScalar(2), n.getScalar(2));
IComplexNDArray n2 = Nd4j.createComplex(Nd4j.create(Nd4j.linspace(0, 7, 8).data(), new int[] {2, 2, 2}));
IComplexNDArray assertion = n2.permute(new int[] {2, 1, 0});
IComplexNDArray validate =
Nd4j.createComplex(Nd4j.create(new double[] {0, 4, 2, 6, 1, 5, 3, 7}, new int[] {2, 2, 2}));
assertEquals(validate, assertion);
IComplexNDArray v1 = Nd4j.createComplex(Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[] {8, 1}));
IComplexNDArray swap = v1.swapAxes(1, 0);
IComplexNDArray transposed = v1.transpose();
assertEquals(swap, transposed);
transposed.put(1, Nd4j.scalar(9));
swap.put(1, Nd4j.scalar(9));
assertEquals(transposed, swap);
assertEquals(transposed.getScalar(1).element(), swap.getScalar(1).element());
IComplexNDArray row = n2.slice(0).getRow(1);
row.put(1, Nd4j.scalar(9));
IComplexNumber n3 = (IComplexNumber) row.getScalar(1).element();
assertEquals(9, n3.realComponent().doubleValue(), 1e-1);
}
@Test
public void testSliceOffset() {
IComplexNDArray test = Nd4j.complexLinSpace(1, 10, 10).reshape(2, 5);
IComplexNDArray testSlice0 = Nd4j.complexLinSpace(1, 5, 5);
IComplexNDArray testSlice1 = Nd4j.complexLinSpace(6, 10, 5);
assertEquals(testSlice0, test.slice(0));
assertEquals(testSlice1, test.slice(1));
IComplexNDArray sliceOfSlice0 = test.slice(0).slice(0);
assertEquals(sliceOfSlice0.getComplex(0), Nd4j.createComplexNumber(1, 0));
assertEquals(test.slice(1).slice(0).getComplex(0), Nd4j.createComplexNumber(6, 0));
assertEquals(test.slice(1).getComplex(1), Nd4j.createComplexNumber(7, 0));
}
@Test
public void testSlice() {
IComplexNDArray slices = Nd4j.createComplex(2, 3);
slices.put(0, 0, 1);
slices.put(0, 1, 2);
slices.put(0, 2, 3);
slices.put(1, 1, 4);
IComplexNDArray assertion = Nd4j.createComplex(new IComplexNumber[] {Nd4j.createComplexNumber(1, 0),
Nd4j.createComplexNumber(2, 0), Nd4j.createComplexNumber(3, 0),
});
assertEquals(assertion, slices.slice(0));
INDArray arr = Nd4j.create(Nd4j.linspace(1, 24, 24).data(), new int[] {4, 3, 2});
IComplexNDArray arr2 = Nd4j.createComplex(arr);
assertEquals(arr, arr2.getReal());
INDArray firstSlice = arr.slice(0);
INDArray firstSliceTest = arr2.slice(0).getReal();
assertEquals(firstSlice, firstSliceTest);
INDArray secondSlice = arr.slice(1);
INDArray secondSliceTest = arr2.slice(1).getReal();
assertEquals(secondSlice, secondSliceTest);
INDArray slice0 = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, new int[] {3, 2});
INDArray slice2 = Nd4j.create(new double[] {7, 8, 9, 10, 11, 12}, new int[] {3, 2});
IComplexNDArray testSliceComplex = arr2.slice(0);
IComplexNDArray testSliceComplex2 = arr2.slice(1);
INDArray testSlice0 = testSliceComplex.getReal();
INDArray testSlice1 = testSliceComplex2.getReal();
assertEquals(slice0, testSlice0);
assertEquals(slice2, testSlice1);
//weird slice striding issues here. try to avoid hacks related to if() the problem is not complex related
INDArray n2 = Nd4j.create(Nd4j.linspace(1, 30, 30).data(), new int[] {3, 5, 2});
INDArray swapped = n2.swapAxes(n2.shape().length - 1, 1);
INDArray firstSlice2 = swapped.slice(0).slice(0);
//problem ends here. Something with slicing?
IComplexNDArray testSlice = Nd4j.createComplex(firstSlice2);
IComplexNDArray testNoOffset =
Nd4j.createComplex(new double[] {1, 0, 3, 0, 5, 0, 7, 0, 9, 0}, new int[] {1, 5});
assertEquals(testSlice, testNoOffset);
}
@Test
public void testSliceConstructor() {
List<IComplexNDArray> testList = new ArrayList<>();
for (int i = 0; i < 5; i++)
testList.add(Nd4j.complexScalar(i + 1));
IComplexNDArray test = Nd4j.createComplex(testList, new int[] {1, testList.size()});
IComplexNDArray expected = Nd4j.createComplex(Nd4j.create(new double[] {1, 2, 3, 4, 5}, new int[] {1, 5}));
assertEquals(expected, test);
}
@Test
public void testVectorInit() {
DataBuffer data = Nd4j.linspace(1, 4, 4).data();
IComplexNDArray arr = Nd4j.createComplex(data, new int[] {1, 4});
assertEquals(true, arr.isRowVector());
IComplexNDArray arr2 = Nd4j.createComplex(data, new int[] {1, 4});
assertEquals(true, arr2.isRowVector());
IComplexNDArray columnVector = Nd4j.createComplex(data, new int[] {4, 1});
assertEquals(true, columnVector.isColumnVector());
}
@Test
public void testMmulOffset() {
IComplexNDArray three = Nd4j.createComplex(Nd4j.create(new double[] {3, 4}, new int[] {1, 2}));
IComplexNDArray test = Nd4j.createComplex(Nd4j.create(Nd4j.linspace(1, 30, 30).data(), new int[] {3, 5, 2}));
IComplexNDArray sliceRow = test.slice(0).getRow(1);
assertEquals(getFailureMessage(), three, sliceRow);
IComplexNDArray twoSix = Nd4j.createComplex(Nd4j.create(new double[] {2, 6}, new int[] {2, 1}));
IComplexNDArray threeTwoSix = three.mmul(twoSix);
IComplexNDArray sliceRowTwoSix = sliceRow.mmul(twoSix);
verifyElements(three, sliceRow);
assertEquals(getFailureMessage(), threeTwoSix, sliceRowTwoSix);
}
@Test
public void testIterateOverAllRows() {
Nd4j.EPS_THRESHOLD = 1e-1;
IComplexNDArray ones = Nd4j.complexOnes(5, 5);
VectorFFT fft = new VectorFFT(ones);
IComplexNDArray assertion = Nd4j.createComplex(5, 5);
for (int i = 0; i < assertion.rows(); i++)
assertion.getRow(i).putScalar(0, Nd4j.createComplexNumber(5, 0));
Nd4j.getExecutioner().iterateOverAllRows(fft);
assertEquals(getFailureMessage(), assertion, ones);
}
@Test
public void testRowVectorGemm() {
IComplexNDArray linspace = Nd4j.complexLinSpace(1, 4, 4);
IComplexNDArray other = Nd4j.complexLinSpace(1, 16, 16).reshape(4, 4);
IComplexNDArray result = linspace.mmul(other);
IComplexNDArray assertion = Nd4j.createComplex(ComplexUtil.complexNumbersFor(new double[] {90, 100, 110, 120}));
assertEquals(assertion, result);
}
@Test
public void testRealConversion() {
IComplexNDArray arr = Nd4j.createComplex(1, 5);
INDArray arr1 = Nd4j.create(1, 5);
assertEquals(arr, Nd4j.createComplex(arr1));
IComplexNDArray arr3 = Nd4j.complexLinSpace(1, 6, 6).reshape(2, 3);
INDArray linspace = Nd4j.linspace(1, 6, 6).reshape(2, 3);
assertEquals(arr3, Nd4j.createComplex(linspace));
}
@Test
public void testTranspose() {
IComplexNDArray ndArray =
Nd4j.createComplex(
new double[] {1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0, 6.0, 0.0,
6.999999999999999, 0.0, 8.0, 0.0, 9.0, 0.0, 10.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
new int[] {16, 1});
IComplexNDArray transposed2 = ndArray.transpose();
assertEquals(16, transposed2.columns());
}
@Test
public void testConjugate() {
IComplexNDArray negative = Nd4j.createComplex(new double[] {1, -1, 2, -1}, new int[] {1, 2});
IComplexNDArray positive = Nd4j.createComplex(new double[] {1, 1, 2, 1}, new int[] {1, 2});
assertEquals(negative, positive.conj());
}
@Test
public void testGetRow() {
IComplexNDArray arr = Nd4j.createComplex(new int[] {3, 2});
IComplexNDArray row = Nd4j.createComplex(new double[] {1, 0, 2, 0}, new int[] {1, 2});
arr.putRow(0, row);
IComplexNDArray firstRow = arr.getRow(0);
assertEquals(true, Shape.shapeEquals(new int[] {1, 2}, firstRow.shape()));
IComplexNDArray testRow = arr.getRow(0);
assertEquals(row, testRow);
IComplexNDArray row1 = Nd4j.createComplex(new double[] {3, 0, 4, 0}, new int[] {1, 2});
arr.putRow(1, row1);
assertEquals(true, Shape.shapeEquals(new int[] {2}, arr.getRow(0).shape()));
IComplexNDArray testRow1 = arr.getRow(1);
assertEquals(row1, testRow1);
INDArray fourTwoTwo = Nd4j.linspace(1, 16, 16).reshape(4, 2, 2);
IComplexNDArray multiRow = Nd4j.createComplex(fourTwoTwo);
IComplexNDArray test = Nd4j.createComplex(Nd4j.create(new double[] {7, 8}, new int[] {1, 2}));
IComplexNDArray multiRowSlice = multiRow.slice(1);
IComplexNDArray testMultiRow = multiRowSlice.getRow(1);
assertEquals(test, testMultiRow);
}
@Test
public void testMultiDimensionalCreation() {
INDArray fourTwoTwo = Nd4j.linspace(1, 16, 16).reshape(4, 2, 2);
IComplexNDArray multiRow = Nd4j.createComplex(fourTwoTwo);
multiRow.toString();
assertEquals(fourTwoTwo, multiRow.getReal());
}
@Test
public void testGetComplex() {
IComplexNDArray arr = Nd4j.createComplex(Nd4j.create(Nd4j.createBuffer(new double[] {1, 2, 3, 4, 5})));
IComplexNumber num = arr.getComplex(4);
assertEquals(Nd4j.createDouble(5, 0), num);
IComplexNDArray matrix = Nd4j.complexLinSpace(1, 10, 10).reshape(2, 5);
IComplexNDArray slice = matrix.slice(0);
IComplexNDArray assertion = Nd4j.complexLinSpace(1, 5, 5);
assertEquals(assertion, slice);
IComplexNDArray assert2 = Nd4j.complexLinSpace(6, 10, 5);
assertEquals(assert2, matrix.slice(1));
}
@Test
public void testGetColumn() {
IComplexNDArray arr = Nd4j.createComplex(Nd4j.create(Nd4j.linspace(1, 8, 8).data(), new int[] {2, 4}));
IComplexNDArray column2 = arr.getColumn(1);
IComplexNDArray result = Nd4j.createComplex(Nd4j.create(new double[] {2, 6}, new int[] {1, 2}));
assertEquals(result, column2);
assertEquals(true, Shape.shapeEquals(new int[] {2, 1}, column2.shape()));
IComplexNDArray column = Nd4j.createComplex(new double[] {11, 0, 12, 0}, new int[] {1, 2});
arr.putColumn(1, column);
IComplexNDArray firstColumn = arr.getColumn(1);
assertEquals(column, firstColumn);
IComplexNDArray column1 = Nd4j.createComplex(new double[] {5, 0, 6, 0}, new int[] {1, 2});
arr.putColumn(1, column1);
assertEquals(true, Shape.shapeEquals(new int[] {2, 1}, arr.getColumn(1).shape()));
IComplexNDArray testC = arr.getColumn(1);
assertEquals(column1, testC);
IComplexNDArray multiSlice =
Nd4j.createComplex(Nd4j.create(Nd4j.linspace(1, 32, 32).data(), new int[] {4, 4, 2}));
IComplexNDArray testColumn = Nd4j.createComplex(Nd4j.create(new double[] {10, 12, 14, 16}, new int[] {1, 4}));
IComplexNDArray sliceColumn = multiSlice.slice(1).getColumn(1);
assertEquals(sliceColumn, testColumn);
IComplexNDArray testColumn2 = Nd4j.createComplex(Nd4j.create(new double[] {17, 19, 21, 23}, new int[] {1, 4}));
IComplexNDArray testSlice2 = multiSlice.slice(2);
IComplexNDArray testSlice2ColumnZero = testSlice2.getColumn(0);
assertEquals(testColumn2, testSlice2ColumnZero);
IComplexNDArray testColumn3 = Nd4j.createComplex(Nd4j.create(new double[] {18, 20, 22, 24}, new int[] {1, 4}));
IComplexNDArray testSlice3 = multiSlice.slice(2).getColumn(1);
assertEquals(testColumn3, testSlice3);
}
@Test
public void testGetIndexing() {
Nd4j.MAX_SLICES_TO_PRINT = Integer.MAX_VALUE;
Nd4j.MAX_ELEMENTS_PER_SLICE = Integer.MAX_VALUE;
IComplexNDArray tenByTen = Nd4j.complexLinSpace(1, 100, 100).reshape(10, 10);
IComplexNDArray thirtyToSixty =
(IComplexNDArray) Transforms.round(Nd4j.complexLinSpace(31, 60, 30)).reshape(3, 10);
IComplexNDArray test = tenByTen.get(NDArrayIndex.interval(3, 6), NDArrayIndex.interval(0, tenByTen.columns()));
assertEquals(thirtyToSixty, test);
}
@Test
public void testPutAndGet() {
IComplexNDArray multiRow = Nd4j.createComplex(2, 2);
multiRow.putScalar(0, 0, Nd4j.createComplexNumber(1, 0));
multiRow.putScalar(0, 1, Nd4j.createComplexNumber(2, 0));
multiRow.putScalar(1, 0, Nd4j.createComplexNumber(3, 0));
multiRow.putScalar(1, 1, Nd4j.createComplexNumber(4, 0));
assertEquals(Nd4j.createComplexNumber(1, 0), multiRow.getComplex(0, 0));
assertEquals(Nd4j.createComplexNumber(2, 0), multiRow.getComplex(0, 1));
assertEquals(Nd4j.createComplexNumber(3, 0), multiRow.getComplex(1, 0));
assertEquals(Nd4j.createComplexNumber(4, 0), multiRow.getComplex(1, 1));
IComplexNDArray arr = Nd4j.createComplex(Nd4j.create(new double[] {1, 2, 3, 4}, new int[] {2, 2}));
assertEquals(4, arr.length());
assertEquals(8, arr.data().length());
arr.put(1, 1, Nd4j.scalar(5.0));
IComplexNumber n1 = arr.getComplex(1, 1);
IComplexNumber n2 = arr.getComplex(1, 1);
assertEquals(5.0, n1.realComponent().doubleValue(), 1e-1);
assertEquals(0.0, n2.imaginaryComponent().doubleValue(), 1e-1);
}
@Test
public void testGetReal() {
DataBuffer data = Nd4j.linspace(1, 8, 8).data();
int[] shape = new int[] {1, 8};
IComplexNDArray arr = Nd4j.createComplex(shape);
for (int i = 0; i < arr.length(); i++)
arr.put(i, Nd4j.scalar(data.getFloat(i)));
INDArray arr2 = Nd4j.create(data, shape);
assertEquals(arr2, arr.getReal());
INDArray ones = Nd4j.ones(10);
IComplexNDArray n2 = Nd4j.complexOnes(10);
assertEquals(ones, n2.getReal());
}
@Test
public void testBroadcast() {
IComplexNDArray arr = Nd4j.complexLinSpace(1, 5, 5);
IComplexNDArray arrs = arr.broadcast(new int[] {5, 5});
IComplexNDArray arrs3 = Nd4j.createComplex(5, 5);
assertTrue(Arrays.equals(arrs.shape(), arrs3.shape()));
for (int i = 0; i < arrs.slices(); i++)
arrs3.putSlice(i, arr);
assertEquals(arrs3, arrs);
}
@Test
public void testBasicOperations() {
IComplexNDArray arr = Nd4j.createComplex(new double[] {0, 1, 2, 1, 1, 2, 3, 4}, new int[] {2, 2});
IComplexNumber scalar = arr.sumComplex();
double sum = scalar.realComponent().doubleValue();
assertEquals(6, sum, 1e-1);
arr.addi(1);
scalar = arr.sumComplex();
sum = scalar.realComponent().doubleValue();
assertEquals(10, sum, 1e-1);
arr.subi(Nd4j.createDouble(1, 0));
scalar = arr.sumComplex().asDouble();
sum = scalar.realComponent().doubleValue();
assertEquals(6, sum, 1e-1);
}
@Test
public void testComplexCalculation() {
IComplexNDArray arr = Nd4j.createComplex(
new IComplexNumber[][] {{Nd4j.createComplexNumber(1, 1), Nd4j.createComplexNumber(2, 1)},
{Nd4j.createComplexNumber(3, 2), Nd4j.createComplexNumber(4, 2)}});
IComplexNumber scalar = arr.sumComplex();
double sum = scalar.realComponent().doubleValue();
assertEquals(10, sum, 1e-1);
double sumImag = scalar.imaginaryComponent().doubleValue();
assertEquals(6, sumImag, 1e-1);
IComplexNDArray res = arr.add(Nd4j.createComplexNumber(1, 1));
scalar = res.sumComplex();
sum = scalar.realComponent().doubleValue();
assertEquals(14, sum, 1e-1);
sumImag = scalar.imaginaryComponent().doubleValue();
assertEquals(10, sumImag, 1e-1);
//original array should keep as it is
sum = arr.sumComplex().realComponent().doubleValue();
assertEquals(10, sum, 1e-1);
}
@Test
public void testElementWiseOps() {
IComplexNDArray n1 = Nd4j.complexScalar(1);
IComplexNDArray n2 = Nd4j.complexScalar(2);
assertEquals(Nd4j.complexScalar(3), n1.add(n2));
assertFalse(n1.add(n2).equals(n1));
IComplexNDArray n3 = Nd4j.complexScalar(3);
IComplexNDArray n4 = Nd4j.complexScalar(4);
IComplexNDArray subbed = n4.sub(n3);
IComplexNDArray mulled = n4.mul(n3);
IComplexNDArray div = n4.div(n3);
assertFalse(subbed.equals(n4));
assertFalse(mulled.equals(n4));
assertEquals(Nd4j.complexScalar(1), subbed);
assertEquals(Nd4j.complexScalar(12), mulled);
assertEquals(Nd4j.complexScalar(1.3333333333333333), div);
IComplexNDArray multiDimensionElementWise =
Nd4j.createComplex(Nd4j.create(Nd4j.linspace(1, 24, 24).data(), new int[] {4, 3, 2}));
IComplexNumber sum2 = multiDimensionElementWise.sumComplex();
assertEquals(sum2, Nd4j.createDouble(300, 0));
IComplexNDArray added = multiDimensionElementWise.add(Nd4j.complexScalar(1));
IComplexNumber sum3 = added.sumComplex();
assertEquals(sum3, Nd4j.createDouble(324, 0));
}
@Test
public void testFlatten() {
IComplexNDArray arr = Nd4j.createComplex(Nd4j.create(Nd4j.linspace(1, 4, 4).data(), new int[] {2, 2}));
IComplexNDArray flattened = arr.ravel();
assertEquals(arr.length(), flattened.length());
assertTrue(Shape.shapeEquals(new int[] {1, 4}, flattened.shape()));
for (int i = 0; i < arr.length(); i++) {
IComplexNumber get = (IComplexNumber) flattened.getScalar(i).element();
assertEquals(i + 1, get.realComponent().doubleValue(), 1e-1);
}
}
@Test
public void testMatrixGet() {
IComplexNDArray arr = Nd4j.createComplex((Nd4j.linspace(1, 4, 4))).reshape(2, 2);
IComplexNumber n1 = arr.getComplex(0, 0);
IComplexNumber n2 = arr.getComplex(0, 1);
IComplexNumber n3 = arr.getComplex(1, 0);
IComplexNumber n4 = arr.getComplex(1, 1);
assertEquals(1, n1.realComponent().doubleValue(), 1e-1);
assertEquals(2, n2.realComponent().doubleValue(), 1e-1);
assertEquals(3, n3.realComponent().doubleValue(), 1e-1);
assertEquals(4, n4.realComponent().doubleValue(), 1e-1);
}
@Override
public char ordering() {
return 'c';
}
}