package org.nd4j.linalg.shape;
import org.apache.commons.math3.util.Pair;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.checkutil.NDArrayCreationUtil;
import org.nd4j.linalg.factory.Nd4jBackend;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class StaticShapeTests extends BaseNd4jTest {
public StaticShapeTests(Nd4jBackend backend) {
super(backend);
}
@Test
public void testShapeInd2Sub() {
long normalTotal = 0;
long n = 1000;
for (int i = 0; i < n; i++) {
long start = System.nanoTime();
Shape.ind2subC(new int[] {2, 2}, 1);
long end = System.nanoTime();
normalTotal += Math.abs(end - start);
}
normalTotal /= n;
System.out.println(normalTotal);
System.out.println("C " + Arrays.toString(Shape.ind2subC(new int[] {2, 2}, 1)));
System.out.println("F " + Arrays.toString(Shape.ind2sub(new int[] {2, 2}, 1)));
}
@Test
public void testBufferToIntShapeStrideMethods() {
//Specifically: Shape.shape(IntBuffer), Shape.shape(DataBuffer)
//.isRowVectorShape(DataBuffer), .isRowVectorShape(IntBuffer)
//Shape.size(DataBuffer,int), Shape.size(IntBuffer,int)
//Also: Shape.stride(IntBuffer), Shape.stride(DataBuffer)
//Shape.stride(DataBuffer,int), Shape.stride(IntBuffer,int)
List<List<Pair<INDArray, String>>> lists = new ArrayList<>();
lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345));
lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(1, 4, 12345));
lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 1, 12345));
lists.add(NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 3, 4, 5));
lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 4, 5, 6));
lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 1, 5, 1));
lists.add(NDArrayCreationUtil.getAll5dTestArraysWithShape(12345, 3, 4, 5, 6, 7));
lists.add(NDArrayCreationUtil.getAll6dTestArraysWithShape(12345, 3, 4, 5, 6, 7, 8));
int[][] shapes = new int[][] {{3, 4}, {1, 4}, {3, 1}, {3, 4, 5}, {3, 4, 5, 6}, {3, 1, 5, 1}, {3, 4, 5, 6, 7},
{3, 4, 5, 6, 7, 8}};
for (int i = 0; i < shapes.length; i++) {
List<Pair<INDArray, String>> list = lists.get(i);
int[] shape = shapes[i];
for (Pair<INDArray, String> p : list) {
INDArray arr = p.getFirst();
assertArrayEquals(shape, arr.shape());
int[] thisStride = arr.stride();
IntBuffer ib = arr.shapeInfo();
DataBuffer db = arr.shapeInfoDataBuffer();
//Check shape calculation
assertEquals(shape.length, Shape.rank(ib));
assertEquals(shape.length, Shape.rank(db));
assertArrayEquals(shape, Shape.shape(ib));
assertArrayEquals(shape, Shape.shape(db));
for (int j = 0; j < shape.length; j++) {
assertEquals(shape[j], Shape.size(ib, j));
assertEquals(shape[j], Shape.size(db, j));
assertEquals(thisStride[j], Shape.stride(ib, j));
assertEquals(thisStride[j], Shape.stride(db, j));
}
//Check base offset
assertEquals(Shape.offset(ib), Shape.offset(db));
//Check offset calculation:
NdIndexIterator iter = new NdIndexIterator(shape);
while (iter.hasNext()) {
int[] next = iter.next();
long offset1 = Shape.getOffset(ib, next);
assertEquals(offset1, Shape.getOffset(db, next));
switch (shape.length) {
case 2:
assertEquals(offset1, Shape.getOffset(ib, next[0], next[1]));
assertEquals(offset1, Shape.getOffset(db, next[0], next[1]));
break;
case 3:
assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2]));
assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2]));
break;
case 4:
assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2], next[3]));
assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2], next[3]));
break;
case 5:
case 6:
//No 5 and 6d getOffset overloads
break;
default:
throw new RuntimeException();
}
}
}
}
}
@Override
public char ordering() {
return 'f';
}
}