package org.nd4j.linalg.shape;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.util.ArrayUtil;
import java.nio.IntBuffer;
import static org.junit.Assert.*;
/**
* Created by agibsoncccc on 1/30/16.
*/
@RunWith(Parameterized.class)
public class ShapeBufferTests extends BaseNd4jTest {
public ShapeBufferTests(Nd4jBackend backend) {
super(backend);
}
@Override
public char ordering() {
return 'c';
}
@Test
public void testRank() {
int[] shape = {2, 4};
int[] stride = {1, 2};
IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt();
int rank = 2;
assertEquals(rank, Shape.rank(buff));
}
@Test
public void testArrCreationShape() {
INDArray arr = Nd4j.linspace(1, 4, 4).reshape(2, 2);
for (int i = 0; i < 2; i++)
assertEquals(2, arr.size(i));
int[] stride = ArrayUtil.calcStrides(new int[] {2, 2});
for (int i = 0; i < stride.length; i++) {
assertEquals(stride[i], arr.stride(i));
}
}
@Test
public void testShape() {
int[] shape = {2, 4};
int[] stride = {1, 2};
IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt();
IntBuffer shapeView = Shape.shapeOf(buff);
assertTrue(Shape.contentEquals(shape, shapeView));
IntBuffer strideView = Shape.stride(buff);
assertTrue(Shape.contentEquals(stride, strideView));
assertEquals('c', Shape.order(buff));
assertEquals(1, Shape.elementWiseStride(buff));
assertFalse(Shape.isVector(buff));
assertTrue(Shape.contentEquals(shape, Shape.shapeOf(buff)));
assertTrue(Shape.contentEquals(stride, Shape.stride(buff)));
}
@Test
public void testBuff() {
int[] shape = {1, 2};
int[] stride = {1, 2};
IntBuffer buff = Shape.createShapeInformation(shape, stride, 0, 1, 'c').asNioInt();
assertTrue(Shape.isVector(buff));
}
}