package org.nd4j.linalg.util;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.assertEquals;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class ShapeTestC extends BaseNd4jTest {
public ShapeTestC(Nd4jBackend backend) {
super(backend);
}
@Test
public void testToOffsetZero() {
INDArray matrix = Nd4j.rand(3, 5);
INDArray rowOne = matrix.getRow(1);
INDArray row1Copy = Shape.toOffsetZero(rowOne);
assertEquals(rowOne, row1Copy);
INDArray rows = matrix.getRows(1, 2);
INDArray rowsOffsetZero = Shape.toOffsetZero(rows);
assertEquals(rows, rowsOffsetZero);
INDArray tensor = Nd4j.rand(new int[] {3, 3, 3});
INDArray getTensor = tensor.slice(1).slice(1);
INDArray getTensorZero = Shape.toOffsetZero(getTensor);
assertEquals(getTensor, getTensorZero);
}
@Test
public void testElementWiseCompareOnesInMiddle() {
INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3);
INDArray onesInMiddle = Nd4j.linspace(1, 6, 6).reshape(2, 1, 3);
for (int i = 0; i < arr.length(); i++)
assertEquals(arr.getDouble(i), onesInMiddle.getDouble(i), 1e-3);
}
@Override
public char ordering() {
return 'c';
}
}