package org.nd4j.linalg.util;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import static org.junit.Assert.*;
/**
* @author Adam Gibson
*/
@RunWith(Parameterized.class)
public class ShapeTest extends BaseNd4jTest {
public ShapeTest(Nd4jBackend backend) {
super(backend);
}
@Test
public void testToOffsetZero() {
INDArray matrix = Nd4j.rand(3, 5);
INDArray rowOne = matrix.getRow(1);
INDArray row1Copy = Shape.toOffsetZero(rowOne);
assertEquals(rowOne, row1Copy);
INDArray rows = matrix.getRows(1, 2);
INDArray rowsOffsetZero = Shape.toOffsetZero(rows);
assertEquals(rows, rowsOffsetZero);
INDArray tensor = Nd4j.rand(new int[] {3, 3, 3});
INDArray getTensor = tensor.slice(1).slice(1);
INDArray getTensorZero = Shape.toOffsetZero(getTensor);
assertEquals(getTensor, getTensorZero);
}
@Test
public void testDupLeadingTrailingZeros() {
testDupHelper(1, 10);
testDupHelper(10, 1);
testDupHelper(1, 10, 1);
testDupHelper(1, 10, 1, 1);
testDupHelper(1, 10, 2);
testDupHelper(2, 10, 1, 1);
testDupHelper(1, 1, 1, 10);
testDupHelper(10, 1, 1, 1);
testDupHelper(1, 1);
}
private void testDupHelper(int... shape) {
INDArray arr = Nd4j.ones(shape);
INDArray arr2 = arr.dup();
assertArrayEquals(arr.shape(), arr2.shape());
assertTrue(arr.equals(arr2));
}
@Test
public void testLeadingOnes() {
INDArray arr = Nd4j.create(1, 5, 5);
assertEquals(1, arr.getLeadingOnes());
INDArray arr2 = Nd4j.create(2, 2);
assertEquals(0, arr2.getLeadingOnes());
INDArray arr4 = Nd4j.create(1, 1, 5, 5);
assertEquals(2, arr4.getLeadingOnes());
}
@Test
public void testTrailingOnes() {
INDArray arr2 = Nd4j.create(5, 5, 1);
assertEquals(1, arr2.getTrailingOnes());
INDArray arr4 = Nd4j.create(5, 5, 1, 1);
assertEquals(2, arr4.getTrailingOnes());
}
@Test
public void testElementWiseCompareOnesInMiddle() {
INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3);
INDArray onesInMiddle = Nd4j.linspace(1, 6, 6).reshape(2, 1, 3);
for (int i = 0; i < arr.length(); i++) {
double val = arr.getDouble(i);
double middleVal = onesInMiddle.getDouble(i);
assertEquals(val, middleVal, 1e-1);
}
}
@Test
public void testSumLeadingTrailingZeros() {
testSumHelper(1, 5, 5);
testSumHelper(5, 5, 1);
testSumHelper(1, 5, 1);
testSumHelper(1, 5, 5, 5);
testSumHelper(5, 5, 5, 1);
testSumHelper(1, 5, 5, 1);
testSumHelper(1, 5, 5, 5, 5);
testSumHelper(5, 5, 5, 5, 1);
testSumHelper(1, 5, 5, 5, 1);
testSumHelper(1, 5, 5, 5, 5, 5);
testSumHelper(5, 5, 5, 5, 5, 1);
testSumHelper(1, 5, 5, 5, 5, 1);
}
private void testSumHelper(int... shape) {
INDArray array = Nd4j.ones(shape);
for (int i = 0; i < shape.length; i++) {
for (int j = 0; j < array.vectorsAlongDimension(i); j++) {
INDArray vec = array.vectorAlongDimension(j, i);
}
array.sum(i);
}
}
@Override
public char ordering() {
return 'f';
}
}