package org.nd4j.linalg.api; import org.apache.commons.math3.util.Pair; import org.junit.Test; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.checkutil.NDArrayCreationUtil; import org.nd4j.linalg.factory.Nd4jBackend; import static org.junit.Assert.assertArrayEquals; /** * Created by Alex on 30/04/2016. */ public class TestNDArrayCreationUtil extends BaseNd4jTest { public TestNDArrayCreationUtil(Nd4jBackend backend) { super(backend); } @Test public void testShapes() { int[] shape2d = {2, 3}; for (Pair<INDArray, String> p : NDArrayCreationUtil.getAllTestMatricesWithShape(2, 3, 12345)) { assertArrayEquals(p.getSecond(), shape2d, p.getFirst().shape()); } int[] shape3d = {2, 3, 4}; for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, shape3d)) { assertArrayEquals(p.getSecond(), shape3d, p.getFirst().shape()); } int[] shape4d = {2, 3, 4, 5}; for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, shape4d)) { assertArrayEquals(p.getSecond(), shape4d, p.getFirst().shape()); } int[] shape5d = {2, 3, 4, 5, 6}; for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll5dTestArraysWithShape(12345, shape5d)) { assertArrayEquals(p.getSecond(), shape5d, p.getFirst().shape()); } int[] shape6d = {2, 3, 4, 5, 6, 7}; for (Pair<INDArray, String> p : NDArrayCreationUtil.getAll6dTestArraysWithShape(12345, shape6d)) { assertArrayEquals(p.getSecond(), shape6d, p.getFirst().shape()); } } @Override public char ordering() { return 'c'; } }