package org.nd4j.linalg.shape;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.util.NDArrayMath;
import static org.junit.Assert.assertEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class NDArrayMathTests extends BaseNd4jTest {
public NDArrayMathTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testVectorPerSlice() {
INDArray arr = Nd4j.create(2, 2, 2, 2);
assertEquals(4, NDArrayMath.vectorsPerSlice(arr));
INDArray matrix = Nd4j.create(2, 2);
assertEquals(2, NDArrayMath.vectorsPerSlice(matrix));
INDArray arrSliceZero = arr.slice(0);
assertEquals(4, NDArrayMath.vectorsPerSlice(arrSliceZero));
}
@Test
public void testMatricesPerSlice() {
INDArray arr = Nd4j.create(2, 2, 2, 2);
assertEquals(2, NDArrayMath.matricesPerSlice(arr));
}
@Test
public void testLengthPerSlice() {
INDArray arr = Nd4j.create(2, 2, 2, 2);
int lengthPerSlice = NDArrayMath.lengthPerSlice(arr);
assertEquals(8, lengthPerSlice);
}
@Test
public void toffsetForSlice() {
INDArray arr = Nd4j.create(3, 2, 2);
int slice = 1;
assertEquals(4, NDArrayMath.offsetForSlice(arr, slice));
}
@Test
public void testMapOntoVector() {
INDArray arr = Nd4j.create(3, 2, 2);
assertEquals(NDArrayMath.mapIndexOntoVector(2, arr), 4);
}
@Test
public void testNumVectors() {
INDArray arr = Nd4j.create(3, 2, 2);
assertEquals(4, NDArrayMath.vectorsPerSlice(arr));
INDArray matrix = Nd4j.create(2, 2);
assertEquals(2, NDArrayMath.vectorsPerSlice(matrix));
}
@Test
public void testOffsetForSlice() {
INDArray arr = Nd4j.linspace(1, 16, 16).reshape(2, 2, 2, 2);
int[] dimensions = {0, 1};
INDArray permuted = arr.permute(2, 3, 0, 1);
int[] test = {0, 0, 1, 1};
for (int i = 0; i < permuted.tensorssAlongDimension(dimensions); i++) {
assertEquals(test[i], NDArrayMath.sliceOffsetForTensor(i, permuted, new int[] {2, 2}));
}
int arrTensorsPerSlice = NDArrayMath.tensorsPerSlice(arr, new int[] {2, 2});
assertEquals(2, arrTensorsPerSlice);
INDArray arr2 = Nd4j.linspace(1, 12, 12).reshape(3, 2, 2);
int[] assertions = {0, 1, 2};
for (int i = 0; i < assertions.length; i++) {
assertEquals(assertions[i], NDArrayMath.sliceOffsetForTensor(i, arr2, new int[] {2, 2}));
}
int tensorsPerSlice = NDArrayMath.tensorsPerSlice(arr2, new int[] {2, 2});
assertEquals(1, tensorsPerSlice);
INDArray otherTest = Nd4j.linspace(1, 144, 144).reshape(6, 3, 2, 2, 2);
System.out.println(otherTest);
INDArray baseArr = Nd4j.linspace(1, 8, 8).reshape(2, 2, 2);
for (int i = 0; i < baseArr.tensorssAlongDimension(0, 1); i++) {
System.out.println(NDArrayMath.sliceOffsetForTensor(i, baseArr, new int[] {2, 2}));
}
}
@Test
public void testOddDimensions() {
INDArray arr = Nd4j.create(3, 2, 2);
int numMatrices = NDArrayMath.matricesPerSlice(arr);
assertEquals(1, numMatrices);
}
@Test
public void testTotalVectors() {
INDArray arr2 = Nd4j.create(2, 2, 2, 2);
assertEquals(8, NDArrayMath.numVectors(arr2));
}
@Override
public char ordering() {
return 'f';
}
}