package org.nd4j.linalg.slicing; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.SpecifiedIndex; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** * @author Adam Gibson */ @RunWith(Parameterized.class) public class SlicingTestsC extends BaseNd4jTest { public SlicingTestsC(Nd4jBackend backend) { super(backend); } @Test public void testSliceRowVector() { INDArray arr = Nd4j.zeros(5); System.out.println(arr.slice(1)); } @Test public void testSliceAssertion() { INDArray arr = Nd4j.linspace(1, 30, 30).reshape(3, 5, 2); INDArray firstRow = arr.slice(0).slice(0); for (int i = 0; i < firstRow.length(); i++) { System.out.println(firstRow.getDouble(i)); } System.out.println(firstRow); } @Test public void testSliceShape() { INDArray arr = Nd4j.linspace(1, 30, 30).reshape(3, 5, 2); INDArray sliceZero = arr.slice(0); for (int i = 0; i < sliceZero.rows(); i++) { INDArray row = sliceZero.slice(i); for (int j = 0; j < row.length(); j++) { System.out.println(row.getDouble(j)); } System.out.println(row); } INDArray assertion = Nd4j.create(new double[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, new int[] {5, 2}); for (int i = 0; i < assertion.rows(); i++) { INDArray row = assertion.slice(i); for (int j = 0; j < row.length(); j++) { System.out.println(row.getDouble(j)); } System.out.println(row); } assertArrayEquals(new int[] {5, 2}, sliceZero.shape()); assertEquals(assertion, sliceZero); INDArray assertionTwo = Nd4j.create(new double[] {11, 12, 13, 14, 15, 16, 17, 18, 19, 20}, new int[] {5, 2}); INDArray sliceTest = arr.slice(1); assertEquals(assertionTwo, sliceTest); } @Test public void testSwapReshape() { INDArray n2 = Nd4j.create(Nd4j.linspace(1, 30, 30).data(), new int[] {3, 5, 2}); INDArray swapped = n2.swapAxes(n2.shape().length - 1, 1); INDArray firstSlice2 = swapped.slice(0).slice(0); INDArray oneThreeFiveSevenNine = Nd4j.create(new float[] {1, 3, 5, 7, 9}); assertEquals(firstSlice2, oneThreeFiveSevenNine); INDArray raveled = oneThreeFiveSevenNine.reshape(5, 1); INDArray raveledOneThreeFiveSevenNine = oneThreeFiveSevenNine.reshape(5, 1); assertEquals(raveled, raveledOneThreeFiveSevenNine); INDArray firstSlice3 = swapped.slice(0).slice(1); INDArray twoFourSixEightTen = Nd4j.create(new float[] {2, 4, 6, 8, 10}); assertEquals(firstSlice2, oneThreeFiveSevenNine); INDArray raveled2 = twoFourSixEightTen.reshape(5, 1); INDArray raveled3 = firstSlice3.reshape(5, 1); assertEquals(raveled2, raveled3); } @Test public void testGetRow() { INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3); INDArray get = arr.getRow(1); INDArray get2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all()); INDArray assertion = Nd4j.create(new double[] {4, 5, 6}); assertEquals(assertion, get); assertEquals(get, get2); get2.assign(Nd4j.linspace(1, 3, 3)); assertEquals(Nd4j.linspace(1, 3, 3), get2); INDArray threeByThree = Nd4j.linspace(1, 9, 9).reshape(3, 3); INDArray offsetTest = threeByThree.get(new SpecifiedIndex(1, 2), NDArrayIndex.all()); INDArray threeByThreeAssertion = Nd4j.create(new double[][] {{4, 5, 6}, {7, 8, 9}}); assertEquals(threeByThreeAssertion, offsetTest); } @Test public void testVectorIndexing() { INDArray zeros = Nd4j.create(1, 400000); INDArray get = zeros.get(NDArrayIndex.interval(0, 300000)); assertArrayEquals(new int[] {1, 300000}, get.shape()); } @Override public char ordering() { return 'c'; } }